# Train and evaluate a 3D Convolutional Neural Network (3dCNN) to classify motor tasks

<p style='text-align: justify;'>In this notebook we create and train a 3D-Convolutional Neural Network which learns to classify different patterns of whole-brain fMRI statistical parameters (t-scores). In this first approach our goal is to train a classifier that can reliably distinguish between such whole-brain patterns for five limb movements (i.e., left/right hand, left/right foot, and tongue).

</p>

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import os, wandb, torch, time
import pandas as pd
import numpy as np
from glob import glob
from torch.utils.data import DataLoader
from torchinfo import summary
import seaborn as sns
import matplotlib.pyplot as plt

from delphi import mni_template
from delphi.networks.ConvNets import BrainStateClassifier3d
from delphi.utils.datasets import NiftiDataset
from delphi.utils.tools import ToTensor, compute_accuracy, convert_wandb_config, read_config, z_transform_volume
from delphi.utils.plots import confusion_matrix

# you can find all these files in ../utils
from utils.tools import attribute_with_method 
from utils.random import set_random_seed

from tqdm.notebook import tqdm

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



<p style='text-align: justify;'>To make sure that we obtain (almost) the same results for each execution we set the random seed of multiple different librabries (i.e., torch, random, numpy)</p>

In [3]:
g = set_random_seed(2020) # the project started in the year 2020, hence the seed

## Initializations

