In [None]:
import numpy as np
import torch
import scipy
from torch import nn
import torchvision.transforms as transform
import braindecode 
from braindecode.models import *
from braindecode.models.modules import Expression
from braindecode.models.functions import squeeze_final_output
from braindecode.datasets import BaseDataset, BaseConcatDataset,create_from_X_y
from braindecode.models.util import to_dense_prediction_model, get_output_shape
import pandas as pd
import resampy
from skorch.callbacks import Checkpoint,ProgressBar
from skorch.helper import predefined_split
from config import *
from dataset import *
from braindecode.preprocessing import create_fixed_length_windows
from mne import set_log_level
set_log_level(False)
device = 'cuda' if cuda else 'cpu'

In [None]:
preproc_functions = []
preproc_functions.append( lambda data, fs: (data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
preproc_functions.append(lambda data, fs: (data[:, :int(duration_recording_mins * 60 * fs)], fs))
if max_abs_val is not None:
    preproc_functions.append(lambda data, fs:(np.clip(data, -max_abs_val, max_abs_val), fs))
preproc_functions.append(lambda data, fs: (resampy.resample(data, fs,sampling_freq,axis=1,filter='kaiser_fast'),sampling_freq))
if divisor is not None:
    preproc_functions.append(lambda data, fs: (data / divisor, fs))
dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='train',
                           sensor_types=sensor_types)
if test_on_eval:
    test_dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='eval',
                           sensor_types=sensor_types)

In [None]:
X,y=dataset.load()
if test_on_eval:
    test_x,test_y=test_dataset.load()

In [None]:
del divisor,max_abs_val,sec_to_cut,duration_recording_mins,preproc_functions
def create_set(X, y, inds):
    """
    X list and y nparray
    :return: 
    """
    new_X = []
    for i in inds:
        new_X.append(X[i])
    new_y = y[inds]
    return (new_X, new_y)
#Use of TrainValidTestSplitter is not necessary in newer versions of braindecode
class TrainValidSplitter(object):
    def __init__(self, n_folds, i_valid_fold, shuffle):
        self.n_folds = n_folds
        self.i_valid_fold = i_valid_fold
        self.rng = np.random.RandomState(39483948)
        self.shuffle = shuffle

    def split(self, X, y):
        if len(X) < self.n_folds:
            raise ValueError("Less Trials: {:d} than folds: {:d}".format(
                len(X), self.n_folds
            ))
        indices=np.arange(len(y))
        #Compared to paper, the valid set will be unbalanced
        batch_size=len(X)//self.n_folds
        if self.shuffle:
            self.rng.shuffle(indices)
        valid_inds=indices[self.i_valid_fold*batch_size:(self.i_valid_fold+1)*batch_size]
        train_inds = np.setdiff1d(indices,valid_inds)
        train_set = create_set(X, y, train_inds)
        valid_set = create_set(X, y, valid_inds)
        return train_set, valid_set

In [None]:
if test_on_eval==False:
    splitter=TrainValidSplitter(n_folds,i_test_fold,True)
    train_set,valid_set=splitter.split(X,y)
    del X,y
    X,y=train_set
    valid_X,valid_y=valid_set
    del n_folds,i_test_fold,train_set,valid_set

In [None]:
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
#we take a 20 second stride as 1 sample and 1 second stride takes too long
stride=sampling_freq*30
train_set=create_from_X_y(X,y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                       window_stride_samples=stride)
if test_on_eval==False:
    valid_set=create_from_X_y(valid_X,valid_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=stride)
    del valid_X,valid_y
elif test_on_eval:
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=stride)
    del test_x,test_y
del stride,ch_names,X,y

In [None]:
n_classes = 2
if model_name=="shallow":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = ShallowFBCSPNet(n_chans,
                                    n_classes,
                                    n_filters_time=n_start_chans,
                                    n_filters_spat=n_start_chans,
                                    input_window_samples=input_time_length,
                                    final_conv_length='auto',)
    test=torch.ones(size=(7,21,6000))
    out=model.forward(test)
    print(out.shape)
if model_name=="deep":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    model = Deep4Net(n_chans, n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_window_samples=input_time_length,
                         n_filters_2 = int(n_start_chans * n_chan_factor),
                         n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                         n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                         final_conv_length='auto',
                        stride_before_pool=True)
    test=torch.ones(size=(7,21,6000,1))
    out=model.forward(test)
    print(out.shape)
