In [1]:
import torch
from torch import nn
from torch.nn.functional import elu,relu,leaky_relu
import braindecode
from braindecode.models import *
from braindecode.models.modules import Expression
from braindecode.models.functions import squeeze_final_output,square,safe_log
from skorch.dataset import Dataset
from skorch.callbacks import Checkpoint
from skorch.helper import predefined_split
from config import *
from sklearn.metrics import roc_auc_score
from mne import set_log_level
from chrononet import ChronoNet
set_log_level(False)
device = 'cuda' if cuda else 'cpu'

Tensorflow not install, you could not use those pipelines


In [2]:
#Loads preprocessed data from mat files
import scipy
import numpy as np
inputs=scipy.io.loadmat("E:/train_set_1.mat")
X=inputs["x"]
y=inputs["y"].squeeze()

inputs=scipy.io.loadmat("E:/train_set_2.mat")
X=np.concatenate((X,inputs["x"]),axis=0)
y=np.concatenate((y,inputs["y"].squeeze()),axis=0)
input_time_length=X.shape[-1]
inputs=scipy.io.loadmat("E:/test_set.mat")
test_x=inputs["x"]
test_y=inputs["y"].squeeze()
del inputs

In [3]:
#We will now train the model by taking pairs or combinations of channels and passing their entire length.
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
picked_ch=[11]
input_time_length=X.shape[-1]
train_x=X[:,picked_ch]
train_y=y
n_chans=len(picked_ch)
eval_x=test_x[:,picked_ch]
eval_y=test_y
print(f'Channels chosen:{[ch_names[ch] for ch in picked_ch]}')

Channels chosen:['FZ']


In [4]:
#We will convert the 1d array into a 2d matrix by converting from (1,60000) to (10,6000) and then passing them through the network
#Due to conv_time_spat layer in model, it would be better to have it so that first 10 entries are in channel 1, next 10 in channel 2 etc
#So that channel 1 will have 10,110,210 entries and so on. This means that the convolution layer will compress the 10 entries.
n_chans=10
input_time_length=input_time_length//n_chans
train_x=train_x.reshape(len(train_x),n_chans,input_time_length,order='F')
eval_x=eval_x.reshape(len(eval_x),n_chans,input_time_length,order='F')

In [5]:
def balancing(dataset,labels):
    abnormal=dataset[labels==1]
    abnormal_labels=labels[labels==1]
    dataset=np.delete(dataset,np.where(labels==1),axis=0)
    labels=np.delete(labels,np.where(labels==1),axis=0)

    factor=len(dataset)//len(abnormal)
    abnormal=np.repeat(abnormal,factor,axis=0)
    abnormal_labels=np.repeat(abnormal_labels,factor,axis=0)

    dataset=np.concatenate((dataset,abnormal),axis=0)
    labels=np.concatenate((labels,abnormal_labels),axis=0)
    return (dataset,labels)

In [6]:
train_x,train_y=balancing(train_x,train_y)

In [7]:
model_name='shallow_deep'

In [9]:
n_classes = 2
criterion=torch.nn.NLLLoss
if model_name=="shallow":
    optimizer_lr = 0.000625
    optimizer_weight_decay = 0
    pool_time_length=150
    pool_time_stride=50
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = ShallowFBCSPNet(n_chans,
                                    n_classes,
                                    n_filters_time=n_start_chans,
                                    n_filters_spat=n_start_chans,
                                    n_times=input_time_length,
                                    pool_time_length=pool_time_length,
                                    pool_time_stride=pool_time_stride,
                                    final_conv_length='auto',)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="deep":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    model = Deep4Net(n_chans, n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         n_times=input_time_length,
                         n_filters_2 = int(n_start_chans * n_chan_factor),
                         n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                         n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                         final_conv_length='auto',
                        stride_before_pool=True)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="hybrid":
    optimizer_lr = 0.000625
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = HybridNet(n_chans, n_classes,n_times=input_time_length,)
    test=torch.ones(size=(2,n_chans,input_time_length))
    out=model.forward(test)
    out_length=out.shape[2]
    model.final_layer=nn.Conv2d(100,n_classes,(out_length,1),bias=True,)
    criterion=torch.nn.CrossEntropyLoss
    model=nn.Sequential(model,Expression(torch.squeeze),nn.Softmax(dim=1))
    out=model.forward(test)
    print(out.shape)
    del out_length