In this section, we define and initialize our required variables. We first need to define which classes we want to predict, i.e., the conditions of the motor mapper. We then define a PyTorch dataset; in this case `NiftiDataset` is a custom written Dataset-Class (see https://github.com/PhilippS893/delphi). As is common practice in machine learning projects, we split our data into a training and validation dataset (ratio=80 to 20, respectively).

Note: In case it is necessary to create a null-model, i.e., a neural network that is trained on data where the labels are randomized, one can set the parameter `shuffe_labels=False` to `True`. This is usually done to have a baseline for the null hypothesis that "everything is random".

In [4]:
TASK_LABEL = "motor"
class_labels = sorted(["handleft", "handright", "footleft", "footright", "tongue"])

data_test = NiftiDataset("../t-maps/test", class_labels, 0, device=DEVICE, transform=ToTensor())

<p style='text-align: justify;'>We now set some parameters required by w&b to properly store information about our trained neural networks.</p>

# Identify the best fold

In [5]:
from utils.wandb_funcs import get_wandb_csv

In [6]:
keys_of_interest = ['group', 'job_type', 'run_name', 'test_accuracy', 'train_acc', 'valid_acc', 
                    'valid_loss', 'best_valid_epoch', 'best_valid_accuracy', 'test_loss', 'train_loss']
wandb_df = get_wandb_csv("philis893", "thesis", "first-steps-motor", keys_of_interest, overwrite=True)

{'entity': 'philis893', 'project': 'thesis', 'filters': {'group_name': {'$regex': 'first-steps-motor.*'}, 'jobType': 'None'}, 'order': '-created_at', '_sweeps': {}, 'client': <wandb.apis.public.RetryingClient object at 0x7fe1986038e0>, 'variables': {'project': 'thesis', 'entity': 'philis893', 'order': '-created_at', 'filters': '{"group_name": {"$regex": "first-steps-motor.*"}, "jobType": "None"}'}, 'per_page': 50, 'objects': [], 'index': -1, 'last_response': None}


In [7]:
real_runs = wandb_df[wandb_df.job_type == "CV-7folds"]
# sort according to best_valid_accuracy (desc), test_accuracy (desc), and test_loss (asc). Take the first entry => best fold
real_runs_sorted = real_runs.sort_values(["best_valid_accuracy", "test_accuracy", "test_loss"], ascending=[False, False, True])
real_runs_sorted.head(1)

Unnamed: 0,group,job_type,run_name,test_accuracy,train_acc,valid_acc,valid_loss,best_valid_epoch,best_valid_accuracy,test_loss,train_loss


In [8]:
BEST_FOLD = 3;

# Investigate what information was deemed relevant (XAI)


In [9]:
from captum.attr import GuidedBackprop
from zennit.rules import Epsilon, Gamma, Pass
from zennit.types import Convolution, Linear, Activation
from zennit.composites import LayerMapComposite
from delphi.utils.tools import save_in_mni
from utils.tools import attribute_with_method

composite_lrp_map = [
    (Activation, Pass()),
    (Convolution, Gamma(gamma=.25)),
    (Linear, Epsilon(epsilon=0)),
]

LRP = LayerMapComposite(
    layer_map=composite_lrp_map,
)
LRP.__name__ = 'LRP'

In [12]:
label_order = "shuffled"

#fold = BEST_FOLD

attributor_method = [LRP, GuidedBackprop]

for i, method in enumerate(attributor_method):
    
    for fold in range(7):
        method_name = str(method.__name__).lower()

        # load the trained network
        model = BrainStateClassifier3d(f"models/CV-7folds-{label_order}/fold-{fold:02d}")
        model.to(torch.device("cpu"));
        model.eval()

        out_dir_name = f"{method_name}/{label_order}/fold-{fold:02d}"
        if not os.path.exists(out_dir_name):
            os.makedirs(out_dir_name)

        for j in range(len(class_labels)):

            print(f"Running {method_name} on {class_labels[j]}")

            out_fname = os.path.join(out_dir_name, '%s.nii.gz' % class_labels[j])
            if os.path.isfile(out_fname):
                print(f"{out_fname} already exists. Skipping")
                continue

            dl = DataLoader(
                NiftiDataset('../t-maps/test', [class_labels[j]], 0, device=torch.device("cpu"), transform=ToTensor()),
                batch_size=20, shuffle=False, num_workers=0
            )

            for i, (volume, target) in enumerate(dl):

                attribution = attribute_with_method(method, model, volume, target)

                subject_attr = np.moveaxis(attribution.squeeze().detach().numpy(), 0, -1)
                subject_attr = z_transform_volume(subject_attr)
                avg_attr = subject_attr.mean(axis=-1)

            save_in_mni(subject_attr, out_fname)

            avg_out_name = os.path.join(out_dir_name, "avg")
            if not os.path.exists(avg_out_name):
                os.makedirs(avg_out_name)
            save_in_mni(avg_attr, os.path.join(avg_out_name, '%s.nii.gz' % class_labels[j]))

Loading from config file models/CV-7folds-shuffled/fold-00/config.yaml
Running lrp on footleft
Saving lrp/shuffled/fold-00/footleft.nii.gz
Saving lrp/shuffled/fold-00/avg/footleft.nii.gz
Running lrp on footright
Saving lrp/shuffled/fold-00/footright.nii.gz
Saving lrp/shuffled/fold-00/avg/footright.nii.gz
Running lrp on handleft
Saving lrp/shuffled/fold-00/handleft.nii.gz
Saving lrp/shuffled/fold-00/avg/handleft.nii.gz
Running lrp on handright
Saving lrp/shuffled/fold-00/handright.nii.gz
Saving lrp/shuffled/fold-00/avg/handright.nii.gz
Running lrp on tongue
Saving lrp/shuffled/fold-00/tongue.nii.gz
Saving lrp/shuffled/fold-00/avg/tongue.nii.gz
Loading from config file models/CV-7folds-shuffled/fold-01/config.yaml
Running lrp on footleft
Saving lrp/shuffled/fold-01/footleft.nii.gz
Saving lrp/shuffled/fold-01/avg/footleft.nii.gz
Running lrp on footright
Saving lrp/shuffled/fold-01/footright.nii.gz
Saving lrp/shuffled/fold-01/avg/footright.nii.gz
Running lrp on handleft
Saving lrp/shuffled

Another sanity check for attribution methods put forth by Adebayo and colleagues is to test the similarity between attribution maps of trained networks with those of randomly initialized weights.
If they were to be similar, we can say that the attribution map is insensitive to the model parameters, which would reduce confidence in the attribution maps of trained networks.

In [14]:
n_folds = 1

attributor_method = [LRP, GuidedBackprop]

model_config = read_config("hyperparameter.yaml")

seed = 1337
ctr = 0

for i, method in enumerate(attributor_method):

    method_name = str(method.__name__).lower()

    for fold in range(n_folds):
       
        g = set_random_seed(seed + ctr)
        
        model = BrainStateClassifier3d((91, 109, 91), len(class_labels), model_config)
        model.to(torch.device("cpu"));
        model.eval()
        
        out_dir_name = f"{method_name}/model-random/{TASK_LABEL}_fold-{fold:02d}"
        if not os.path.exists(out_dir_name):
            os.makedirs(out_dir_name)

        for j in range(len(class_labels)):

            print(f"Running {method_name} on {class_labels[j]}")

            out_fname = os.path.join(out_dir_name, '%s.nii.gz' % class_labels[j])
            if os.path.isfile(out_fname):
                print(f"{out_fname} already exists. Skipping")
                continue

            dl = DataLoader(
                NiftiDataset('../t-maps/test', [class_labels[j]], 0, device=torch.device("cpu"), transform=ToTensor()),
                batch_size=20, shuffle=False, num_workers=0
            )

            for i, (volume, target) in enumerate(dl):

                attribution = attribute_with_method(method, model, volume, target)

                subject_attr = np.moveaxis(attribution.squeeze().detach().numpy(), 0, -1)
                subject_attr = z_transform_volume(subject_attr)
                avg_attr = subject_attr.mean(axis=-1)

            save_in_mni(subject_attr, out_fname)

            avg_out_name = os.path.join(out_dir_name, "avg")
            if not os.path.exists(avg_out_name):
                os.makedirs(avg_out_name)
            save_in_mni(avg_attr, os.path.join(avg_out_name, '%s.nii.gz' % class_labels[j]))
            
        ctr+=1

Running lrp on footleft
lrp/model-random/motor_fold-00/footleft.nii.gz already exists. Skipping
Running lrp on footright
lrp/model-random/motor_fold-00/footright.nii.gz already exists. Skipping
Running lrp on handleft
lrp/model-random/motor_fold-00/handleft.nii.gz already exists. Skipping
Running lrp on handright
lrp/model-random/motor_fold-00/handright.nii.gz already exists. Skipping
Running lrp on tongue
lrp/model-random/motor_fold-00/tongue.nii.gz already exists. Skipping
Running guidedbackprop on footleft
guidedbackprop/model-random/motor_fold-00/footleft.nii.gz already exists. Skipping
Running guidedbackprop on footright
guidedbackprop/model-random/motor_fold-00/footright.nii.gz already exists. Skipping
Running guidedbackprop on handleft
guidedbackprop/model-random/motor_fold-00/handleft.nii.gz already exists. Skipping
Running guidedbackprop on handright
guidedbackprop/model-random/motor_fold-00/handright.nii.gz already exists. Skipping
Running guidedbackprop on tongue
guidedbackp