#Works properly, fit the hybrid cnn
if model_name=="hybrid":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = HybridNet(n_chans, n_classes,input_window_samples=input_time_length,)
    test=torch.ones(size=(2,21,6000))
    out=model.forward(test)
    out_length=out.shape[2]
    model.final_conv=nn.Conv2d(100,n_classes,(out_length,1),bias=True,)
    model=nn.Sequential(model,Expression(torch.squeeze))
    out=model.forward(test)
    print(out.shape)
    del out_length
if model_name=="TCN":
    import warnings
    #This disables the warning of the dropout2d layers receiving 3d input
    warnings.filterwarnings("ignore")
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    n_blocks=7
    n_filters=32
    kernel_size=24
    drop_prob = 0.3
    add_log_softmax=False
    x=TCN(n_chans,n_classes,n_blocks,n_filters,kernel_size,drop_prob,add_log_softmax)
    test=torch.ones(size=(7,21,6000,1))
    out=x.forward(test)
    print(out.shape)
    out_length=out.shape[2]
    #There is no hyperparameter where output of TCN is (Batch_Size,Classes) when input is (Batch_Size,21,6000) so add new layers to meet size
    model=nn.Sequential(x,nn.Conv1d(n_classes,n_classes,out_length,bias=True,),Expression(torch.squeeze),nn.LogSoftmax(dim=1))
    out=model.forward(test)
    print(out.shape)
    del out_length,x
if cuda:
    model.cuda()
del test,out

In [None]:
model

In [None]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params=f'{model_name}best_param.pkl',
               f_optimizer=f'{model_name}best_opt.pkl', f_history=f'{model_name}best_history.json')
if test_on_eval==False:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=optimizer_lr,
        #optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],
        warm_start=True,
        )
elif test_on_eval:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(test_set),
        optimizer__lr=optimizer_lr,
        #optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],
        warm_start=True,
        )
classifier.initialize()
del model

