# Train a 3D Convolutional Neural Network (3dCNN) to classify motor tasks

In [1]:
import os
import wandb
import torch
import pandas as pd
import numpy as np
import time
from glob import glob
from torch.utils.data import DataLoader
from torchinfo import summary

from delphi.networks.ConvNets import BrainStateClassifier3d
from delphi.utils.datasets import NiftiDataset
from delphi.utils.tools import ToTensor, compute_accuracy, convert_wandb_config, read_config
from sklearn.model_selection import StratifiedShuffleSplit

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def set_random_seed(seed):
    import random
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    g = torch.Generator()  # can be used in pytorch dataloaders for reproducible sample selection when shuffle=True
    g.manual_seed(seed)

    return g

g = set_random_seed(2020)

In [3]:
def wandb_plots(y_true, y_pred, y_prob, class_labels, dataset):
    wandb.log({
        f"{dataset}-ROC": wandb.plot.roc_curve(y_true=y_true, y_probas=y_prob, labels=class_labels),
        f"{dataset}-PR": wandb.plot.pr_curve(y_true=y_true, y_probas=y_prob, labels=class_labels, ),
        f"{dataset}-ConfMat": wandb.plot.confusion_matrix(y_true=y_true, preds=y_pred, class_names=class_labels)
    })

In [4]:
class_labels = sorted(["handleft", "handright", "footleft", "footright", "tongue"])

data_test = NiftiDataset("../t-maps/test", class_labels, 0, device=DEVICE, transform=ToTensor())

# we will split the train dataset into a train (80%) and validation (20%) set.
data_train_full = NiftiDataset("../t-maps/train", class_labels, 0, device=DEVICE, transform=ToTensor())

# we want one stratified shuffled split
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2020)
idx_train, idx_valid = next(sss.split(data_train_full.data, data_train_full.labels))

data_train = torch.utils.data.Subset(data_train_full, idx_train)
data_valid = torch.utils.data.Subset(data_train_full, idx_valid)

In [5]:
model_cfg = {
    "channels": [1, 8, 16, 32, 64],
    "lin_neurons": [128, 64],
    "pooling_kernel": 2,
    "kernel_size": 5,
    "dropout": .5,
}
input_dims = (91, 109, 91)
model = BrainStateClassifier3d(input_dims, len(class_labels), model_cfg)
print(summary(model, (1,1,91,109,91)))

Layer (type:depth-idx)                   Output Shape              Param #
BrainStateClassifier3d                   [1, 5]                    --
├─Sequential: 1-1                        --                        --
│    └─Sequential: 2-1                   [1, 8, 45, 54, 45]        --
│    │    └─Conv3d: 3-1                  [1, 8, 91, 109, 91]       1,008
│    │    └─ReLU: 3-2                    [1, 8, 91, 109, 91]       --
│    │    └─MaxPool3d: 3-3               [1, 8, 45, 54, 45]        --
│    └─Sequential: 2-2                   [1, 16, 22, 27, 22]       --
│    │    └─Conv3d: 3-4                  [1, 16, 45, 54, 45]       16,016
│    │    └─ReLU: 3-5                    [1, 16, 45, 54, 45]       --
│    │    └─MaxPool3d: 3-6               [1, 16, 22, 27, 22]       --
│    └─Sequential: 2-3                   [1, 32, 11, 13, 11]       --
│    │    └─Conv3d: 3-7                  [1, 32, 22, 27, 22]       64,032
│    │    └─ReLU: 3-8                    [1, 32, 22, 27, 22]       --
│   

In [6]:
# set the wandb sweep config
# os.environ['WANDB_MODE'] = 'offline'
os.environ['WANDB_ENTITY'] = "philis893" # this is my wandb account name. This can also be a group name, for example
os.environ['WANDB_PROJECT'] = "thesis" # this is simply the project name where we want to store the sweep logs and plots

