# Train and evaluate a 3D Convolutional Neural Network (3dCNN) to classify motor tasks

<p style='text-align: justify;'>In this notebook we create and train a 3D-Convolutional Neural Network which learns to classify different patterns of whole-brain fMRI statistical parameters (t-scores). In this first approach our goal is to train a classifier that can reliably distinguish between such whole-brain patterns for five limb movements (i.e., left/right hand, left/right foot, and tongue).

</p>

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import os, wandb, torch, time
import pandas as pd
import numpy as np
from glob import glob
from torch.utils.data import DataLoader
from torchinfo import summary
import seaborn as sns
import matplotlib.pyplot as plt

from delphi import mni_template
from delphi.networks.ConvNets import BrainStateClassifier3d
from delphi.utils.datasets import NiftiDataset
from delphi.utils.tools import ToTensor, compute_accuracy, convert_wandb_config, read_config, z_transform_volume
from delphi.utils.plots import confusion_matrix

from sklearn.model_selection import StratifiedShuffleSplit

# you can find all these files in ../utils
from utils.tools import attribute_with_method, concat_stat_files, compute_mi
from utils.wandb_funcs import reset_wandb_env, wandb_plots
from utils.random import set_random_seed

from tqdm.notebook import tqdm

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

<p style='text-align: justify;'>To make sure that we obtain (almost) the same results for each execution we set the random seed of multiple different librabries (i.e., torch, random, numpy)</p>

In [3]:
g = set_random_seed(2020) # the project started in the year 2020, hence the seed

## Initializations

