In [1]:
import os
import wandb
import torch
import pandas as pd
import numpy as np
import time
from glob import glob
from torch.utils.data import DataLoader

from delphi.networks.ConvNets import BrainStateClassifier3d
from delphi.utils.datasets import NiftiDataset
from delphi.utils.tools import ToTensor, compute_accuracy, convert_wandb_config, read_config
from sklearn.model_selection import StratifiedShuffleSplit

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def set_random_seed(seed):
    import random
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    g = torch.Generator()  # can be used in pytorch dataloaders for reproducible sample selection when shuffle=True
    g.manual_seed(seed)

    return g

g = set_random_seed(2020)

In [3]:
def wandb_plots(y_true, y_pred, y_prob, class_labels, dataset):
    wandb.log({
        f"{dataset}-ROC": wandb.plot.roc_curve(y_true=y_true, y_probas=y_prob, labels=class_labels),
        f"{dataset}-PR": wandb.plot.pr_curve(y_true=y_true, y_probas=y_prob, labels=class_labels, ),
        f"{dataset}-ConfMat": wandb.plot.confusion_matrix(y_true=y_true, preds=y_pred, class_names=class_labels)
    })

# Define the classes and data to use

In [4]:
class_labels = sorted(["handleft", "handright", "footleft", "footright", "tongue"])

In [5]:
data_test = NiftiDataset("../t-maps/test", class_labels, 0, device=DEVICE, transform=ToTensor())

# we will split the train dataset into a train (80%) and validation (20%) set.
data_train_full = NiftiDataset("../t-maps/train", class_labels, 0, device=DEVICE, transform=ToTensor())

# we want one stratified shuffled split
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2020)
idx_train, idx_valid = next(sss.split(data_train_full.data, data_train_full.labels))

data_train = torch.utils.data.Subset(data_train_full, idx_train)
data_valid = torch.utils.data.Subset(data_train_full, idx_valid)

# Set up the sweep

# Define the training function

