In [1]:
import torch
from torch.nn.functional import elu,relu,leaky_relu
import braindecode
from braindecode.models import *
from braindecode.models.functions import squeeze_final_output,square,safe_log
from skorch.dataset import Dataset
from skorch.callbacks import Checkpoint
from skorch.helper import predefined_split
from config import *
from sklearn.metrics import roc_auc_score
from mne import set_log_level
set_log_level(False)
device = 'cuda' if cuda else 'cpu'

Tensorflow not install, you could not use those pipelines


In [2]:
#Loads preprocessed data from mat files
import scipy
import numpy as np
inputs=scipy.io.loadmat("E:/train_set_1.mat")
X=inputs["x"]
y=inputs["y"].squeeze()

inputs=scipy.io.loadmat("E:/train_set_2.mat")
X=np.concatenate((X,inputs["x"]),axis=0)
y=np.concatenate((y,inputs["y"].squeeze()),axis=0)
input_time_length=X.shape[-1]
inputs=scipy.io.loadmat("E:/test_set.mat")
test_x=inputs["x"]
test_y=inputs["y"].squeeze()
del inputs

In [137]:
#We will now train the model by taking pairs or combinations of channels and passing their entire length.
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
picked_ch=[20]
input_time_length=X.shape[-1]
train_x=X[:,picked_ch]
n_chans=len(picked_ch)
eval_x=test_x[:,picked_ch]
print(f'Channels chosen:{[ch_names[ch] for ch in picked_ch]}')

Channels chosen:['T6']


In [138]:
#We will convert the 1d array into a 2d matrix by converting from (1,60000) to (10,6000) and then passing them through the network
#Due to conv_time_spat layer in model, it would be better to have it so that first 10 entries are in channel 1, next 10 in channel 2 etc
#So that channel 1 will have 10,110,210 entries and so on
n_chans=10
input_time_length=input_time_length//n_chans
train_x=train_x.reshape(len(train_x),n_chans,input_time_length,order='F')
eval_x=eval_x.reshape(len(eval_x),n_chans,input_time_length,order='F')

In [5]:
model_name='shallow_deep'

In [145]:
n_classes = 2
criterion=torch.nn.NLLLoss
if model_name=="shallow":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    pool_time_length=150
    pool_time_stride=50
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = ShallowFBCSPNet(n_chans,
                                    n_classes,
                                    n_filters_time=n_start_chans,
                                    n_filters_spat=n_start_chans,
                                    input_window_samples=input_time_length,
                                    pool_time_length=pool_time_length,
                                    pool_time_stride=pool_time_stride,
                                    final_conv_length='auto',)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="deep":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    model = Deep4Net(n_chans, n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_window_samples=input_time_length,
                         n_filters_2 = int(n_start_chans * n_chan_factor),
                         n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                         n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                         final_conv_length='auto',
                        stride_before_pool=True)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="deep_smac" or model_name == 'deep_smac_bnorm':
    optimizer_lr = 0.0000625
    if model_name == 'deep_smac':
            do_batch_norm = False
    else:
        do_batch_norm = True
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 24
    filter_length_4 = 36
    filter_time_length = 21
    #final_conv_length = 1
    first_nonlin = elu
    first_pool_mode = 'mean'
    later_nonlin = elu
    later_pool_mode = 'mean'
    n_filters_factor = 1.679066
    n_filters_start = 32
    pool_time_length = 3
    pool_time_stride = 3
    split_first_layer = True
    n_chan_factor = n_filters_factor
    n_start_chans = n_filters_start
    model = Deep4Net(n_chans, n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_window_samples=input_time_length,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
            final_conv_length='auto',
            batch_norm=do_batch_norm,
            drop_prob=drop_prob,
            filter_length_2=filter_length_2,
            filter_length_3=filter_length_3,
            filter_length_4=filter_length_4,
            filter_time_length=filter_time_length,
            first_conv_nonlin=first_nonlin,
            first_pool_mode=first_pool_mode,
            later_conv_nonlin=later_nonlin,
            later_pool_mode=later_pool_mode,
            #pool_time_length=pool_time_length,
            #pool_time_stride=pool_time_stride,
            split_first_layer=split_first_layer,
            stride_before_pool=True)