elif model_name=="deep_smac" or model_name == 'deep_smac_bnorm':
    optimizer_lr = 0.000625
    if model_name == 'deep_smac':
            do_batch_norm = False
    else:
        do_batch_norm = True
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 24
    filter_length_4 = 36
    filter_time_length = 21
    #final_conv_length = 1
    first_nonlin = elu
    first_pool_mode = 'mean'
    later_nonlin = elu
    later_pool_mode = 'mean'
    n_filters_factor = 1.679066
    n_filters_start = 32
    pool_time_length = 3
    pool_time_stride = 3
    split_first_layer = True
    n_chan_factor = n_filters_factor
    n_start_chans = n_filters_start
    model = Deep4Net(n_chans, n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            n_times=input_time_length,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
            final_conv_length='auto',
            batch_norm=do_batch_norm,
            drop_prob=drop_prob,
            filter_length_2=filter_length_2,
            filter_length_3=filter_length_3,
            filter_length_4=filter_length_4,
            filter_time_length=filter_time_length,
            first_conv_nonlin=first_nonlin,
            first_pool_mode=first_pool_mode,
            later_conv_nonlin=later_nonlin,
            later_pool_mode=later_pool_mode,
            #pool_time_length=pool_time_length,
            #pool_time_stride=pool_time_stride,
            split_first_layer=split_first_layer,
            stride_before_pool=True)
elif model_name=="shallow_deep":
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 14
    filter_length_4 = 32
    n_filters_factor = 1.679066
    n_filters_start = 32
    split_first_layer = True
    n_chan_factor = n_filters_factor
    n_start_chans = n_filters_start

    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    conv_time_length=25
    first_conv_nonlin=relu
    first_pool_nonlin=safe_log
    later_conv_nonlin=elu
    later_pool_nonlin=safe_log
    first_pool_mode = 'mean'
    later_pool_mode = 'mean'
    pool_time_length=15
    model = Deep4Net(n_chans, n_classes,
                            n_times=input_time_length,
                            n_filters_time=n_start_chans,
                            n_filters_spat=n_start_chans,
                            n_filters_2 = int(n_start_chans * n_chan_factor),
                            n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                            n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                            final_conv_length='auto',
                            first_pool_nonlin=first_pool_nonlin,
                            first_conv_nonlin=first_conv_nonlin,
                            #later_pool_nonlin=later_pool_nonlin,
                            #later_conv_nonlin=later_conv_nonlin,
                            filter_time_length=conv_time_length,
                            pool_time_length=pool_time_length,
                            first_pool_mode=first_pool_mode,
                            later_pool_mode=later_pool_mode,
                            split_first_layer=split_first_layer,
                            drop_prob=drop_prob,
                            filter_length_2=filter_length_2,
                            filter_length_3=filter_length_3,
                            filter_length_4=filter_length_4,
                            )
elif model_name=="attention":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    model=ATCNet(n_chans,n_classes,input_time_length//sampling_freq,sampling_freq,concat=True,tcn_depth=4)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="ChronoNet":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    #I had made the ChronoNet structure as similar to tensorflow implementation as possible.
    model=ChronoNet(input_time_length)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="TCN":
    import warnings
    #This disables the warning of the dropout2d layers receiving 3d input
    warnings.filterwarnings("ignore")
    optimizer_lr = 0.000625
    optimizer_weight_decay = 0
    n_blocks=7
    n_filters=32
    kernel_size=24
    drop_prob = 0.3
    add_log_softmax=False
    #Minimum time length for TCN, found inside tcn.py
    min_len = 1
    for i in range(n_blocks):
        dilation = 2 ** i
        min_len += 2 * (kernel_size - 1) * dilation
    print(f"Minimum length :{min_len}")
    #Only setting n_classes to 1 so TCN output is (batch,1,2) so we can remove additional conv1d block.
    #This is only possible due to input_time_length=30,n_block=3 and kernel_size=3.
    x=TCN(n_chans,n_classes,n_blocks,n_filters,kernel_size,drop_prob,n_times=input_time_length)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=x.forward(test)
    out_length=out.shape[2]
    #model=nn.Sequential(x,Expression(torch.squeeze),nn.LogSoftmax(dim=1))
    #There is no hyperparameter where output of TCN is (Batch_Size,Classes) when input is (Batch_Size,21,6000) so add new layers to meet size
    model=nn.Sequential(x,nn.Conv1d(n_classes,n_classes,out_length,bias=True,),Expression(torch.squeeze),nn.LogSoftmax(dim=1))
    out=model.forward(test)
    print(out.shape)
    del out_length,x
if cuda:
    model.cuda()

print(model_name)

shallow_deep


In [30]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params=f'chan_{model_name}best_param.pkl',
               f_optimizer=f'chan_{model_name}best_opt.pkl', f_history=f'chan_{model_name}best_history.json')