In [6]:
def train_net(model, config, save_name, logwandb=True):
    
    dl_test = DataLoader(data_test, batch_size=config.batch_size, shuffle=True, generator=g)
    dl_train = DataLoader(data_train, batch_size=config.batch_size, shuffle=True, generator=g)
    dl_valid = DataLoader(data_valid, batch_size=config.batch_size, shuffle=True, generator=g)
    
    best_loss, best_acc = 100, 0
    loss_acc = []
    train_stats, valid_stats = [], []
    patience = 9
    patience_ctr = 0
    
    # loop for the above set number of epochs
    for epoch in range(0, config.epochs):
        _, _ = model.fit(dl_train, lr=config.learning_rate, device=DEVICE)

        # for validating or testing set the network into evaluation mode such that layers like dropout are not active
        with torch.no_grad():
            tloss, tstats = model.fit(dl_train, device=DEVICE, train=False)
            vloss, vstats = model.fit(dl_valid, device=DEVICE, train=False)
                    
        tacc = compute_accuracy(tstats[:, -2], tstats[:, -1])
        vacc = compute_accuracy(vstats[:, -2], vstats[:, -1])

        loss_acc.append(pd.DataFrame([[tloss, vloss, tacc, vacc]],
                                     columns=["train_loss", "valid_loss", "train_acc", "valid_acc"]))
        
        train_stats.append(pd.DataFrame(tstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
        train_stats[epoch]["epoch"] = epoch
        valid_stats.append(pd.DataFrame(vstats.tolist(), columns=[*class_labels, *["real", "predicted"]]))
        valid_stats[epoch]["epoch"] = epoch
        
        wandb.log({
            "train_acc": tacc, "train_loss": tloss,
            "valid_acc": vacc, "valid_loss": vloss
        })
        
        print('Epoch=%03d, train_loss=%2.3f, train_acc=%1.3f, valid_loss=%2.3f, valid_acc=%1.3f' % 
             (epoch, tloss, tacc, vloss, vacc))
        
        if (vacc >= best_acc) and (vloss <= best_loss):
            # assign the new best values
            best_acc, best_loss = vacc, vloss
            wandb.run.summary["best_valid_accuracy"] = best_acc
            wandb.run.summary["best_valid_epoch"] = epoch
            # save the current best model
            model.save(save_name)
            # plot some graphs for the validation data
            wandb_plots(vstats[:, -2], vstats[:, -1], vstats[:, :-2], class_labels, "valid")
            
            # reset the patience counter
            patience_ctr=0
            
        else:
            patience_ctr+=1
        
        if patience_ctr > patience:
            print('Reached patience. Stopping training and continuing with test set.')
            break

    # save the files
    full_df = pd.concat(loss_acc)
    full_df.to_csv(os.path.join(save_name, "loss_acc_curves.csv"), index=False)
    full_df = pd.concat(train_stats)
    full_df.to_csv(os.path.join(save_name, "train_stats.csv"), index=False)
    full_df = pd.concat(valid_stats)
    full_df.to_csv(os.path.join(save_name, "valid_stats.csv"), index=False)
    
    # EVALUATE THE MODEL ON THE TEST DATA
    with torch.no_grad():
        testloss, teststats = model.fit(dl_test, train=False)
    testacc = compute_accuracy(teststats[:, -2], teststats[:, -1])
    wandb.run.summary["test_accuracy"] = testacc

    wandb.log({"test_accuracy": testacc, "test_loss": testloss})
    wandb_plots(teststats[:, -2], teststats[:, -1], teststats[:, :-2], class_labels, "test")

    wandb.finish()

# Define the run_train function

In [7]:
# define the training function with the wandb init
def run_train():
    
    # here we initialize weights&biases. 
    with wandb.init() as run:
        # here's the promised conversion of the wandb.config
        # this results into a dict that contains key-value pairs that we can use to configure our network:
        # converted_config['lin_neurons'] = [512, 8, 128]
                
        converted_config = convert_wandb_config(wandb.config, BrainStateClassifier3d._REQUIRED_PARAMS)
                
        model = BrainStateClassifier3d((91, 109, 91), len(class_labels), converted_config)
        model.to(DEVICE)
        
        # We do not necessarily need this line but it is nice to update the config.
        #wandb.config.update(model.config, allow_val_change=True)
        
        t_stamp = time.time()
        save_name = os.path.join("models", f"motor-explo_{t_stamp}")
        wandb.run.name = f"motor-explo-{t_stamp}"
        
        # now train the netwok, yay!
        train_net(model, wandb.config, save_name)

# Run the sweep

In [8]:
sweep_config = read_config("exploration_sweep.yaml")
print(sweep_config)

{'name': 'exploration-sweep', 'entity': 'philis893', 'project': 'thesis', 'method': 'grid', 'metric': {'name': 'valid_acc'}, 'parameters': {'channels1': {'value': 1}, 'channels2': {'value': 8}, 'channels3': {'value': 16}, 'channels4': {'value': 32}, 'channels5': {'value': 64}, 'kernel_size': {'values': [3, 5, 7]}, 'lin_neurons1': {'value': 128}, 'lin_neurons2': {'value': 64}, 'batch_size': {'values': [4, 8, 16, 32]}, 'dropout': {'values': [0.3, 0.4, 0.5, 0.6, 0.7]}, 'learning_rate': {'values': [1e-05, 0.0001, 0.001]}, 'epochs': {'value': 60}}}


In [9]:
# set the wandb sweep config
#os.environ['WANDB_MODE'] = 'offline'
os.environ['WANDB_ENTITY'] = "philis893" # this is my wandb account name. This can also be a group name, for example
os.environ['WANDB_PROJECT'] = "thesis" # this is simply the project name where we want to store the sweep logs and plots
#sweep_id = wandb.sweep(sweep_config)

In [10]:
count=180
wandb.agent("bhhpc7mn", function=run_train, count=count)

[34m[1mwandb[0m: Agent Starting Run: 4q8vwzir with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	channels1: 1
[34m[1mwandb[0m: 	channels2: 8
[34m[1mwandb[0m: 	channels3: 16
[34m[1mwandb[0m: 	channels4: 32
[34m[1mwandb[0m: 	channels5: 64
[34m[1mwandb[0m: 	dropout: 0.7
[34m[1mwandb[0m: 	epochs: 60
[34m[1mwandb[0m: 	kernel_size: 3
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	lin_neurons1: 128
[34m[1mwandb[0m: 	lin_neurons2: 64
[34m[1mwandb[0m: Currently logged in as: [33mphilis893[0m. Use [1m`wandb login --relogin`[0m to force relogin


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=000, train_loss=1.598, train_acc=0.330, valid_loss=1.600, valid_acc=0.321
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=001, train_loss=1.570, train_acc=0.386, valid_loss=1.576, valid_acc=0.386
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=002, train_loss=1.498, train_acc=0.359, valid_loss=1.512, valid_acc=0.336


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=003, train_loss=1.171, train_acc=0.764, valid_loss=1.203, valid_acc=0.743
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=004, train_loss=0.628, train_acc=0.930, valid_loss=0.684, valid_acc=0.864
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=005, train_loss=0.394, train_acc=0.941, valid_loss=0.443, valid_acc=0.914
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=006, train_loss=0.348, train_acc=0.805, valid_loss=0.395, valid_acc=0.764


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=007, train_loss=0.232, train_acc=0.973, valid_loss=0.265, valid_acc=0.943
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=008, train_loss=0.173, train_acc=0.979, valid_loss=0.233, valid_acc=0.943
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=009, train_loss=0.183, train_acc=0.968, valid_loss=0.248, valid_acc=0.914


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Epoch=010, train_loss=0.108, train_acc=0.982, valid_loss=0.171, valid_acc=0.950
Saving models/motor-explo_1665415391.7978227/state_dict.pth


0it [00:00, ?it/s]

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [11]:
DEVICE

device(type='cuda', index=0)