# Train a 3D Convolutional Neural Network (3dCNN) to classify motor tasks

In [1]:
import os
import wandb
import torch
import pandas as pd
import numpy as np
import time
from delphi import mni_template
from glob import glob
from torch.utils.data import DataLoader
from torchinfo import summary

from delphi.networks.ConvNets import BrainStateClassifier3d
from delphi.utils.datasets import NiftiDataset
from delphi.utils.tools import ToTensor, compute_accuracy, convert_wandb_config, read_config, z_transform_volume
from delphi.utils.plots import confusion_matrix

from sklearn.model_selection import StratifiedShuffleSplit

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def set_random_seed(seed):
    import random
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    g = torch.Generator()  # can be used in pytorch dataloaders for reproducible sample selection when shuffle=True
    g.manual_seed(seed)

    return g

g = set_random_seed(2020)

In [3]:
def wandb_plots(y_true, y_pred, y_prob, class_labels, dataset):
    wandb.log({
        f"{dataset}-ROC": wandb.plot.roc_curve(y_true=y_true, y_probas=y_prob, labels=class_labels, title=f"{dataset}-ROC"),
        f"{dataset}-PR": wandb.plot.pr_curve(y_true=y_true, y_probas=y_prob, labels=class_labels, title=f"{dataset}-PR"),
        f"{dataset}-ConfMat": wandb.plot.confusion_matrix(y_true=y_true, preds=y_pred, class_names=class_labels, title=f"{dataset}-ConfMat")
    })

In [4]:
def reset_wandb_env():
    exclude = {
        "WANDB_PROJECT",
        "WANDB_ENTITY",
        "WANDB_API_KEY",
    }
    for k, v in os.environ.items():
        if k.startswith("WANDB_") and k not in exclude:
            del os.environ[k]

In [5]:
class_labels = sorted(["handleft", "handright", "footleft", "footright", "tongue"])
print(class_labels)

data_test = NiftiDataset("../t-maps/test", class_labels, 0, device=DEVICE, transform=ToTensor())

# we will split the train dataset into a train (80%) and validation (20%) set.
data_train_full = NiftiDataset("../t-maps/train", class_labels, 0, device=DEVICE, transform=ToTensor(), shuffle_labels=False)

['footleft', 'footright', 'handleft', 'handright', 'tongue']


In [None]:
# set the wandb sweep config
# os.environ['WANDB_MODE'] = 'offline'
os.environ['WANDB_ENTITY'] = "philis893" # this is my wandb account name. This can also be a group name, for example
os.environ['WANDB_PROJECT'] = "thesis" # this is simply the project name where we want to store the sweep logs and plots

In [None]:
#hp = read_config("hyperparameter.yaml")
hp = read_config("hyperparameter.yaml")

num_folds = 10
run = 0
input_dims = (91, 109, 91)

# we want one stratified shuffled split
sss = StratifiedShuffleSplit(n_splits=num_folds, test_size=0.2, random_state=2020)

