In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import os, wandb, torch, time
import pandas as pd
import numpy as np
from glob import glob
from torch.utils.data import DataLoader
from torchinfo import summary
import seaborn as sns
import matplotlib.pyplot as plt

from delphi import mni_template
from delphi.networks.ConvNets import BrainStateClassifier3d
from delphi.utils.datasets import NiftiDataset
from delphi.utils.tools import ToTensor, compute_accuracy, convert_wandb_config, read_config, z_transform_volume, save_in_mni
from delphi.utils.plots import confusion_matrix

from sklearn.model_selection import StratifiedShuffleSplit

# you can find all these files in ../utils
from utils.tools import attribute_with_method, concat_stat_files, compute_mi
from utils.wandb_funcs import reset_wandb_env, wandb_plots
from utils.random import set_random_seed

from tqdm.notebook import tqdm

from captum.attr import GuidedBackprop
from zennit.rules import Epsilon, Gamma, Pass
from zennit.types import Convolution, Linear, Activation
from zennit.composites import LayerMapComposite
from utils.tools import attribute_with_method

composite_lrp_map = [
    (Activation, Pass()),
    (Convolution, Gamma(gamma=.25)),
    (Linear, Epsilon(epsilon=0)),
]

LRP = LayerMapComposite(
    layer_map=composite_lrp_map,
)
LRP.__name__ = 'LRP'

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

g = set_random_seed(2020) # the project started in the year 2020, hence the seed



In [3]:
attributor_method = [LRP, GuidedBackprop]
#volume_model_path = "../05_volumetric_approach/models/vol-motor-classifier-withrest_fold-09"
volume_model_path = "models/vol-wm-classifier-withrest_fold-02/"

TASK_LABEL = 'wm'

for i, method in enumerate(attributor_method):

    method_name = str(method.__name__).lower()

    # load the trained network
    
    model = BrainStateClassifier3d(volume_model_path)
    model.to(torch.device("cpu"));
    model.eval()

    class_labels = model.config['class_labels']
    
    out_dir_name = f"{method_name}/{TASK_LABEL}"
    if not os.path.exists(out_dir_name):
        os.makedirs(out_dir_name)

    for j in range(model.config["n_classes"]):

        print(f"Running {method_name} on {class_labels[j]}")

        out_fname = os.path.join(out_dir_name, '%s.nii.gz' % class_labels[j])
        if os.path.isfile(out_fname):
            print(f"{out_fname} already exists. Skipping")
            continue

        dl = DataLoader(
            NiftiDataset('../v-maps/test/', [class_labels[j]], 0, device=torch.device("cpu"), transform=ToTensor()),
            batch_size=20, shuffle=False, num_workers=0
        )

        for i, (volume, target) in enumerate(dl):

            attribution = attribute_with_method(method, model, volume, target)

            subject_attr = np.moveaxis(attribution.squeeze().detach().numpy(), 0, -1)
            subject_attr = z_transform_volume(subject_attr)
            avg_attr = subject_attr.mean(axis=-1)

        save_in_mni(subject_attr, out_fname)

        avg_out_name = os.path.join(out_dir_name, "avg")
        if not os.path.exists(avg_out_name):
            os.makedirs(avg_out_name)
        save_in_mni(avg_attr, os.path.join(avg_out_name, '%s.nii.gz' % class_labels[j]))

Loading from config file models/vol-wm-classifier-withrest_fold-02//config.yaml
Running lrp on body
Saving lrp/wm/body.nii.gz
Saving lrp/wm/avg/body.nii.gz
Running lrp on face
Saving lrp/wm/face.nii.gz
Saving lrp/wm/avg/face.nii.gz
Running lrp on place
Saving lrp/wm/place.nii.gz
Saving lrp/wm/avg/place.nii.gz
Running lrp on rest_WM
Saving lrp/wm/rest_WM.nii.gz
Saving lrp/wm/avg/rest_WM.nii.gz
Running lrp on tool
Saving lrp/wm/tool.nii.gz
Saving lrp/wm/avg/tool.nii.gz
Loading from config file models/vol-wm-classifier-withrest_fold-02//config.yaml
Running guidedbackprop on body




Saving guidedbackprop/wm/body.nii.gz
Saving guidedbackprop/wm/avg/body.nii.gz
Running guidedbackprop on face
Saving guidedbackprop/wm/face.nii.gz
Saving guidedbackprop/wm/avg/face.nii.gz
Running guidedbackprop on place
Saving guidedbackprop/wm/place.nii.gz
Saving guidedbackprop/wm/avg/place.nii.gz
Running guidedbackprop on rest_WM
Saving guidedbackprop/wm/rest_WM.nii.gz
Saving guidedbackprop/wm/avg/rest_WM.nii.gz
Running guidedbackprop on tool
Saving guidedbackprop/wm/tool.nii.gz
Saving guidedbackprop/wm/avg/tool.nii.gz


In [10]:
from nltools.data import Brain_Data
import glob

files = sorted(glob.glob("lrp/multi/avg/*.nii.gz"))
test = Brain_Data(files, mask=mni_template)
files

['lrp/multi/avg/body.nii.gz',
 'lrp/multi/avg/face.nii.gz',
 'lrp/multi/avg/footleft.nii.gz',
 'lrp/multi/avg/footright.nii.gz',
 'lrp/multi/avg/handleft.nii.gz',
 'lrp/multi/avg/handright.nii.gz',
 'lrp/multi/avg/match.nii.gz',
 'lrp/multi/avg/mental.nii.gz',
 'lrp/multi/avg/place.nii.gz',
 'lrp/multi/avg/relation.nii.gz',
 'lrp/multi/avg/rest_MOTOR.nii.gz',
 'lrp/multi/avg/rest_RELATIONAL.nii.gz',
 'lrp/multi/avg/rest_SOCIAL.nii.gz',
 'lrp/multi/avg/rest_WM.nii.gz',
 'lrp/multi/avg/rnd.nii.gz',
 'lrp/multi/avg/tongue.nii.gz',
 'lrp/multi/avg/tool.nii.gz']

In [11]:
test.iplot(threshold="95%")



interactive(children=(FloatText(value=95.0, description='Threshold'), IntSlider(value=0, continuous_update=Fal…

# TEST

In [6]:
from glob import glob
from nilearn.image import load_img, smooth_img
from nilearn.masking import apply_mask, unmask
from delphi import mni_template
from tqdm.notebook import tqdm

mask = load_img(mni_template)
data_dir_train = glob(os.path.join("../v-maps/train", "*"))
classes = [os.path.split(x)[-1] for i, x in enumerate(data_dir_train)]
classes

['handleft',
 'handright',
 'footleft',
 'footright',
 'tongue',
 'rest_MOTOR',
 'rest_RELATIONAL',
 'face',
 'body',
 'place',
 'tool',
 'rest_SOCIAL',
 'rest_WM',
 'match',
 'relation',
 'mental',
 'rnd']

In [7]:
fwhm=4

for c, cl in enumerate(data_dir_train):
    
    imgs = glob(os.path.join(cl, "*.nii.gz"))
    out_dir = os.path.join(f"../v-smoothed/train/{classes[c]}")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        
    for i, img in tqdm(enumerate(imgs), desc="img"):
        
        out_name = os.path.join(out_dir, os.path.split(img)[-1])
        img_dat = load_img(img)
        smoothed = unmask(apply_mask(smooth_img(img_dat, fwhm=fwhm), mask), mask)
        smoothed.to_filename(out_name)

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]

img: 0it [00:00, ?it/s]