elif model_name=="shallow_deep":
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 14
    filter_length_4 = 32
    n_filters_factor = 1.679066
    n_filters_start = 32
    split_first_layer = True
    n_chan_factor = n_filters_factor
    n_start_chans = n_filters_start

    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    conv_time_length=25
    first_conv_nonlin=relu
    first_pool_nonlin=safe_log
    later_conv_nonlin=elu
    later_pool_nonlin=safe_log
    first_pool_mode = 'mean'
    later_pool_mode = 'mean'
    pool_time_length=15
    model = Deep4Net(n_chans, n_classes,
                            input_time_length,
                            n_filters_time=n_start_chans,
                            n_filters_spat=n_start_chans,
                            n_filters_2 = int(n_start_chans * n_chan_factor),
                            n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                            n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                            final_conv_length='auto',
                            first_pool_nonlin=first_pool_nonlin,
                            first_conv_nonlin=first_conv_nonlin,
                            #later_pool_nonlin=later_pool_nonlin,
                            #later_conv_nonlin=later_conv_nonlin,
                            filter_time_length=conv_time_length,
                            pool_time_length=pool_time_length,
                            first_pool_mode=first_pool_mode,
                            later_pool_mode=later_pool_mode,
                            split_first_layer=split_first_layer,
                            drop_prob=drop_prob,
                            filter_length_2=filter_length_2,
                            filter_length_3=filter_length_3,
                            filter_length_4=filter_length_4,
                            )
elif model_name=="attention":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    model=ATCNet(n_chans,n_classes,input_time_length//sampling_freq,sampling_freq,concat=True,tcn_depth=4)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
if cuda:
    model.cuda()

print(model_name)
print(model)

shallow_deep
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
Deep4Net (Deep4Net)                      [1, 10, 6000]             [1, 2]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 10, 6000]             [1, 10, 6000, 1]          --                        --
├─Rearrange (dimshuffle): 1-2            [1, 10, 6000, 1]          [1, 1, 6000, 10]          --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 6000, 10]          [1, 32, 5976, 1]          11,072                    --
├─BatchNorm2d (bnorm): 1-4               [1, 32, 5976, 1]          [1, 32, 5976, 1]          64                        --
├─Expression (conv_nonlin): 1-5          [1, 32, 5976, 1]          [1, 32, 5976, 1]          --                        --
├─AvgPool2dWithConv (pool): 1-6          [1, 32, 5976, 1]          [1, 32, 1988, 1]          --                        [15,



In [146]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params=f'chanbest_param.pkl',
               f_optimizer=f'chanbest_opt.pkl', f_history=f'chanbest_history.json')
classifier = braindecode.EEGClassifier(
    model,
    criterion=criterion,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(Dataset(eval_x,test_y)),
    optimizer__lr=optimizer_lr,
    #optimizer__weight_decay=optimizer_weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    device=device,
    callbacks=["accuracy","f1",'roc_auc',cp],#Try ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=20))
    warm_start=True,
    )
classifier.initialize()

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
  Deep4Net (Deep4Net)                      [1, 10, 6000]             [1, 2]                    --                        --
  ├─Ensure4d (ensuredims): 1-1             [1, 10, 6000]             [1, 10, 6000, 1]          --                        --
  ├─Rearrange (dimshuffle): 1-2            [1, 10, 6000, 1]          [1, 1, 6000, 10]          --                        --
  ├─CombinedConv (conv_time_spat): 1-3     [1, 1, 6000, 10]          [1, 32, 5976, 1]          11,072                    --
  ├─BatchNorm2d (bnorm): 1-4               [1, 32, 5976, 1]          [1, 32, 5976, 1]          64                        --
  ├─Expression (conv_nonlin): 1-5          [1, 32, 5976, 1]          [1, 32, 5976, 1]          --                        --
  ├─AvgPool2dWithConv (pool): 1-6          [1, 32, 5976, 1]  

