# Train and evaluate a 3D Convolutional Neural Network (3dCNN) to classify motor tasks

<p style='text-align: justify;'>In this notebook we create and train a 3D-Convolutional Neural Network which learns to classify different patterns of whole-brain fMRI statistical parameters (t-scores). In this first approach our goal is to train a classifier that can reliably distinguish between such whole-brain patterns for five limb movements (i.e., left/right hand, left/right foot, and tongue).

</p>

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import os, wandb, torch, time
import pandas as pd
import numpy as np
from glob import glob
from torch.utils.data import DataLoader
from torchinfo import summary
import seaborn as sns
import matplotlib.pyplot as plt

from delphi import mni_template
from delphi.networks.ConvNets import BrainStateClassifier3d
from delphi.utils.datasets import NiftiDataset
from delphi.utils.tools import ToTensor, compute_accuracy, convert_wandb_config, read_config, z_transform_volume
from delphi.utils.plots import confusion_matrix

from sklearn.model_selection import StratifiedShuffleSplit

# you can find all these files in ../utils
from utils.tools import attribute_with_method, concat_stat_files, compute_mi
from utils.wandb_funcs import reset_wandb_env, wandb_plots
from utils.random import set_random_seed

from tqdm.notebook import tqdm

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



<p style='text-align: justify;'>To make sure that we obtain (almost) the same results for each execution we set the random seed of multiple different librabries (i.e., torch, random, numpy)</p>

In [3]:
g = set_random_seed(2020) # the project started in the year 2020, hence the seed

## Initializations