In [None]:
#Loads Phase 1 parameters and fit them further in phase 2
path=f'{model_name}'
if test_on_eval:
    classifier.load_params(
        f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
    print("Paramters Loaded")
    path=f'{model_name}II'

In [None]:
if test_on_eval:
    path=f'{model_name}II'
elif test_on_eval==False:
    path=f'{model_name}'
try:
    classifier.load_params(
        f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
    print("Paramters Loaded")
except:
    pass

In [None]:
#Shows the history of training the neural network
classifier.history_

In [None]:
classifier.fit(train_set,y=None,epochs=1)

In [None]:
classifier.save_params(
    f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
#torch.save({"model":classifier.module_.state_dict(),"optimizer":classifier.optimizer_.state_dict()}, path)

In [None]:
if test_on_eval==False:
    pred_labels=classifier.predict(valid_set)
    actual_labels=[label[1] for label in valid_set]
elif test_on_eval:
    pred_labels=classifier.predict(test_set)
    actual_labels=[label[1] for label in test_set]
actual_labels=np.array(actual_labels)
accuracy=np.mean(pred_labels==actual_labels)
print(f"Accuracy:{accuracy}")
tp=np.sum(pred_labels*actual_labels)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(actual_labels)
f1=2*precision*recall/(precision+recall)
print(f"F1-Score:{f1}")

In [None]:
#Test the model on proper test set according to paper
if test_on_eval:
    try:
        del train_set,test_set
    except:
        pass
    test_x,test_y=test_dataset.load()
    ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
    #Stride between windows is set to sampling frequency as written in paper
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
    del test_x,test_y

In [None]:
if test_on_eval:    
    pred_labels=classifier.predict(test_set)
    actual_labels=[label[1] for label in test_set]
    actual_labels=np.array(actual_labels)
    accuracy=np.mean(pred_labels==actual_labels)
    print(f"Accuracy:{accuracy}")
    tp=np.sum(pred_labels*actual_labels)
    precision=tp/np.sum(pred_labels)
    recall=tp/np.sum(actual_labels)
    f1=2*precision*recall/(precision+recall)
    print(f"F1-Score:{f1}") 

In [None]:
#This will load the model and parameters and then replace it with one whose classification layer is removed
from skorch import NeuralNet
network=NeuralNet(module=model,criterion=torch.nn.modules.loss.NLLLoss,batch_size=batch_size,device=device)
network.initialize()
network.load_params(
    f_params=f'model/{model_name}best_param.pkl', f_optimizer=f'model/{model_name}best_opt.pkl', f_history=f'model/{model_name}best_history.json')
print("Paramters Loaded")
network.module_=torch.nn.Sequential(*(list(network.module_.children())[:-3]),nn.modules.Flatten())

In [None]:
#Loads dataset, finds smallest trial, with this, we find number of windows using stride and convert it to array of windows of trials
#shape is (no_of_trials,no_of_windows,channels,input_time_length) in the end
X,y=dataset.load()
min_shape=X[0].shape[1]
for arr in X:
    if min_shape>arr.shape[1]:
        min_shape=arr.shape[1]
print(min_shape)

In [None]:
#30 second stride between windows
stride=sampling_freq*20
no_of_windows=((min_shape-input_time_length)//stride)
#To make the features for the LSTM, we will make all the trials of same length as smallest to allow batch training
for i in range(len(X)):
    windows=[]
    for j in range(no_of_windows):
        windows.append(X[i][:,j*stride:j*stride+input_time_length])
    X[i]=np.asarray(windows)
trials=np.array(X)
del windows,X

In [None]:
#This will calculate the features before classification layer
features=[]
for i in range(len(trials)):
    out=network.predict(trials[i])
    features.append(out)
features=np.asarray(features)
del trials,out

In [None]:
#This saves the features along with labels of each trial in a .mat file
scipy.io.savemat("E:/train_features.mat",{"x":features,"y":y})
del features,y

In [1]:
import scipy
import numpy as np
inputs=scipy.io.loadmat("E:/train_features.mat")
X=inputs["x"][:,:16,:]
y=inputs["y"].squeeze()
_,t,f=X.shape
del inputs

In [3]:
import tensorflow as tf
from keras.layers import LSTM, Dense,Input
from keras.models import Model
from keras.callbacks import ModelCheckpoint
from keras.callbacks import EarlyStopping
inputsin= Input(shape=(t,f))

x=LSTM(50,activation='tanh')(inputsin)
predictions = Dense(2,activation='softmax')(x)
model = Model(inputs=inputsin, outputs=predictions)
del inputsin,predictions,x
opt=tf.keras.optimizers.Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.1, amsgrad=False)

model.compile(optimizer = opt, loss = 'sparse_categorical_crossentropy', metrics=['accuracy'])
es = EarlyStopping(monitor='val_loss', min_delta=0.01, mode='min', verbose=1, patience=15)
mc = ModelCheckpoint('model/LSTM_acc.hdf5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
mces = ModelCheckpoint('model/LSTM_loss.hdf5', monitor='val_loss', mode='min', verbose=1, save_best_only=True)

In [4]:
model.fit(X,y,validation_split=0.2,epochs=500,batch_size=8,verbose=1,callbacks=[es, mc,mces],shuffle=True)

Epoch 1/500
Epoch 1: val_accuracy improved from -inf to 0.23952, saving model to model\LSTM_acc.hdf5

Epoch 1: val_loss improved from inf to 5.83784, saving model to model\LSTM_loss.hdf5
Epoch 2/500
Epoch 2: val_accuracy did not improve from 0.23952

Epoch 2: val_loss did not improve from 5.83784
Epoch 3/500
Epoch 3: val_accuracy did not improve from 0.23952

Epoch 3: val_loss did not improve from 5.83784
Epoch 4/500
Epoch 4: val_accuracy did not improve from 0.23952

Epoch 4: val_loss did not improve from 5.83784
Epoch 5/500
Epoch 5: val_accuracy did not improve from 0.23952

Epoch 5: val_loss did not improve from 5.83784
Epoch 6/500
Epoch 6: val_accuracy did not improve from 0.23952

Epoch 6: val_loss did not improve from 5.83784
Epoch 7/500
Epoch 7: val_accuracy did not improve from 0.23952

Epoch 7: val_loss did not improve from 5.83784
Epoch 8/500
Epoch 8: val_accuracy did not improve from 0.23952

Epoch 8: val_loss did not improve from 5.83784
Epoch 9/500
Epoch 9: val_accuracy di

<keras.callbacks.History at 0x209f6ab6920>