In [121]:
#   Accuracy    F1-score       Loss              AUC
#A1:0.7481      0.7344        0.5478           0.8070
#A2:0.7556      0.7130        0.6907           0.7879
#C3:0.7259      0.7040        0.6569           0.7496
#C4:0.7111      0.6286        0.8119           0.7738
#CZ:0.7259      0.6942        0.7168           0.7628
#F3:0.7111      0.6549        0.8018           0.7540
#F4:0.7111      0.6422        0.6267           0.7608
#F7:0.7556      0.6733        0.6990           0.8015
#F8:0.7333      0.6897        0.6086           0.7520
#FP1:0.7259     0.7259        0.5895           0.7788
#FP2:0.7481     0.7167        0.6900           0.7705
#FZ:0.7852      0.7521        0.6720           0.8140   #Consistently high results
#O1:0.7556      0.7130        0.5564           0.8158
#O2:0.7185      0.6415        0.7401           0.7711
#P3:0.7037      0.6154        0.5774           0.8110
#P4:0.7407      0.7445        0.5793           0.7993
#PZ:0.7037      0.6825        0.6528           0.7238
#T3:0.7704      0.7395        0.4957           0.8506
#T4:0.7185      0.7286        0.5802           0.7753
#T5:0.7481      0.6909        0.5537           0.8239
#T6:0.7185      0.6607        0.5840           0.7722

In [147]:
classifier.fit(train_x,y,epochs=30)

  epoch    train_accuracy    train_f1    train_loss    train_roc_auc    valid_acc    valid_accuracy    valid_f1    valid_loss    valid_roc_auc    cp     dur
-------  ----------------  ----------  ------------  ---------------  -----------  ----------------  ----------  ------------  ---------------  ----  ------
      1            [36m0.8359[0m      [32m0.0000[0m        [35m0.5263[0m           [31m0.7272[0m       [94m0.5259[0m            [36m0.5259[0m      [32m0.0000[0m        [35m2.0995[0m           [31m0.6791[0m     +  3.2420
      2            0.7533      [32m0.0636[0m        [35m0.3956[0m           0.2904       0.5259            0.5259      [32m0.1351[0m        [35m1.3070[0m           0.2890     +  3.0460
      3            [36m0.8431[0m      [32m0.0966[0m        [35m0.3772[0m           0.6548       [94m0.5407[0m            [36m0.5407[0m      0.0606        [35m1.2865[0m           [31m0.6901[0m        2.8390
      4            [36m0.8599[0m 

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
  Deep4Net (Deep4Net)                      [1, 10, 6000]             [1, 2]                    --                        --
  ├─Ensure4d (ensuredims): 1-1             [1, 10, 6000]             [1, 10, 6000, 1]          --                        --
  ├─Rearrange (dimshuffle): 1-2            [1, 10, 6000, 1]          [1, 1, 6000, 10]          --                        --
  ├─CombinedConv (conv_time_spat): 1-3     [1, 1, 6000, 10]          [1, 32, 5976, 1]          11,072                    --
  ├─BatchNorm2d (bnorm): 1-4               [1, 32, 5976, 1]          [1, 32, 5976, 1]          64                        --
  ├─Expression (conv_nonlin): 1-5          [1, 32, 5976, 1]          [1, 32, 5976, 1]          --                        --
  ├─AvgPool2dWithConv (pool): 1-6          [1, 32, 5976, 1]  

In [None]:
classifier.save_params(f_params=f'model/chanbest_param.pkl',f_optimizer=f'model/chanbest_opt.pkl', f_history=f'model/chanbest_history.json')

In [None]:
classifier.load_params(f_params=f'model/chanbest_param.pkl',f_optimizer=f'model/chanbest_opt.pkl', f_history=f'model/chanbest_history.json')
print("Paramters Loaded")

In [None]:
#This block finds the accuracy, f1 score and roc auc of the valid/test set
pred_labels=classifier.predict(eval_x)
auc=roc_auc_score(test_y,classifier.predict_proba(eval_x)[:,1])
accuracy=np.mean(pred_labels==test_y)
tp=np.sum(pred_labels*test_y)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(test_y)
f1=2*precision*recall/(precision+recall)

print(model_name)
print(f"Accuracy:{accuracy}")
print(f"F1-Score:{f1}")
print(f"roc_auc score:{auc}")