for fold, (idx_train, idx_valid) in enumerate(sss.split(data_train_full.data, data_train_full.labels)):
    reset_wandb_env()
    wandb_kwargs = {
        "entity": "philis893",
        "project": "thesis",
        "group": "first-steps-motor",
        "name": f"motor-classifier_fold-{fold:02d}",
        "job_type": "CV-motor" if num_folds > 1 else "train",
    }

    data_train = torch.utils.data.Subset(data_train_full, idx_train)
    data_valid = torch.utils.data.Subset(data_train_full, idx_valid)
    
    save_name = os.path.join("models", wandb_kwargs["name"])
    if os.path.exists(save_name):
        continue
        
    g = set_random_seed(2020 + fold + run)
        
    with wandb.init(config=hp, **wandb_kwargs) as run:
        
        model_cfg = {
            "channels": [1, 8, 16, 32, 64],
            "lin_neurons": [128, 64],
            "pooling_kernel": 2,
            "kernel_size": run.config.kernel_size,
            "dropout": run.config.dropout,
        }
        model = BrainStateClassifier3d(input_dims, len(class_labels), model_cfg)
        model.to(DEVICE);
        
        dl_train = DataLoader(data_train, batch_size=run.config.batch_size, shuffle=True, generator=g)
        dl_valid = DataLoader(data_valid, batch_size=run.config.batch_size, shuffle=True, generator=g)
        dl_test = DataLoader(data_test, batch_size=run.config.batch_size, shuffle=False, generator=g)

        best_loss, best_acc = 100, 0
        loss_acc = []
        train_stats, valid_stats = [], []

        # loop for the above set number of epochs
        for epoch in range(run.config.epochs):
            _, _ = model.fit(dl_train, lr=run.config.learning_rate, device=DEVICE)

            # for validating or testing set the network into evaluation mode such that layers like dropout are not active
            with torch.no_grad():
                tloss, tstats = model.fit(dl_train, device=DEVICE, train=False)
                vloss, vstats = model.fit(dl_valid, device=DEVICE, train=False)

            # the model.fit() method has 2 output parameters: loss, stats = model.fit()
            # the first parameter is simply the loss for each sample
            # the second parameter is a matrix of n_classes+2-by-n_samples
            # the first n_classes columns are the output probabilities of the model per class
            # the second to last column (i.e., [:, -2]) represents the real labels
            # the last column (i.e., [:, -1]) represents the predicted labels
            tacc = compute_accuracy(tstats[:, -2], tstats[:, -1])
            vacc = compute_accuracy(vstats[:, -2], vstats[:, -1])

            loss_acc.append(pd.DataFrame([[tloss, vloss, tacc, vacc]],
                                         columns=["train_loss", "valid_loss", "train_acc", "valid_acc"]))

            train_stats.append(pd.DataFrame(tstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
            train_stats[epoch]["epoch"] = epoch
            valid_stats.append(pd.DataFrame(vstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
            valid_stats[epoch]["epoch"] = epoch

            wandb.log({
                "train_acc": tacc, "train_loss": tloss,
                "valid_acc": vacc, "valid_loss": vloss
            }, step=epoch)

            print('Epoch=%03d, train_loss=%2.3f, train_acc=%1.3f, valid_loss=%2.3f, valid_acc=%1.3f' % 
                 (epoch, tloss, tacc, vloss, vacc))

            if (vacc >= best_acc) and (vloss <= best_loss):
                # assign the new best values
                best_acc, best_loss = vacc, vloss
                wandb.run.summary["best_valid_accuracy"] = best_acc
                wandb.run.summary["best_valid_epoch"] = epoch
                # save the current best model
                model.save(save_name)
                # plot some graphs for the validation data
                wandb_plots(vstats[:, -2], vstats[:, -1], vstats[:, :-2], class_labels, "valid")


        # save the files
        full_df = pd.concat(loss_acc)
        full_df.to_csv(os.path.join(save_name, "loss_acc_curves.csv"), index=False)
        full_df = pd.concat(train_stats)
        full_df.to_csv(os.path.join(save_name, "train_stats.csv"), index=False)
        full_df = pd.concat(valid_stats)
        full_df.to_csv(os.path.join(save_name, "valid_stats.csv"), index=False)

        # EVALUATE THE MODEL ON THE TEST DATA
        with torch.no_grad():
            testloss, teststats = model.fit(dl_test, train=False)
        testacc = compute_accuracy(teststats[:, -2], teststats[:, -1])
        wandb.run.summary["test_accuracy"] = testacc

        wandb.log({"test_accuracy": testacc, "test_loss": testloss})
        wandb_plots(teststats[:, -2], teststats[:, -1], teststats[:, :-2], class_labels, "test")

        wandb.finish()

# Run the Layer-wise Relevance Propagation (LRP) algorithm

In [6]:
from delphi.utils.tools import save_in_mni
from zennit import composites

In [None]:
# Execute the LRP algorithm and save the resulting LRP maps somewhere.
shape = (1, 1, 91, 109, 91)

composite_kwargs = {
    'low': -1 * torch.ones(*shape, device=torch.device("cpu")),  # the lowest and ...
    'high': 1 * torch.ones(*shape, device=torch.device("cpu")),  # the highest pixel value for ZBox
}
# In this line I am telling zennit what kind of rule I want to use and supply
# some arguments. 
attributor = composites.COMPOSITES['epsilon_gamma_box'](**composite_kwargs)

for param in model.parameters():
    param.requires_grad = False


for fold in range(10):
# load the trained network
    model = BrainStateClassifier3d(f"models/motor-classifier_fold-{fold:02d}")
    model.eval()
    model.to(torch.device("cpu"));

    out_dir_name = f"lrp/epsgamma_fold-{fold:02d}"
    if not os.path.exists(out_dir_name):
        os.mkdir(out_dir_name)

    for j in range(len(class_labels)):

        print(f"Running LRP on {class_labels[j]}")
        dl = DataLoader(
            NiftiDataset('../t-maps/test', [class_labels[j]], 0, device=torch.device("cpu"), transform=ToTensor()),
            batch_size=20, shuffle=False, num_workers=0
        )

        for i, (volume, target) in enumerate(dl):

            volume.requires_grad = True
            grad_dummy = torch.eye(model.config['n_classes'], device=torch.device("cpu"))[target]

            with attributor.context(model) as modified_model:

                output = modified_model(volume)
                attribution, = torch.autograd.grad(output, volume, grad_outputs=grad_dummy)
                subject_lrp = np.moveaxis(attribution.squeeze().detach().numpy(), 0, -1)
                subject_lrp = z_transform_volume(subject_lrp)
                avg_lrp = subject_lrp.mean(axis=-1)

        save_in_mni(subject_lrp, os.path.join(out_dir_name, '%s.nii.gz' % class_labels[j]))
        #save_in_mni(avg_lrp, os.path.join(out_dir_name, 'avg_%s.nii.gz' % class_labels[j]))
        

In [None]:
# let us visualize the average LRP maps
# Brain_Data from the nltools toolbox allows us to visualize brain maps quit nicely
from nltools.data import Brain_Data

files = sorted(glob(f"motor-mapper-lrp-epsgamma/handleft.nii.gz"))
#files = sorted(glob("../t-maps/test/handleft/*.nii.gz"))
dat = Brain_Data(files)
dat.iplot()

# SecondLevel GLM analyses

In [7]:
import matplotlib.pyplot as plt

In [8]:
from nilearn.image import load_img

In [12]:
# load the real images
test_img_files = sorted(glob("../t-maps/test/*/*.nii.gz"))
images = load_img(test_img_files)
real = np.repeat([0,1,2,3,4], 20)
#real = np.repeat([2,3], 20)

# load the relevance images
fold = 1
relevance_files = sorted(glob(f"lrp/epsgamma_fold-{fold:02d}/*nii.gz"))
relevance_maps = load_img(relevance_files)

mask = load_img(mni_template)

In [None]:
from nilearn.plotting import plot_design_matrix, plot_contrast_matrix
n_subs = 20
n = len(test_img_files)
design_matrix = {}

for label in range(len(class_labels)):
    regressor = np.repeat(np.eye(len(class_labels))[label], n_subs)
    design_matrix[f'{class_labels[label]}'] = regressor
               
design_matrix['lh_vs_rh'] = np.repeat([0, 0, 1, -1, 0], n_subs)
design_matrix['lf_vs_rf'] = np.repeat([1, -1, 0, 0, 0], n_subs)
design_matrix['t_vs_all'] = np.repeat([-.25, -.25, -.25, -.25, 1], n_subs)
design_matrix['f_vs_h'] = np.repeat([.5, .5, -.5, -.5, 0], n_subs)
design_matrix['intercept'] = np.ones(n)

for sub in range(n_subs):
    regressor = np.tile(np.eye(n_subs)[sub], len(class_labels))
    design_matrix[f'sub-{sub}'] = regressor

df_design_matrix = pd.DataFrame(design_matrix)
plot_design_matrix(df_design_matrix, rescale=False);

In [None]:
import nilearn
from nilearn.glm.second_level import SecondLevelModel
glm_orig = SecondLevelModel(smoothing_fwhm=None, mask_img=mask)
glm_orig.fit([nilearn.image.index_img(images, i) for i in range(n)], design_matrix=df_design_matrix)

glm_lrp = SecondLevelModel(smoothing_fwhm=None, mask_img=mask)
glm_lrp.fit([nilearn.image.index_img(relevance_maps , i) for i in range(n)], design_matrix=df_design_matrix)

In [None]:
contrasts = ["footleft", "footright", "handleft", "handright", "tongue"]
#"lh_vs_rh", "lf_vs_rf", "t_vs_all", "f_vs_h"]

for c, contrast in enumerate(contrasts):
    maps_orig = glm_orig.compute_contrast(contrast, output_type="all")
    maps_lrp = glm_lrp.compute_contrast(contrast, output_type="all")

    fig, axes = plt.subplots(1,2,figsize=(18,6))
    nilearn.plotting.plot_glass_brain(maps_orig['z_score'], colorbar=True, plot_abs=False, axes=axes[0]);
    axes[0].set_title(f"{contrast}-orig")
    nilearn.plotting.plot_glass_brain(maps_lrp['z_score'], colorbar=True, plot_abs=False, axes=axes[1]);
    axes[1].set_title(f"{contrast}-lrp")

# TEST OCCLUSION

In [10]:
from delphi.utils.tools import occlude_images

In [None]:
percentages = np.array([0, .01, .1, .2, .5, .75, 1, 5, 10, 20, 30, 50, 75])
accs = np.zeros((8, len(percentages)))
n_voxel = np.zeros_like(percentages)


for f in range(8):
    # load the trained network
    model = BrainStateClassifier3d(f"models/motor-classifier_fold-{f:02d}")
    model.eval()
    model.to(torch.device("cpu"));
    
    relevance_files = sorted(glob(f"lrp/epsgamma_fold-{fold:02d}/*nii.gz"))
    relevance_maps = load_img(relevance_files)

    for i, frac in enumerate(percentages):
        occluded, n_voxel[f] = occlude_images(images, relevance_maps, mask, fraction=frac, get_fdata=True)
        occluded = np.moveaxis(occluded, -1, 0)
        occluded = torch.tensor(occluded).unsqueeze(1)
        pred = np.argmax(model(occluded.float()).detach().cpu().numpy(), axis=1)
        accs[f, i] = compute_accuracy(real, pred)

Loading from config file models/motor-classifier_fold-00/config.yaml
Loading from config file models/motor-classifier_fold-01/config.yaml
Loading from config file models/motor-classifier_fold-02/config.yaml
Loading from config file models/motor-classifier_fold-03/config.yaml
Loading from config file models/motor-classifier_fold-04/config.yaml


In [None]:
mean_acc = accs[:7].mean(axis=0)
sd_acc = accs[:7].std(axis=0)
fig, ax = plt.subplots(figsize=(20, 8))
ax.errorbar(np.log(percentages), mean_acc, yerr=sd_acc, marker="o", linewidth=3, capsize=5, capthick=2);
ax.scatter(np.repeat(np.log(percentages), 7), accs[:7].transpose(), marker="o", alpha=.2, color="black");
ax.set_xticks(np.log(percentages)), ax.set_xticklabels(percentages)
ax.hlines(.2, ax.get_xlim()[0], ax.get_xlim()[1], color="black", linestyle="--")
ax.set_title('Occlusion effect'), ax.set_xlabel('Percent occlusion'), ax.set_ylabel('accuracy')
ax.spines['top'].set_visible(False), ax.spines['right'].set_visible(False);

In [33]:
ax.get_xlim()

(-5.051303100964311, 4.76362102851253)