In [7]:
hp = read_config("hyperparameter.yaml")
with wandb.init(config=hp, name=hp["name"], group="first_steps") as run:
    
    save_name = run.config.name
    
    dl_test = DataLoader(data_test, batch_size=run.config.batch_size, shuffle=True, generator=g)
    dl_train = DataLoader(data_train, batch_size=run.config.batch_size, shuffle=True, generator=g)
    dl_valid = DataLoader(data_valid, batch_size=run.config.batch_size, shuffle=True, generator=g)
    
    best_loss, best_acc = 100, 0
    loss_acc = []
    train_stats, valid_stats = [], []
    # patience, patience_ctr = 9, 0
    
    # loop for the above set number of epochs
    for epoch in range(run.config.epochs):
        _, _ = model.fit(dl_train, lr=run.config.learning_rate, device=DEVICE)

        # for validating or testing set the network into evaluation mode such that layers like dropout are not active
        with torch.no_grad():
            tloss, tstats = model.fit(dl_train, device=DEVICE, train=False)
            vloss, vstats = model.fit(dl_valid, device=DEVICE, train=False)
        
        # the model.fit() method has 2 output parameters: loss, stats = model.fit()
        # the first parameter is simply the loss for each sample
        # the second parameter is a matrix of n_classes+2-by-n_samples
        # the first n_classes columns are the output probabilities of the model per class
        # the second to last column (i.e., [:, -2]) represents the real labels
        # the last column (i.e., [:, -1]) represents the predicted labels
        tacc = compute_accuracy(tstats[:, -2], tstats[:, -1])
        vacc = compute_accuracy(vstats[:, -2], vstats[:, -1])

        loss_acc.append(pd.DataFrame([[tloss, vloss, tacc, vacc]],
                                     columns=["train_loss", "valid_loss", "train_acc", "valid_acc"]))
        
        train_stats.append(pd.DataFrame(tstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
        train_stats[epoch]["epoch"] = epoch
        valid_stats.append(pd.DataFrame(vstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
        valid_stats[epoch]["epoch"] = epoch
        
        wandb.log({
            "train_acc": tacc, "train_loss": tloss,
            "valid_acc": vacc, "valid_loss": vloss
        })
        
        print('Epoch=%03d, train_loss=%2.3f, train_acc=%1.3f, valid_loss=%2.3f, valid_acc=%1.3f' % 
             (epoch, tloss, tacc, vloss, vacc))
        
        if (vacc >= best_acc) and (vloss <= best_loss):
            # assign the new best values
            best_acc, best_loss = vacc, vloss
            wandb.run.summary["best_valid_accuracy"] = best_acc
            wandb.run.summary["best_valid_epoch"] = epoch
            # save the current best model
            model.save(save_name)
            # plot some graphs for the validation data
            wandb_plots(vstats[:, -2], vstats[:, -1], vstats[:, :-2], class_labels, "valid")
            
            # reset the patience counter
            #patience_ctr=0
            
        #else:
        #    patience_ctr+=1
        #
        #if patience_ctr > patience:
        #    print('Reached patience. Stopping training and continuing with test set.')
        #    break

    # save the files
    full_df = pd.concat(loss_acc)
    full_df.to_csv(os.path.join(save_name, "loss_acc_curves.csv"), index=False)
    full_df = pd.concat(train_stats)
    full_df.to_csv(os.path.join(save_name, "train_stats.csv"), index=False)
    full_df = pd.concat(valid_stats)
    full_df.to_csv(os.path.join(save_name, "valid_stats.csv"), index=False)
    
    # EVALUATE THE MODEL ON THE TEST DATA
    with torch.no_grad():
        testloss, teststats = model.fit(dl_test, train=False)
    testacc = compute_accuracy(teststats[:, -2], teststats[:, -1])
    wandb.run.summary["test_accuracy"] = testacc

    wandb.log({"test_accuraccy": testacc, "test_loss": testloss})
    wandb_plots(teststats[:, -2], teststats[:, -1], teststats[:, :-2], class_labels, "test")

    wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mphilis893[0m. Use [1m`wandb login --relogin`[0m to force relogin


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=000, train_loss=1.601, train_acc=0.348, valid_loss=1.602, valid_acc=0.314
Saving motor-classifier/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=001, train_loss=1.591, train_acc=0.200, valid_loss=1.594, valid_acc=0.200


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=002, train_loss=1.579, train_acc=0.200, valid_loss=1.585, valid_acc=0.200


0it [00:00, ?it/s]

0it [00:00, ?it/s]

VBox(children=(Label(value='0.085 MB of 0.085 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_acc,█▁▁
train_loss,█▅▁
valid_acc,█▁▁
valid_loss,█▅▁

0,1
best_valid_accuracy,0.31429
best_valid_epoch,0.0
train_acc,0.2
train_loss,1.5792
valid_acc,0.2
valid_loss,1.58547


KeyboardInterrupt: 