In this section, we define and initialize our required variables. We first need to define which classes we want to predict, i.e., the conditions of the motor mapper. We then define a PyTorch dataset; in this case `NiftiDataset` is a custom written Dataset-Class (see https://github.com/PhilippS893/delphi). As is common practice in machine learning projects, we split our data into a training and validation dataset (ratio=80 to 20, respectively).

Note: In case it is necessary to create a null-model, i.e., a neural network that is trained on data where the labels are randomized, one can set the parameter `shuffe_labels=False` to `True`. This is usually done to have a baseline for the null hypothesis that "everything is random".

In [4]:
TASK_LABEL = "motor"
class_labels = sorted(["handleft", "handright", "footleft", "footright", "tongue"])

data_test = NiftiDataset("../t-maps/test", class_labels, 0, device=DEVICE, transform=ToTensor())

<p style='text-align: justify;'>We now set some parameters required by w&b to properly store information about our trained neural networks.</p>

In [5]:
# set the wandb sweep config
# os.environ['WANDB_MODE'] = 'offline'
os.environ['WANDB_ENTITY'] = "philis893" # this is my wandb account name. This can also be a group name, for example
os.environ['WANDB_PROJECT'] = "thesis" # this is simply the project name where we want to store the sweep logs and plots

## Training the neural network(s)

<p style='text-align: justify;'>In this first approach, we estimated what parameters we could use for our 3d-CNN from existing literature. We make use of some functionality, e.g., the function "read_config", I wrote for the "delphi" toolbox (see https://github.com/PhilippS893/delphi). This function can read .yaml files with the formatting as below:</p>

The contents of `hyperparameter.yaml`: <br>
`kernel_size: 7`<br>
`batch_size: 4`<br>
`dropout: .5`<br>
`learning_rate: 0.00001`<br>
`epochs: 60`<br>
`channels: [1, 8, 16, 32, 64]`<br>
`lin_neurons: [128, 64]`<br>
`pooling_kernel: 2`<br>

The function `read_config` will return a dictionary variable containing all keyword-value pairs as set in the `hyperparameter.yaml`. We do this, because the `delphi` toolbox was written with dictionaries as configuration variables in mind and because we can easily submit dictionaries to `w&b` to keep track of such parameters.

In [None]:
def train_network(run, fold, data_train, data_valid, data_test, hp, class_labels, shuffled_labels=False):
    reset_wandb_env()
    
    job_name = "CV-shuffled" if shuffled_labels else "CV-real"
    wandb_kwargs = {
        "entity": os.environ['WANDB_ENTITY'],
        "project": os.environ['WANDB_PROJECT'],
        "group": "first-steps",
        "name": f"motor_fold-{fold:02d}",
        "job_type": job_name if num_folds > 1 else "train",
    }
    
    save_name = os.path.join("models", job_name, wandb_kwargs["name"])
    if os.path.exists(save_name):
        return
    
    # we adjust the random seed here to ensure that each run of each fold has a unique seed!
    g = set_random_seed(2020 + fold + run)
    
    # we now use the wandb context to track the training and evaluation process.
    # all settings and changes will be reset at the beginning of the fold-loop. (see line 11)
    with wandb.init(config=hp, **wandb_kwargs) as run:
        
        # please note that this conversion is unnecessary if not using w&b!
        model_cfg = convert_wandb_config(run.config, BrainStateClassifier3d._REQUIRED_PARAMS)
        
        # setup a model with the parameters given in model_cfg
        model = BrainStateClassifier3d(input_dims, len(class_labels), model_cfg)
        model.to(DEVICE);
        
        model.config["class_labels"] = class_labels
        
        dl_train = DataLoader(data_train, batch_size=run.config.batch_size, shuffle=True, generator=g)
        dl_valid = DataLoader(data_valid, batch_size=run.config.batch_size, shuffle=True, generator=g)
        dl_test = DataLoader(data_test, batch_size=run.config.batch_size, shuffle=False, generator=g)

        best_loss, best_acc = 100, 0
        loss_acc = []
        train_stats, valid_stats = [], []

        # loop for the above set number of epochs
        for epoch in range(run.config.epochs):
            _, _ = model.fit(dl_train, lr=run.config.learning_rate)

            # for validating or testing set the network into evaluation mode such that layers like dropout are not active
            with torch.no_grad():
                tloss, tstats = model.fit(dl_train, train=False)
                vloss, vstats = model.fit(dl_valid, train=False)

            # the model.fit() method has 2 output parameters: loss, stats = model.fit()
            # the first parameter is simply the loss for each sample
            # the second parameter is a matrix of n_classes+2-by-n_samples
            # the first n_classes columns are the output probabilities of the model per class
            # the second to last column (i.e., [:, -2]) represents the real labels
            # the last column (i.e., [:, -1]) represents the predicted labels
            tacc = compute_accuracy(tstats[:, -2], tstats[:, -1])
            vacc = compute_accuracy(vstats[:, -2], vstats[:, -1])

            loss_acc.append(pd.DataFrame([[tloss, vloss, tacc, vacc]],
                                         columns=["train_loss", "valid_loss", "train_acc", "valid_acc"]))

            train_stats.append(pd.DataFrame(tstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
            train_stats[epoch]["epoch"] = epoch
            valid_stats.append(pd.DataFrame(vstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
            valid_stats[epoch]["epoch"] = epoch

            wandb.log({
                "train_acc": tacc, "train_loss": tloss,
                "valid_acc": vacc, "valid_loss": vloss
            }, step=epoch)

            print('Epoch=%03d, train_loss=%2.3f, train_acc=%1.3f, valid_loss=%2.3f, valid_acc=%1.3f' % 
                 (epoch, tloss, tacc, vloss, vacc))

            if (vacc >= best_acc) and (vloss <= best_loss):
                # assign the new best values
                best_acc, best_loss = vacc, vloss
                wandb.run.summary["best_valid_accuracy"] = best_acc
                wandb.run.summary["best_valid_epoch"] = epoch
                # save the current best model
                model.save(save_name)
                # plot some graphs for the validation data
                wandb_plots(vstats[:, -2], vstats[:, -1], vstats[:, :-2], class_labels, "valid")


        # save the files
        full_df = pd.concat(loss_acc)
        full_df.to_csv(os.path.join(save_name, "loss_acc_curves.csv"), index=False)
        full_df = pd.concat(train_stats)
        full_df.to_csv(os.path.join(save_name, "train_stats.csv"), index=False)
        full_df = pd.concat(valid_stats)
        full_df.to_csv(os.path.join(save_name, "valid_stats.csv"), index=False)

        # load the best performing model
        model = BrainStateClassifier3d(save_name)
        model.to(DEVICE)
        
        # EVALUATE THE MODEL ON THE TEST DATA
        with torch.no_grad():
            testloss, teststats = model.fit(dl_test, train=False)
            
        testacc = compute_accuracy(teststats[:, -2], teststats[:, -1])
        df_test = pd.DataFrame(teststats.tolist(), columns=[*class_labels, *["real", "predicted"]])
        df_test.to_csv(os.path.join(save_name, "test_stats.csv"), index=False)
        
        wandb.run.summary["test_accuracy"] = testacc

        wandb.log({"test_accuracy": testacc, "test_loss": testloss})
        wandb_plots(teststats[:, -2], teststats[:, -1], teststats[:, :-2], class_labels, "test")

        wandb.finish()

In [None]:
hp = read_config("hyperparameter.yaml")

num_folds = 10 # we decided to run a 10-fold cross-validation scheme
run = 0 # in case you decide to do multiple runs of the cross-validation scheme
input_dims = (91, 109, 91) # these are the 3D input dimension of our whole-brain data

In [None]:
len(data_test)

### Run cross-validation training using correct labels

In [None]:
# we will split the train dataset into a train (80%) and validation (20%) set.
data_train_full = NiftiDataset("../t-maps/train", class_labels, 0, device=DEVICE, 
                               transform=ToTensor(), shuffle_labels=False)

# we want a stratified shuffled split
sss = StratifiedShuffleSplit(n_splits=num_folds, test_size=0.2, random_state=2020)

for fold, (idx_train, idx_valid) in enumerate(sss.split(data_train_full.data, data_train_full.labels)):
    data_train = torch.utils.data.Subset(data_train_full, idx_train)
    data_valid = torch.utils.data.Subset(data_train_full, idx_valid)
    train_network(run, fold, data_train, data_valid, data_test, hp, class_labels)

### Run cross-validation training using shuffled labels

In [None]:
# we will split the train dataset into a train (80%) and validation (20%) set.
data_train_full = NiftiDataset("../t-maps/train", class_labels, 20, device=DEVICE, 
                               transform=ToTensor(), shuffle_labels=True)

# we want a stratified shuffled split
sss = StratifiedShuffleSplit(n_splits=num_folds, test_size=0.2, random_state=2020)

for fold, (idx_train, idx_valid) in enumerate(sss.split(data_train_full.data, data_train_full.labels)):
    data_train = torch.utils.data.Subset(data_train_full, idx_train)
    data_valid = torch.utils.data.Subset(data_train_full, idx_valid)
    train_network(run, fold, data_train, data_valid, data_test, hp, class_labels, shuffled_labels=True)