In this section, we define and initialize our required variables. We first need to define which classes we want to predict, i.e., the conditions of the motor mapper. We then define a PyTorch dataset; in this case `NiftiDataset` is a custom written Dataset-Class (see https://github.com/PhilippS893/delphi). As is common practice in machine learning projects, we split our data into a training and validation dataset (ratio=80 to 20, respectively).

Note: In case it is necessary to create a null-model, i.e., a neural network that is trained on data where the labels are randomized, one can set the parameter `shuffe_labels=False` to `True`. This is usually done to have a baseline for the null hypothesis that "everything is random".

In [4]:
TASK_LABEL = "motor"
class_labels = sorted(["handleft", "handright", "footleft", "footright", "tongue"])

data_test = NiftiDataset("../t-maps/test", class_labels, 0, device=DEVICE, transform=ToTensor())

<p style='text-align: justify;'>We now set some parameters required by w&b to properly store information about our trained neural networks.</p>

In [None]:
# set the wandb sweep config
# os.environ['WANDB_MODE'] = 'offline'
os.environ['WANDB_ENTITY'] = "philis893" # this is my wandb account name. This can also be a group name, for example
os.environ['WANDB_PROJECT'] = "thesis" # this is simply the project name where we want to store the sweep logs and plots

## Training the neural network(s)

<p style='text-align: justify;'>In this first approach, we estimated what parameters we could use for our 3d-CNN from existing literature. We make use of some functionality, e.g., the function "read_config", I wrote for the "delphi" toolbox (see https://github.com/PhilippS893/delphi). This function can read .yaml files with the formatting as below:</p>

The contents of `hyperparameter.yaml`: <br>
`kernel_size: 7`<br>
`batch_size: 4`<br>
`dropout: .5`<br>
`learning_rate: 0.00001`<br>
`epochs: 60`<br>
`channels: [1, 8, 16, 32, 64]`<br>
`lin_neurons: [128, 64]`<br>
`pooling_kernel: 2`<br>

The function `read_config` will return a dictionary variable containing all keyword-value pairs as set in the `hyperparameter.yaml`. We do this, because the `delphi` toolbox was written with dictionaries as configuration variables in mind and because we can easily submit dictionaries to `w&b` to keep track of such parameters.

In [None]:
def train_network(run, fold, data_train, data_valid, data_test, hp, class_labels, shuffled_labels=False):
    reset_wandb_env()
    
    job_name = "CV-shuffled" if shuffled_labels else "CV-real"
    wandb_kwargs = {
        "entity": os.environ['WANDB_ENTITY'],
        "project": os.environ['WANDB_PROJECT'],
        "group": "first-steps",
        "name": f"motor_fold-{fold:02d}",
        "job_type": job_name if num_folds > 1 else "train",
    }
    
    save_name = os.path.join("models", job_name, wandb_kwargs["name"])
    if os.path.exists(save_name):
        return
    
    # we adjust the random seed here to ensure that each run of each fold has a unique seed!
    g = set_random_seed(2020 + fold + run)
    
    # we now use the wandb context to track the training and evaluation process.
    # all settings and changes will be reset at the beginning of the fold-loop. (see line 11)
    with wandb.init(config=hp, **wandb_kwargs) as run:
        
        # please note that this conversion is unnecessary if not using w&b!
        model_cfg = convert_wandb_config(run.config, BrainStateClassifier3d._REQUIRED_PARAMS)
        
        # setup a model with the parameters given in model_cfg
        model = BrainStateClassifier3d(input_dims, len(class_labels), model_cfg)
        model.to(DEVICE);
        
        model.config["class_labels"] = class_labels
        
        dl_train = DataLoader(data_train, batch_size=run.config.batch_size, shuffle=True, generator=g)
        dl_valid = DataLoader(data_valid, batch_size=run.config.batch_size, shuffle=True, generator=g)
        dl_test = DataLoader(data_test, batch_size=run.config.batch_size, shuffle=False, generator=g)

        best_loss, best_acc = 100, 0
        loss_acc = []
        train_stats, valid_stats = [], []

        # loop for the above set number of epochs
        for epoch in range(run.config.epochs):
            _, _ = model.fit(dl_train, lr=run.config.learning_rate)

            # for validating or testing set the network into evaluation mode such that layers like dropout are not active
            with torch.no_grad():
                tloss, tstats = model.fit(dl_train, train=False)
                vloss, vstats = model.fit(dl_valid, train=False)

            # the model.fit() method has 2 output parameters: loss, stats = model.fit()
            # the first parameter is simply the loss for each sample
            # the second parameter is a matrix of n_classes+2-by-n_samples
            # the first n_classes columns are the output probabilities of the model per class
            # the second to last column (i.e., [:, -2]) represents the real labels
            # the last column (i.e., [:, -1]) represents the predicted labels
            tacc = compute_accuracy(tstats[:, -2], tstats[:, -1])
            vacc = compute_accuracy(vstats[:, -2], vstats[:, -1])

            loss_acc.append(pd.DataFrame([[tloss, vloss, tacc, vacc]],
                                         columns=["train_loss", "valid_loss", "train_acc", "valid_acc"]))

            train_stats.append(pd.DataFrame(tstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
            train_stats[epoch]["epoch"] = epoch
            valid_stats.append(pd.DataFrame(vstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
            valid_stats[epoch]["epoch"] = epoch

            wandb.log({
                "train_acc": tacc, "train_loss": tloss,
                "valid_acc": vacc, "valid_loss": vloss
            }, step=epoch)

            print('Epoch=%03d, train_loss=%2.3f, train_acc=%1.3f, valid_loss=%2.3f, valid_acc=%1.3f' % 
                 (epoch, tloss, tacc, vloss, vacc))

            if (vacc >= best_acc) and (vloss <= best_loss):
                # assign the new best values
                best_acc, best_loss = vacc, vloss
                wandb.run.summary["best_valid_accuracy"] = best_acc
                wandb.run.summary["best_valid_epoch"] = epoch
                # save the current best model
                model.save(save_name)
                # plot some graphs for the validation data
                wandb_plots(vstats[:, -2], vstats[:, -1], vstats[:, :-2], class_labels, "valid")


        # save the files
        full_df = pd.concat(loss_acc)
        full_df.to_csv(os.path.join(save_name, "loss_acc_curves.csv"), index=False)
        full_df = pd.concat(train_stats)
        full_df.to_csv(os.path.join(save_name, "train_stats.csv"), index=False)
        full_df = pd.concat(valid_stats)
        full_df.to_csv(os.path.join(save_name, "valid_stats.csv"), index=False)

        # load the best performing model
        model = BrainStateClassifier3d(save_name)
        model.to(DEVICE)
        
        # EVALUATE THE MODEL ON THE TEST DATA
        with torch.no_grad():
            testloss, teststats = model.fit(dl_test, train=False)
            
        testacc = compute_accuracy(teststats[:, -2], teststats[:, -1])
        df_test = pd.DataFrame(teststats.tolist(), columns=[*class_labels, *["real", "predicted"]])
        df_test.to_csv(os.path.join(save_name, "test_stats.csv"), index=False)
        
        wandb.run.summary["test_accuracy"] = testacc

        wandb.log({"test_accuracy": testacc, "test_loss": testloss})
        wandb_plots(teststats[:, -2], teststats[:, -1], teststats[:, :-2], class_labels, "test")

        wandb.finish()

In [None]:
hp = read_config("hyperparameter.yaml")

num_folds = 10 # we decided to run a 10-fold cross-validation scheme
run = 0 # in case you decide to do multiple runs of the cross-validation scheme
input_dims = (91, 109, 91) # these are the 3D input dimension of our whole-brain data

In [None]:
len(data_test)

### Run cross-validation training using correct labels

In [None]:
# we will split the train dataset into a train (80%) and validation (20%) set.
data_train_full = NiftiDataset("../t-maps/train", class_labels, 0, device=DEVICE, 
                               transform=ToTensor(), shuffle_labels=False)

# we want a stratified shuffled split
sss = StratifiedShuffleSplit(n_splits=num_folds, test_size=0.2, random_state=2020)

for fold, (idx_train, idx_valid) in enumerate(sss.split(data_train_full.data, data_train_full.labels)):
    data_train = torch.utils.data.Subset(data_train_full, idx_train)
    data_valid = torch.utils.data.Subset(data_train_full, idx_valid)
    train_network(run, fold, data_train, data_valid, data_test, hp, class_labels)

### Run cross-validation training using shuffled labels

In [None]:
# we will split the train dataset into a train (80%) and validation (20%) set.
data_train_full = NiftiDataset("../t-maps/train", class_labels, 20, device=DEVICE, 
                               transform=ToTensor(), shuffle_labels=True)

# we want a stratified shuffled split
sss = StratifiedShuffleSplit(n_splits=num_folds, test_size=0.2, random_state=2020)

for fold, (idx_train, idx_valid) in enumerate(sss.split(data_train_full.data, data_train_full.labels)):
    data_train = torch.utils.data.Subset(data_train_full, idx_train)
    data_valid = torch.utils.data.Subset(data_train_full, idx_valid)
    train_network(run, fold, data_train, data_valid, data_test, hp, class_labels, shuffled_labels=True)

## Analyzing the network performances

## Loss and accuracy curves for the real and shuffled labels

In [None]:
SMALL_SIZE = 18
MEDIUM_SIZE = 22
BIGGER_SIZE = 26

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 8), sharex=True)
label_order = ["real", "shuffled"]

folds_real = sorted(glob(os.path.join(f"models/CV-real/*fold*"))) 
curves_real = concat_stat_files(folds_real, "loss_acc_curves.csv")
curves_real["label_order"] = "real"
folds_shuffled = sorted(glob(os.path.join(f"models/CV-shuffled/*fold*"))) 
curves_shuffled = concat_stat_files(folds_shuffled, "loss_acc_curves.csv")   
curves_shuffled["label_order"] = "shuffled"

curves = pd.concat([curves_real, curves_shuffled])
curves["epoch"] = curves.index

sns.lineplot(ax=axes[0], data=curves.melt(value_vars=["train_loss", "valid_loss"], id_vars=["epoch", "label_order"], var_name="loss", value_name="CrossEntropyLoss"), 
             x="epoch", y="CrossEntropyLoss", hue="label_order", style="loss", 
             linewidth=2, errorbar=("ci", 95), n_boot=5000)
sns.lineplot(ax=axes[1], data=curves.melt(value_vars=["train_acc", "valid_acc"], id_vars=["epoch", "label_order"], var_name="accuracy", value_name="acc"), 
             x="epoch", y="acc", hue="label_order", style="accuracy", 
             linewidth=2, errorbar=("ci", 95), n_boot=5000)

for i in range(2):
    axes[i].spines[["top", "right"]].set_visible(False);
    axes[i].legend(frameon=False, loc="center right", fontsize=18)
    
fig.tight_layout()

plt.savefig('figures/loss-acc-curves-across-folds.pdf', facecolor=fig.get_facecolor(), transparent=True)

<p style='text-align: justify;'>We will now compute/load the test classification performance of each trained fold. We then check if the network is indeed generalizing well to novel data.
</p>

In [None]:
# compute the test classification accuracy of the folds
label_condition = ["real", "shuffled"]
stat_mats = {
    "real": [],
    "shuffled": [],
}

test_loader = DataLoader(data_test, batch_size=4, shuffle=False)

accs = np.zeros((10, 2))

for i, c in enumerate(label_condition):
    
    # get the folds
    folds = sorted(glob(os.path.join(f"models/CV-{c}", "*fold*")))
    
    for j, fold in enumerate(folds):
        
        if not os.path.isfile(os.path.join(fold, "test_stats.csv")):
            model = BrainStateClassifier3d(fold)
            model.to(DEVICE)

            with torch.no_grad():
                _ , stats = model.fit(test_loader, train=False)

            df_test = pd.DataFrame(stats.tolist(), columns=[*class_labels, *["real", "predicted"]])
            df_test.to_csv(os.path.join(fold, "test_stats.csv"), index=False)
            stat_mats[c].append(df_test)
            accs[j, i] = compute_accuracy(stats[:, -2], stats[:, -1])
        else:
            stat_mats[c].append(pd.read_csv(os.path.join(fold, "test_stats.csv")))
            accs[j, i] = compute_accuracy(stat_mats[c][j]["real"], stat_mats[c][j]["predicted"])
        

df_accs = pd.DataFrame(accs.tolist(), columns=label_condition)
df_accs.to_csv("stats/motor-accs.csv", index=False)

df_acc_transformed = np.arcsin(df_accs)
df_acc_transformed.to_csv("stats/motor-accs-transformed.csv", index=False)

<p style='text-align: justify;'>The resulting graph of the code below consists of 3 panels. The panel on the left shows the average test classification performance across all folds for the 3D CNNs trained with real and shuffled input-label mappings. The blue and violet dots represent the individual test accuracies for a given fold. The dashed line at 0.2 represents the chance level. <br>
We clearly see that the networks trained with real input-label mappings perform reliably above the chance level, whereas the networks trained with shuffled input-label mappings perform at chance level.
<br><br>
The other two panels show the confusion matrices for the real (left) and shuffled (right) input-label mappings. This is just another representation that the input-label mapping indeed matters for the network to learn to distinguish between the conditions of the motor mapper.
</p>

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(20,6), gridspec_kw={'width_ratios': [.5, 1, 1]})

# plot the average test accuracy 
sns.barplot(ax=ax[0], data=df_accs, color=[.5,.5,.5], alpha=.2, errorbar=("ci", 95), n_boot=5000, capsize=.1, width=.4)
sns.stripplot(ax=ax[0], data=df_accs, zorder=1, alpha=.7, size=8, legend=False, palette="cool")
ax[0].axhline(0.2, linestyle="--", color="black", zorder=0, label="chance level")
ax[0].set(ylabel="accuracy", xlabel="label order", title="average classification performance");
ax[0].legend(frameon=False)
ax[0].spines[["top", "right"]].set_visible(False);

# plot the confusion matrix across all folds for the "real" label order
df = pd.concat(stat_mats["real"])
conf_mat, conf_ax = confusion_matrix(df["real"], df["predicted"], df.columns[:-2], normalize=False, ax=ax[1], **{"vmin": 0, "vmax": 200})
ax[1].set_title("real label order");

# plot the confusion matrix across all folds for the "shuffled" label order
df = pd.concat(stat_mats["shuffled"])
conf_mat, conf_ax = confusion_matrix(df["real"], df["predicted"], df.columns[:-2], normalize=False, ax=ax[2], **{"vmin": 0, "vmax": 200})
ax[2].set_title("shuffled label order");
fig.colorbar(conf_ax, ax=ax[2], )

fig.tight_layout()

plt.savefig('figures/test-performance-across-folds.pdf', facecolor=fig.get_facecolor(), transparent=True)

## Investigate what information was deemed relevant (XAI)


In [None]:
from captum.attr import GuidedBackprop
from zennit.rules import Epsilon, Gamma, Pass
from zennit.types import Convolution, Linear, Activation
from zennit.composites import LayerMapComposite
from delphi.utils.tools import save_in_mni

composite_lrp_map = [
    (Activation, Pass()),
    (Convolution, Gamma(gamma=.25)),
    (Linear, Epsilon(epsilon=0)),
]

LRP = LayerMapComposite(
    layer_map=composite_lrp_map,
)
LRP.__name__ = 'LRP'

In [None]:
label_order = "shuffled"

n_folds = 10

attributor_method = [LRP, GuidedBackprop]

for i, method in enumerate(attributor_method):

    method_name = str(method.__name__).lower()

    for fold in range(n_folds):
        # load the trained network
        model = BrainStateClassifier3d(f"models/CV-{label_order}/{TASK_LABEL}_fold-{fold:02d}")
        model.to(torch.device("cpu"));
        model.eval()
        
        out_dir_name = f"{method_name}/{label_order}/{TASK_LABEL}_fold-{fold:02d}"
        if not os.path.exists(out_dir_name):
            os.makedirs(out_dir_name)

        for j in range(len(class_labels)):

            print(f"Running {method_name} on {class_labels[j]}")

            out_fname = os.path.join(out_dir_name, '%s.nii.gz' % class_labels[j])
            if os.path.isfile(out_fname):
                print(f"{out_fname} already exists. Skipping")
                continue

            dl = DataLoader(
                NiftiDataset('../t-maps/test', [class_labels[j]], 0, device=torch.device("cpu"), transform=ToTensor()),
                batch_size=20, shuffle=False, num_workers=0
            )

            for i, (volume, target) in enumerate(dl):

                attribution = attribute_with_method(method, model, volume, target)

                subject_attr = np.moveaxis(attribution.squeeze().detach().numpy(), 0, -1)
                subject_attr = z_transform_volume(subject_attr)
                avg_attr = subject_attr.mean(axis=-1)

            save_in_mni(subject_attr, out_fname)

            avg_out_name = os.path.join(out_dir_name, "avg")
            if not os.path.exists(avg_out_name):
                os.makedirs(avg_out_name)
            save_in_mni(avg_attr, os.path.join(avg_out_name, '%s.nii.gz' % class_labels[j]))

# SecondLevel GLM analyses

We need some way to compare the attribution maps from LRP and GBP to the original t-maps. To do this, we first compute a group level analysis across the subjects in the left-out test dataset.
That is, we compute a group level 
* t-map, 
* LRP-map, and 
* GBP-map.

In a later stage we can then use, for example, mutual information to quantify the similarity in a whithin subject and subject-vs-grpmap fashion. 
With such comparisons we can check whether mutual information of LRP/GBP maps with their respective group map is larger than, e.g., the mutual information between single subjects' t-maps (i.e., the original input data) with their respective group map.

We therefore will perform the following comparisons:

* subattr vs grpattr (grp attr is a map computed by leaving the subject out of the GLM whos map we want to compare to the grp, e.g., leave-out subject-01 in computing the grp map but use its attribution map to compute mutual information)
* subt vs grpt (same as above)
* subattr vs subt
* grpattr vs grpt

In [None]:
import matplotlib.pyplot as plt
import nilearn
from nilearn.image import load_img
from nilearn.glm.second_level import SecondLevelModel
import nibabel as nib

def save_stat_maps(data, prefix, save_loc='stat-maps', maps_of_interest=["z_score", "stat", "effect_size"]):
    
    if not os.path.exists(save_loc):
        os.makedirs(save_loc)
    
    for i, key in enumerate(maps_of_interest):
        path2save = os.path.join(save_loc, f"{prefix}_{key}.nii.gz")
        nib.save(data[key], path2save)

In [None]:
# load the real images
test_img_files = []
[test_img_files.extend(sorted(glob(f"../copes/test/{class_labels[i]}/*.nii.gz"))) for i in range(len(class_labels))]
images = load_img(test_img_files)

# assign the correct labels to each file
real = np.repeat([0,1,2,3,4], 20)

# a whole brain mask
mask = load_img(mni_template)

In this group level analysis we are interested in computing the average t-map for each condition: footleft, footright, handleft, handright, and tongue. 
Just due to curiousity we will also compute an average t-map across all conditions.

The code cell below produces and plots the design matrix for our GLM analysis.

The data and matrix are setup in a way that the all files of the same condition are grouped together, hence we can see 20 '1's for each condition (first 5 columns). The "average" contrast simply contains "1"s for all input volumes.
As this is a paired design we need to assign the subjects to their respective volumes. You can see this by the sub-XX columns.

In [None]:
# set up the 2nd-level contrast matrix
from nilearn.plotting import plot_design_matrix, plot_contrast_matrix
n_subs = 20
n = len(test_img_files)
design_matrix = {}

for label in range(len(class_labels)):
    regressor = np.repeat(np.eye(len(class_labels))[label], n_subs)
    design_matrix[f'{class_labels[label]}'] = regressor

""" In case we would want some extra contrasts
design_matrix['lf_vs_all'] = np.repeat([1, -.25, -.25, -.25, -.25], n_subs)
design_matrix['rf_vs_all'] = np.repeat([-.25, 1, -.25, -.25, -.25], n_subs)
design_matrix['lh_vs_all'] = np.repeat([-.25, -.25, 1, -.25, -.25], n_subs)
design_matrix['rh_vs_all'] = np.repeat([-.25, -.25, -.25, 1, -.25], n_subs)
design_matrix['t_vs_all'] = np.repeat([-.25, -.25, -.25, -.25, 1], n_subs)
"""

design_matrix['average'] = np.ones(n) # the average across al conditions simply takes all files as input

for sub in range(n_subs):
    regressor = np.tile(np.eye(n_subs)[sub], len(class_labels))
    design_matrix[f'sub-{sub:02d}'] = regressor

df_design_matrix = pd.DataFrame(design_matrix)
fig, ax = plt.subplots(figsize=(8, 10))
plot_design_matrix(df_design_matrix, rescale=False, ax=ax);
ax.set_xticklabels(df_design_matrix.columns, fontsize=15, rotation=90, ha="center");
ax.set_yticks([0, 20, 40, 60, 80])
ax.set_yticklabels([0, 20, 40, 60, 80], fontsize=15);
ax.set_ylabel("sample", fontsize=15);
ax.xaxis.set_label_position('top') 
ax.set_title("2nd-level contrasts", fontsize=20);

plt.savefig('figures/2nd-level-matrix-nilearn.svg', facecolor=fig.get_facecolor(), transparent=True)

In [None]:
# create the glm
glm = SecondLevelModel(smoothing_fwhm=None, mask_img=mask)

In [None]:
# perform the GLM once on the original beta maps
glm.fit([nilearn.image.index_img(images, i) for i in range(n)], design_matrix=df_design_matrix)

contrasts = ["footleft", "footright", "handleft", "handright", "tongue", "average"]
#contrasts = ['footleft', 'footright', 'handleft', 'handright', 'tongue', 'lf_vs_all', 'rf_vs_all', 'lh_vs_all', 'rh_vs_all', 't_vs_all', 'average']

for c, contrast in enumerate(contrasts):
    maps_orig = glm.compute_contrast(contrast, output_type="all")
    save_stat_maps(maps_orig, f"{contrast}", save_loc="stat-maps/orig")

In [None]:
# perform the GLM on the relevance maps for each individual fold
# load the relevance images
contrasts = ["footleft", "footright", "handleft", "handright", "tongue", "average"]
#contrasts = ['footleft', 'footright', 'handleft', 'handright', 'tongue', 'lf_vs_all', 'rf_vs_all', 'lh_vs_all', 'rh_vs_all', 't_vs_all', 'average']

attr_types = ["guidedbackprop", "lrp"] #"lrp" "guidedbackprop"

label_order = "real" #, "shuffled"

overwrite = False

for a, attr_type in enumerate(attr_types):

    for fold in range(10):

        print(attr_type, label_order, fold)

        save_loc = os.path.join("stat-maps", attr_type, label_order, f"motor_fold-{fold:02d}")

        if not os.path.isdir(save_loc) or overwrite:

            relevance_files = sorted(glob(os.path.join(attr_type, label_order, f"motor_fold-{fold:02d}/*.nii.gz")))
            relevance_maps = load_img(relevance_files)

            # set up the GLM for the current fold
            glm.fit([nilearn.image.index_img(relevance_maps , i) for i in n_subse(n)], design_matrix=df_design_matrix)

            for c, contrast in enumerate(contrasts):
                maps_attr = glm.compute_contrast(contrast, output_type="all")
                save_stat_maps(maps_attr, f"{contrast}", save_loc=save_loc)


### Leave-one-subject-out grp maps

Now that we have the group LRP, GBP, and t-maps, we can turn to the leave-one-subject-out tests. We need these group statistics as well to be unbiased in computing the sub(attr/t) vs group(attr/t) mutual information. 

In [None]:
# compute a GLM for the test data set by leaving 1-subject out. We will need this for the mutual information analysis later
contrasts = ["footleft", "footright", "handleft", "handright", "tongue", "average"]

sub_list = np.tile(np.arange(n_subs), 5)

for s in tqdm(range(n_subs)):
    
    # copy the original design matrix
    this_dm = df_design_matrix.copy()
    
    # now drop the respective subject
    this_dm = this_dm.drop(labels=np.squeeze(np.where(sub_list == s)), axis="index")
    this_dm = this_dm.drop(columns=[f"sub-{s:02d}"])
    #plot_design_matrix(this_dm)
    
    idcs = np.arange(n)
    idcs = np.delete(idcs, np.where(sub_list == s))
    
    # perform the GLM on the all subs except s
    glm.fit([nilearn.image.index_img(images, i) for i in idcs], design_matrix=this_dm)

    for c, contrast in enumerate(contrasts):
        maps_orig = glm.compute_contrast(contrast, output_type="all")
        save_stat_maps(maps_orig, f"{contrast}_wo-sub{s:02d}", save_loc="stat-maps/orig/left-out", maps_of_interest=["z_score"])

In [None]:
# COMPUTE THE GLM FOR EACH FOLD BUT LEAVE OUT ONE SUBJECT AT A TIME
# perform the GLM on the relevance maps for each individual fold
# load the relevance images
contrasts = ["footleft", "footright", "handleft", "handright", "tongue", "average"]

attr_types = ["guidedbackprop", "lrp"] #"lrp" "guidedbackprop"

label_orders = "real" #, "shuffled"

overwrite = False

n_folds = 10

sub_list = np.tile(np.arange(n_subs), 5)

for a, attr_type in enumerate(attr_types):
        
    print(attr_type, label_order)
        
    for fold in tqdm(range(n_folds), desc="folds"):

        save_loc = os.path.join("stat-maps", attr_type, label_order, f"motor_fold-{fold:02d}", "left-out")

        if not os.path.isdir(save_loc) or overwrite:
            
            # load the files for a given fold
            relevance_files = sorted(glob(os.path.join(attr_type, label_order, f"motor_fold-{fold:02d}/*.nii.gz")))
            relevance_maps = load_img(relevance_files)

            for s in tqdm(range(n_subs), leave=False, desc="subs"):

                # copy the original design matrix
                this_dm = df_design_matrix.copy()

                # now drop subject s from the design matrix
                this_dm = this_dm.drop(labels=np.squeeze(np.where(sub_list == s)), axis="index")
                this_dm = this_dm.drop(columns=[f"sub-{s:02d}"])

                # remove the subject s' indices from the data list
                idcs = np.arange(n)
                idcs = np.delete(idcs, np.where(sub_list == s))
                
                # perform the GLM on the all subs except s
                glm.fit([nilearn.image.index_img(relevance_maps, i) for i in idcs], design_matrix=this_dm)

                for c, contrast in enumerate(contrasts):
                    maps_orig = glm.compute_contrast(contrast, output_type="all")
                    save_stat_maps(maps_orig, f"{contrast}_wo-sub{s:02d}", save_loc=save_loc, maps_of_interest=["z_score"])

## SANITY CHECKS FOR RELEVANCE/XAI
### OCCLUSION (faithfulness test)

In [None]:
from delphi.utils.tools import occlude_images

In [None]:
percentages = np.concatenate([
    np.arange(0, 2, .2),
    np.arange(2, 5, .5),
    np.arange(5, 10, 1),
    np.arange(10, 32, 2)
])

xai_algo = ["lrp", "guidedbackprop"]
#accs = np.zeros((10, len(percentages)))
#n_voxel = np.zeros_like(percentages)

df_occlu = []

if not os.path.isfile("stats/occlusion_curves.csv"):

    for i, algo in enumerate(xai_algo):

        folds = sorted(glob(os.path.join(algo, "real", "*fold*")))

        for f, fold in enumerate(folds):
            # load the trained network
            model = BrainStateClassifier3d(f"models/CV-real/{TASK_LABEL}_fold-{f:02d}")
            model.eval()
            model.to(torch.device("cpu"));

            relevance_files = sorted(glob(os.path.join(fold, "*nii.gz")))
            relevance_maps = load_img(relevance_files)

            accs = np.zeros(len(percentages))

            for i, frac in enumerate(percentages):
                print(fold, i)
                occluded, _ = occlude_images(images, relevance_maps, mask, fraction=frac, get_fdata=True)
                occluded = np.moveaxis(occluded, -1, 0)
                occluded = torch.tensor(occluded).unsqueeze(1)
                pred = np.argmax(model(occluded.float()).detach().cpu().numpy(), axis=1)
                accs[i] = compute_accuracy(real, pred)

            df_inter = pd.DataFrame(accs.tolist(), columns=["accuracy"])
            df_inter["algorithm"] = algo
            df_inter["prct_occluded"] = percentages

            df_occlu.append(df_inter)

    df_occlusion = pd.concat(df_occlu)
    df_occlusion.to_csv("stats/occlusion_curves.csv", index=False)
else:
    df_occlusion = pd.read_csv("stats/occlusion_curves.csv")

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
sns.scatterplot(ax=ax, data=df_occlusion, x="prct_occluded", y="accuracy", hue="algorithm", alpha=.4, legend=False, zorder=0)
sns.lineplot(ax=ax, data=df_occlusion, x="prct_occluded", y="accuracy", hue="algorithm", style="algorithm", 
             errorbar=("ci", 95), n_boot=5000, linewidth=2, markers=True, markersize=10)
ax.axhline(.2, color="black", linestyle="--", zorder=0)
ax.legend(frameon=False, loc="center right", fontsize=18, title="attribution method")
ax.spines[["top", "right"]].set_visible(False)

plt.savefig("figures/occlusion_curves.pdf", facecolor=fig.get_facecolor(), transparent=True)

# Compute and investigate mutual information between different combination of attribution- and (group) t-maps

Mutual information can help us understand how similar the attribution maps of the LRP or GBP algorithm are to the original data format, i.e., t-statistics maps of single subjects or of the group. 

We can further perform some sanity checks this way:
* Comparing attribution maps obtained by networks trained with real input-label mappings to those with shuffled mappings can tell us something about the sensitivity to network parameters
* Comparing attribution maps to t-maps of single subjects or to the group t-maps allows us to judge how well the attribution maps reflect univariate statistics (may not be so fair because CNNs are multivariate!)
* We can see whether LRP or GBP has higher better mutual information to the original input data.

In [None]:
from sklearn.feature_selection import mutual_info_regression
from nilearn.masking import apply_mask
from nilearn.image import load_img

In [None]:
#mask = "stat-maps/motor_mask.nii.gz"
mask = mni_template
contrasts = ["subattr-vs-grpattr", "subt-vs-grpt"]
#mask = mni_template
for i, contrast in enumerate(contrasts):
    test = compute_mi(TASK_LABEL, class_labels, mask, "mi_whole-brain", n_folds=10, attr_methods=["lrp", "guidedbackprop"], contrast=contrast)

In [None]:
mask = mni_template
contrasts = ["subattr-vs-subt", "attr-real-vs-shuffled", "subattr-vs-grporig", "grpattr-vs-grporig"]

for c, contrast in enumerate(contrasts):
    test = compute_mi(TASK_LABEL, class_labels, mask, "mi_whole-brain", n_folds=10, attr_methods=["lrp", "guidedbackprop"], contrast=contrast)

In [None]:
SMALL_SIZE = 18
MEDIUM_SIZE = 22
BIGGER_SIZE = 26

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
contrasts = ["subattr-vs-subt", "attr-real-vs-shuffled", "subattr-vs-grporig", "grpattr-vs-grporig"]
dfs = []
for c, contrast in enumerate(contrasts):
    dfs.append(pd.read_csv(f"stats/mi_whole-brain_{contrast}.csv"))
    dfs[c]["contrast"] = contrast

df = pd.concat(dfs)

fig, ax = plt.subplots(figsize=(20, 8), sharex=True, sharey=True)
sns.boxenplot(ax=ax, data=df, x="attr_method", y="mi", hue="contrast", palette="cool", k_depth="trustworthy")
ax.spines[["top", "right"]].set_visible(False)

ax.set_title("across classes and folds")

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(20,12), sharex=True, sharey=True)
axes = axes.flatten()
for a, ax in enumerate(axes):
    plt_legend=True if a == 2 else False
    sns.boxenplot(ax=ax, data=df[df.contrast==contrasts[a]], x="class", y="mi", hue="attr_method", k_depth="trustworthy")
    ax.spines[["top", "right"]].set_visible(False)
    ax.set_title("{} across folds".format(contrasts[a]))

    if not plt_legend:
        ax.legend([],[], frameon=False)
    else:
        ax.legend(frameon=False)
    
fig.tight_layout()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(20,12), sharex=True, sharey=True)
axes = axes.flatten()
for a, ax in enumerate(axes):
    plt_legend=True if a == 2 else False
    sns.boxenplot(ax=ax, data=df[df.contrast==contrasts[a]], x="fold", y="mi", hue="attr_method", k_depth="trustworthy")
    ax.spines[["top", "right"]].set_visible(False)
    ax.set_title("{} across classes".format(contrasts[a]))
    ax.set_xticks(np.arange(0,10))
    ax.set_xticklabels(np.arange(1,11))

    if not plt_legend:
        ax.legend([],[], frameon=False)
    else:
        ax.legend(frameon=False)
    
fig.tight_layout()

In [None]:
dfs[3][:25]

In [None]:
print(dfs[3].columns)
dfs[3].pivot(index="fold", values="mi", columns=["contrast", "attr_method", "class"])

In [None]:
test = dfs[0].copy()
test["sub"] = np.tile(np.arange(0,20), 100)
test.pivot(index="sub", values="mi", columns=["contrast", "attr_method", "class", "fold"]).to_csv("test.csv", index=False)

In [None]:
mask = ["whole-brain", "motor-mask"]

fig, axes = plt.subplots(2, 2, figsize=(20, 10), sharex=True, sharey=True)

for i, m in enumerate(mask):
        
    # plot the grp-attibution vs grp-orig mutual information
    df = pd.read_csv(f"stats/mi_{m}-grp-vs-grp.csv")
    sns.boxenplot(ax=axes[0, i], data=df.melt(id_vars = "attr_method", var_name="class", value_name="mi"), 
                x="class", y="mi", hue="attr_method")
    axes[0, i].set_ylabel("$MI_{(GrpAttr, GrpOrig)}$")
    
    # plot the sub-attibution vs grp-orig mutual information
    df = pd.read_csv(f"stats/mi_{m}-sub-vs-grp.csv")
    sns.boxenplot(ax=axes[1, i], data=df, x="class", y="mi", hue="attr_method")
    axes[1, i].set_ylabel("$MI_{(SubAttr, GrpOrig)}$")
    
    if i == 0:
        axes[0, 0].legend(frameon=False, title="attribution method")
        axes[1, 0].legend([],[], frameon=False)
    else:
        axes[0, i].legend([],[], frameon=False)
        axes[1, i].legend([],[], frameon=False)
    axes[0, i].spines[["top", "right"]].set_visible(False)
    axes[1, i].spines[["top", "right"]].set_visible(False)

    axes[0, i].set_title(m)
    
plt.savefig("figures/mutualinfo_motor.pdf", facecolor=fig.get_facecolor(), transparent=True) 