classifier = braindecode.EEGClassifier(
    model,
    criterion=criterion,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(Dataset(eval_x,eval_y)),
    optimizer__lr=optimizer_lr,
    #optimizer__weight_decay=optimizer_weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    device=device,
    callbacks=["accuracy","f1",'roc_auc',cp],#Try ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=20))
    warm_start=True,
    )
classifier.initialize()

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=Sequential(
    (0): TCN(
      (ensuredims): Ensure4d()
      (temporal_blocks): Sequential(
        (temporal_block_0): TemporalBlock(
          (conv1): Conv1d(10, 32, kernel_size=(24,), stride=(1,), padding=(23,))
          (chomp1): Chomp1d(chomp_size=23)
          (relu1): ReLU()
          (dropout1): Dropout2d(p=0.3, inplace=False)
          (conv2): Conv1d(32, 32, kernel_size=(24,), stride=(1,), padding=(23,))
          (chomp2): Chomp1d(chomp_size=23)
          (relu2): ReLU()
          (dropout2): Dropout2d(p=0.3, inplace=False)
          (downsample): Conv1d(10, 32, kernel_size=(1,), stride=(1,))
          (relu): ReLU()
        )
        (temporal_block_1): TemporalBlock(
          (conv1): Conv1d(32, 32, kernel_size=(24,), stride=(1,), padding=(46,), dilation=(2,))
          (chomp1): Chomp1d(chomp_size=46)
          (relu1): ReLU()
          (dropout1): Dropout2d(p=0.3, inplace=False)
          (conv2)

In [None]:
#1D to 2D conversion of single channel of 10 minutes and using shallow-deep CNN
#   Accuracy    F1-score       Loss              AUC
#A1:0.7481      0.7344        0.5478           0.8070
#A2:0.7556      0.7130        0.6907           0.7879
#C3:0.7259      0.7040        0.6569           0.7496
#C4:0.7111      0.6286        0.8119           0.7738
#CZ:0.7259      0.6942        0.7168           0.7628
#F3:0.7111      0.6549        0.8018           0.7540
#F4:0.7111      0.6422        0.6267           0.7608
#F7:0.7556      0.6733        0.6990           0.8015
#F8:0.7333      0.6897        0.6086           0.7520
#FP1:0.7259     0.7259        0.5895           0.7788
#FP2:0.7481     0.7167        0.6900           0.7705
#FZ:0.7852      0.7521        0.6720           0.8140   #Consistently high results
#O1:0.7556      0.7130        0.5564           0.8158
#O2:0.7185      0.6415        0.7401           0.7711
#P3:0.7037      0.6154        0.5774           0.8110
#P4:0.7407      0.7445        0.5793           0.7993
#PZ:0.7037      0.6825        0.6528           0.7238
#T3:0.7704      0.7395        0.4957           0.8506
#T4:0.7185      0.7286        0.5802           0.7753
#T5:0.7481      0.6909        0.5537           0.8239
#T6:0.7185      0.6607        0.5840           0.7722

#1D to 2D conversion of single channel of 3 seconds and using TCN
#   Accuracy    F1-score       Loss              AUC
#A1:0.5827      0.2397        0.8152           0.7300
#A2:0.5903      0.2692        0.7758           0.7414
#C3:
#C4:
#CZ:
#F3:
#F4:
#F7:
#F8:
#FP1:0.6015     0.2980        0.7654           0.7980
#FP2:0.5868     0.2593        0.7681           0.7376
#FZ:
#O1:
#O2:
#P3:
#P4:
#PZ:
#T3:
#T4:
#T5:
#T6:

In [20]:
classifier.fit(train_x,train_y,epochs=10)

     11            0.4953      0.6625        0.8183           0.5000       0.4741            0.4741      0.6432        0.8392           0.5000        108.8986
     12            0.4953      0.6625        0.8169           0.5000       0.4741            0.4741      0.6432        0.8392           0.5000        109.6255
     13            0.4953      0.6625        0.8183           0.5000       0.4741            0.4741      0.6432        0.8392           0.5000        108.9973
     14            0.4953      0.6625        0.8176           0.5000       0.4741            0.4741      0.6432        0.8392           0.5000        108.5942
     15            0.4953      0.6625        0.8180           0.5000       0.4741            0.4741      0.6432        0.8392           0.5000        108.7051
     16            0.4953      0.6625        0.8187           0.5000       0.4741            0.4741      0.6432        0.8392           0.5000        109.1842
     17            0.4953      0.6625        0

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=Sequential(
    (0): HybridNet(
      (reduced_deep_model): Sequential(
        (ensuredims): Ensure4d()
        (dimshuffle): Rearrange('batch C T 1 -> batch 1 T C')
        (conv_time_spat): CombinedConv(
          (conv_time): Conv2d(1, 20, kernel_size=(10, 1), stride=(1, 1))
          (conv_spat): Conv2d(20, 30, kernel_size=(1, 10), stride=(1, 1), bias=False)
        )
        (bnorm): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv_nonlin): Expression(expression=elu) 
        (pool): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=(1, 1), ceil_mode=False)
        (pool_nonlin): Expression(expression=identity) 
        (drop_2): Dropout(p=0.5, inplace=False)
        (conv_2): Conv2d(30, 40, kernel_size=(10, 1), stride=(1, 1), dilation=(3, 1), bias=False)
        (bnorm_2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [None]:
classifier.save_params(f_params=f'model/chan_{model_name}best_param.pkl',f_optimizer=f'model/chan_{model_name}best_opt.pkl', f_history=f'model/chan_{model_name}best_history.json')

In [31]:
classifier.load_params(f_params=f'model/chan_{model_name}best_param.pkl',f_optimizer=f'model/chan_{model_name}best_opt.pkl', f_history=f'model/chan_{model_name}best_history.json')
print("Paramters Loaded")

Paramters Loaded


In [32]:
#This block finds the accuracy, f1 score and roc auc of the valid/test set
pred_labels=classifier.predict(eval_x)
auc=roc_auc_score(test_y,classifier.predict_proba(eval_x)[:,1])
accuracy=np.mean(pred_labels==test_y)
tp=np.sum(pred_labels*test_y)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(test_y)
f1=2*precision*recall/(precision+recall)

print(model_name)
print(f"Accuracy:{accuracy}")
print(f"F1-Score:{f1}")
print(f"roc_auc score:{auc}")

TCN
Accuracy:0.7333333333333333
F1-Score:0.7142857142857143
roc_auc score:0.755281690140845


In [None]:
print(model)

In [None]:
#Hybrid CNN and LSTM block
from skorch import NeuralNet
network=NeuralNet(module=model,criterion=torch.nn.modules.loss.NLLLoss,batch_size=batch_size,device=device)
network.initialize()
network.load_params(
    f_params=f'model/chan_{model_name}best_param.pkl', f_optimizer=f'model/chan_{model_name}best_opt.pkl', f_history=f'model/chan_{model_name}best_history.json')
network.module_=torch.nn.Sequential(*(list(network.module_.children())[:-3]),nn.modules.Flatten())
test=torch.ones(size=(2,n_chans,input_time_length))
feat=network.predict(test).shape[1]
print(feat)
del test

In [None]:
train_features=network.predict(train_x)
eval_features=network.predict(eval_x)

In [None]:
#For channel-wise classification, we used 10 minutes of a single channel instead of 1 minute of all channels, we shall consider the features to be contiguous
#We had converted the 1D 10 minute window into 2D 1 minute window with 10 channels, we will just reshape the 1D windows into 2D again
#t variable determines timesteps for hybrid model
t=5
factor=feat//t
train_features=train_features[:,:factor*t].reshape(len(train_y),t,factor)
eval_features=eval_features[:,:factor*t].reshape(len(eval_y),t,factor)
print(factor)

In [None]:
class SimpleModel(torch.nn.Module):
  def __init__(self,input_features):
    super().__init__()
    self.lstm = torch.nn.LSTM(input_size=input_features, hidden_size=50, batch_first=True)
    self.fc = torch.nn.Linear(50, 2)
    self.tanh = torch.nn.Tanh()
    self.softmax = torch.nn.LogSoftmax(dim=1)

  def forward(self, inputs):
    h1, _ = self.lstm(inputs)
    h2=self.tanh(h1[:,-1,:])#Final output of LSTM will be fed to linear layer
    h3 = self.fc(h2)
    output = self.softmax(h3)
    return output
model = SimpleModel(factor)
test=torch.ones(3,5,factor)
out=model.forward(test)
print(out.shape)

In [None]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params='chan_LSTMbest_param.pkl',f_optimizer='chan_LSTMbest_opt.pkl',f_history='chan_LSTMbest_history.json')
classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(Dataset(eval_features,eval_y)),
        optimizer__lr=0.0001,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",'roc_auc',cp],
        warm_start=True,
        )
classifier.initialize()

In [None]:
classifier.fit(train_features,y=train_y,epochs=30)