In [None]:
import numpy as np
import torch
import scipy
from torch import nn
from torch.nn.functional import elu,relu,leaky_relu
import braindecode 
from braindecode.models import *
from braindecode.models.modules import Expression
from braindecode.models.functions import squeeze_final_output,square,safe_log
from braindecode.datasets import BaseConcatDataset,create_from_X_y
from skorch.dataset import Dataset
from skorch.callbacks import Checkpoint
from skorch.helper import predefined_split
from config import *
from dataset import *
from sklearn.metrics import roc_auc_score
from mne import set_log_level
import resampy
set_log_level(False)
device = 'cuda' if cuda else 'cpu'

In [None]:
preproc_functions = []
preproc_functions.append( lambda data, fs: (data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
preproc_functions.append(lambda data, fs: (data[:, :int(duration_recording_mins * 60 * fs)], fs))
if max_abs_val is not None:
    preproc_functions.append(lambda data, fs:(np.clip(data, -max_abs_val, max_abs_val), fs))
preproc_functions.append(lambda data, fs: (resampy.resample(data, fs,sampling_freq,axis=1,filter='kaiser_fast'),sampling_freq))
if divisor is not None:
    preproc_functions.append(lambda data, fs: (data / divisor, fs))
dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='train',
                           sensor_types=sensor_types)
if test_on_eval:
    test_dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='eval',
                           sensor_types=sensor_types)

In [None]:
#Numpy array doesn't work as they take too much space, use BaseConcatDataset instead
#BaseConcatDataset does work recursively with itself.
X,y=dataset.load()
if test_on_eval:
    test_x,test_y=test_dataset.load()

In [None]:
del divisor,max_abs_val,sec_to_cut,duration_recording_mins,preproc_functions,data_folders
def create_set(X, y, inds):
    """
    X list and y nparray
    :return: 
    """
    new_X = []
    for i in inds:
        new_X.append(X[i])
    new_y = y[inds]
    return (new_X, new_y)

#Use of TrainValidTestSplitter is not necessary in newer versions of braindecode
class TrainValidSplitter(object):
    def __init__(self, n_folds, i_valid_fold, shuffle):
        self.n_folds = n_folds
        self.i_valid_fold = i_valid_fold
        self.rng = np.random.RandomState(39483948)
        self.shuffle = shuffle

    def split(self, X, y):
        if len(X) < self.n_folds:
            raise ValueError("Less Trials: {:d} than folds: {:d}".format(
                len(X), self.n_folds
            ))
        indices=np.arange(len(y))
        #Compared to paper, the valid set will be unbalanced
        batch_size=len(X)//self.n_folds
        if self.shuffle:
            self.rng.shuffle(indices)
        valid_inds=indices[self.i_valid_fold*batch_size:(self.i_valid_fold+1)*batch_size]
        train_inds = np.setdiff1d(indices,valid_inds)
        train_set = create_set(X, y, train_inds)
        valid_set = create_set(X, y, valid_inds)
        return train_set, valid_set
    
#To get timesteps, we can use numpy.reshape
def create_windows(X,y,stride=sampling_freq):
    no_of_trials=0
    trials=[]
    labels=[]
    for i in range(len(X)):
        no_of_trials+=((X[i].shape[1]-input_time_length)//stride)-1

    trials=np.zeros(shape=(no_of_trials,21,6000),dtype=np.float32)
    position=0
    for i in range(len(X)):
        windows=[]
        no_of_windows=((X[i].shape[1]-input_time_length)//stride)-1
        for j in range(no_of_windows):
            windows.append(X[i][:,j*stride:j*stride+input_time_length])
            labels.append(y[i])
        trials[position:position+no_of_windows]=np.array(windows)
        position+=no_of_windows
    labels=np.array(labels)
    return trials,labels

In [None]:
if test_on_eval==False:
    splitter=TrainValidSplitter(n_folds,i_test_fold,True)
    train_set,valid_set=splitter.split(X,y)
    del X,y
    X,y=train_set
    valid_X,valid_y=valid_set
    del n_folds,i_test_fold,train_set,valid_set

In [None]:
'''
Methods of data augmentation:-
Time Warping: This involves stretching or compressing the time axis. In time-series analysis, it can lead to a better understanding of variations in time.
Mathematical Explanation:
x′(t)=x(a⋅t)
where a is the warping factor.
Window Slicing: Similar to cropping, but with fixed-size windows. Overlapping windows can also be used to increase the amount of data.
Time Masking: Certain time steps are masked (set to zero or mean value), which can help the model become more robust to missing data.
Noise Injection: Random noise can be added to the sequence, aiding the model in learning to ignore irrelevant variations.
Mathematical Explanation:
x′(t)=x(t)+N(0,σ2)
where N(0,σ2) is Gaussian noise with mean 0 and variance σ2.
Data Mixing: By mixing two or more sequences, you can create a new sequence. For instance, in audio processing, overlaying two sound tracks.
Temporal Jittering: It involves adding small random shifts to the temporal alignment of the sequence. It's often used in speech and audio processing.
Sequence-to-sequence Transformation: This involves applying complex transformations like Fourier transform followed by an inverse transformation after modifications in the frequency domain.
Mathematical Explanation:
X′=F−1(F(X)+N)
where F and F−1 are the Fourier and inverse Fourier transforms, and N is a noise term.
'''

In [None]:
#This block will be used to separate the abnormal and normal training trials
abnormal_indexes=np.nonzero(y)[0][::-1]
abnormal=[]
for i in abnormal_indexes:
    abnormal.append(X.pop(i))
abnormal_labels=y[i:]
y=y[:i]
del abnormal_indexes
print(f"normal cases:{len(y)}")
print(f"abnormal cases:{len(abnormal_labels)}")
print(len(y)//len(abnormal_labels))

In [None]:
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
#we take a 10 second stride as 1 second stride takes too long
stride=sampling_freq*10
train_set=create_from_X_y(X,y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                       window_stride_samples=stride)
del X,y
if test_on_eval==False:
    valid_set=create_from_X_y(valid_X,valid_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
    del valid_X,valid_y
elif test_on_eval:
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
    del test_x,test_y
abnormal_train_set=create_from_X_y(abnormal,abnormal_labels,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names
                            ,window_size_samples=input_time_length,window_stride_samples=sampling_freq)
del abnormal,abnormal_labels

print(f"normal windows:{len(train_set)}")
print(f"abnormal windows:{len(abnormal_train_set)}")
train_set=BaseConcatDataset([abnormal_train_set,train_set])
del abnormal_train_set

In [None]:
n_classes = 2
if model_name=="shallow":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = ShallowFBCSPNet(n_chans,
                                    n_classes,
                                    n_filters_time=n_start_chans,
                                    n_filters_spat=n_start_chans,
                                    n_times=input_time_length,
                                    final_conv_length='auto',)
    test=torch.ones(size=(7,21,6000))
    out=model.forward(test)
    print(out.shape)
elif model_name == 'shallow_smac':
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    #conv_nonlin = identity
    do_batch_norm = True
    drop_prob = 0.328794
    filter_time_length = 56
    n_filters_spat = 73
    n_filters_time = 24
    pool_mode = 'max'
    #pool_nonlin = identity
    pool_time_length = 84
    pool_time_stride = 3
    split_first_layer = True
    model = ShallowFBCSPNet(in_chans=n_chans, n_classes=n_classes,
                            n_filters_time=n_filters_time,
                            n_filters_spat=n_filters_spat,
                            n_times=input_time_length,
                            final_conv_length='auto',
                            #conv_nonlin=conv_nonlin,
                            batch_norm=do_batch_norm,
                            drop_prob=drop_prob,
                            filter_time_length=filter_time_length,
                            pool_mode=pool_mode,
                            #pool_nonlin=pool_nonlin,
                            pool_time_length=pool_time_length,
                            pool_time_stride=pool_time_stride,
                            split_first_layer=split_first_layer,
                            )
    test=torch.ones(size=(7,21,6000))
    out=model.forward(test)
    print(out.shape)
elif model_name=="deep":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    model = Deep4Net(n_chans, n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         n_times=input_time_length,
                         n_filters_2 = int(n_start_chans * n_chan_factor),
                         n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                         n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                         final_conv_length='auto',
                        stride_before_pool=True)
    test=torch.ones(size=(7,21,6000,1))
    out=model.forward(test)
    print(out.shape)
elif model_name=="deep_smac" or model_name == 'deep_smac_bnorm':
    optimizer_lr = 0.0000625
    if model_name == 'deep_smac':
            do_batch_norm = False
    else:
        do_batch_norm = True
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 14
    filter_length_4 = 32
    filter_time_length = 21
    #final_conv_length = 1
    first_nonlin = elu
    first_pool_mode = 'mean'
    later_nonlin = elu
    later_pool_mode = 'mean'
    n_filters_factor = 1.679066
    n_filters_start = 32
    pool_time_length = 1
    pool_time_stride = 2
    split_first_layer = True
    n_chan_factor = n_filters_factor
    n_start_chans = n_filters_start
    model = Deep4Net(n_chans, n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            n_times=input_time_length,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
            final_conv_length='auto',
            batch_norm=do_batch_norm,
            drop_prob=drop_prob,
            filter_length_2=filter_length_2,
            filter_length_3=filter_length_3,
            filter_length_4=filter_length_4,
            filter_time_length=filter_time_length,
            first_conv_nonlin=first_nonlin,
            first_pool_mode=first_pool_mode,
            later_conv_nonlin=later_nonlin,
            later_pool_mode=later_pool_mode,
            pool_time_length=pool_time_length,
            pool_time_stride=pool_time_stride,
            split_first_layer=split_first_layer,
            stride_before_pool=True)
    test=torch.ones(size=(6,21,6000,1))
    out=model.forward(test)
    print(out.shape)
    del do_batch_norm,drop_prob,filter_length_2,filter_length_3,filter_length_4,filter_time_length,first_nonlin,n_chan_factor,n_start_chans,first_pool_mode,later_nonlin,later_pool_mode,n_filters_factor,n_filters_start,pool_time_length,pool_time_stride,split_first_layer
#Works properly, fit the hybrid cnn
elif model_name=="hybrid":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = HybridNet(n_chans, n_classes,input_window_samples=input_time_length,)
    test=torch.ones(size=(2,21,6000))
    out=model.forward(test)
    out_length=out.shape[2]
    model.final_conv=nn.Conv2d(100,n_classes,(out_length,1),bias=True,)
    model=nn.Sequential(model,Expression(torch.squeeze))
    out=model.forward(test)
    print(out.shape)
    del out_length
elif model_name=="TCN":
    import warnings
    #This disables the warning of the dropout2d layers receiving 3d input
    warnings.filterwarnings("ignore")
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    n_blocks=7
    n_filters=32
    kernel_size=24
    drop_prob = 0.3
    add_log_softmax=False
    x=TCN(n_chans,n_classes,n_blocks,n_filters,kernel_size,drop_prob,add_log_softmax)
    test=torch.ones(size=(7,21,6000))
    out=x.forward(test)
    out_length=out.shape[2]
    #There is no hyperparameter where output of TCN is (Batch_Size,Classes) when input is (Batch_Size,21,6000) so add new layers to meet size
    model=nn.Sequential(x,nn.Conv1d(n_classes,n_classes,out_length,bias=True,),Expression(torch.squeeze),nn.LogSoftmax(dim=1))
    out=model.forward(test)
    print(out.shape)
    del out_length,x
elif model_name=="shallow_deep":
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 14
    filter_length_4 = 32
    n_filters_factor = 1.679066
    n_filters_start = 32
    split_first_layer = True
    n_chan_factor = n_filters_factor
    #n_start_chans = n_filters_start

    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    conv_time_length=25
    first_conv_nonlin=relu
    first_pool_nonlin=safe_log
    later_conv_nonlin=elu
    later_pool_nonlin=safe_log
    first_pool_mode = 'mean'
    later_pool_mode = 'mean'
    pool_time_length=15
    model = Deep4Net(n_chans, n_classes,
                            n_filters_time=n_start_chans,
                            n_filters_spat=n_start_chans,
                            n_times=input_time_length,
                            n_filters_2 = int(n_start_chans * n_chan_factor),
                            n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                            n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                            final_conv_length='auto',
                            first_pool_nonlin=first_pool_nonlin,
                            first_conv_nonlin=first_conv_nonlin,
                            #later_pool_nonlin=later_pool_nonlin,
                            #later_conv_nonlin=later_conv_nonlin,
                            filter_time_length=conv_time_length,
                            pool_time_length=pool_time_length,
                            first_pool_mode=first_pool_mode,
                            later_pool_mode=later_pool_mode,
                            split_first_layer=split_first_layer,
                            drop_prob=drop_prob,
                            filter_length_2=filter_length_2,
                            filter_length_3=filter_length_3,
                            filter_length_4=filter_length_4,
                            )
    test=torch.ones(size=(7,21,6000))
    out=model(test)
    print(out.shape)

elif model_name=="attention":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    model=ATCNet(n_chans,n_classes,input_time_length//sampling_freq,sampling_freq,concat=True)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
if cuda:
    model.cuda()
del test,out
print(model_name)

In [None]:
model

In [None]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params=f'{model_name}best_param.pkl',
               f_optimizer=f'{model_name}best_opt.pkl', f_history=f'{model_name}best_history.json')
if test_on_eval==False:
    path=f'{model_name}'
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=optimizer_lr,
        #optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],#Try ‘roc_auc’
        warm_start=True,
        )
elif test_on_eval:
    path=f'{model_name}II'
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(test_set),
        optimizer__lr=optimizer_lr,
        #optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],#Try ‘roc_auc’
        warm_start=True,
        )
classifier.initialize()
del model

In [None]:
test=np.random.rand(3,n_chans,input_time_length)
out=classifier.predict(test)
print(out)

In [None]:
#Loads Phase 1 parameters and fit them further in phase 2
path=f'{model_name}'
if test_on_eval:
    classifier.load_params(
        f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
    print("Paramters Loaded")
    path=f'{model_name}II'

In [None]:
#Used to load parameters for ongoing training
try:
    classifier.load_params(
        f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
    print("Paramters Loaded")
except:
    pass

In [None]:
#Shows the history of training the neural network
classifier.history_

In [None]:
classifier.fit(train_set,y=None,epochs=1)

In [None]:
classifier.save_params(
    f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')

In [None]:
classifier = braindecode.EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.AdamW,
    optimizer__lr=optimizer_lr,
    #optimizer__weight_decay=optimizer_weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    device=device,
    callbacks=["accuracy","f1",cp],#Try ‘roc_auc’
    warm_start=True,
        )
classifier.initialize()

In [None]:
test_x,test_y=test_dataset.load()
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
del test_x,test_y

In [None]:
#This block loads the best parameters and finds the accuracy, f1 score and roc auc of the valid/test set
classifier.load_params(
        f_params=f'model/{model_name}best_param.pkl', f_history=f'model/{model_name}best_history.json')
print("Paramters Loaded")
if test_on_eval==False:
    pred_labels=classifier.predict(valid_set)
    actual_labels=[label[1] for label in valid_set]
    auc=roc_auc_score(actual_labels,classifier.predict_proba(valid_set)[:,1])
elif test_on_eval:
    pred_labels=classifier.predict(test_set)
    actual_labels=[label[1] for label in test_set]
    auc=roc_auc_score(actual_labels,classifier.predict_proba(test_set)[:,1])
actual_labels=np.array(actual_labels)
accuracy=np.mean(pred_labels==actual_labels)
tp=np.sum(pred_labels*actual_labels)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(actual_labels)
f1=2*precision*recall/(precision+recall)

print(model_name)
print(f"Accuracy:{accuracy}")
print(f"F1-Score:{f1}")
print(f"roc_auc score:{auc}")

In [None]:
#This will load the model and parameters and then replace it with one whose classification layer is removed
from skorch import NeuralNet
network=NeuralNet(module=model,criterion=torch.nn.modules.loss.NLLLoss,batch_size=batch_size,device=device)
network.initialize()
network.load_params(
    f_params=f'model/{model_name}best_param.pkl', f_optimizer=f'model/{model_name}best_opt.pkl', f_history=f'model/{model_name}best_history.json')
print("Paramters Loaded")
network.module_=torch.nn.Sequential(*(list(network.module_.children())[:-3]),nn.modules.Flatten())

In [None]:
network.module_

In [None]:
test=torch.ones(size=(2,21,6000))
feat=network.predict(test).shape[1]
print(feat)
del test

In [None]:
#Loads dataset, finds smallest trial, with this, we find number of windows using stride and convert it to array of windows of trials
#shape is (no_of_trials,no_of_windows,channels,input_time_length) in the end
X,y=dataset.load()
test_x,test_y=test_dataset.load()

In [None]:
#Separates normal and abnormal recordings
abnormal_indexes=np.nonzero(y)[0][::-1]
abnormal=[]
for i in abnormal_indexes:
    abnormal.append(X.pop(i))
abnormal_labels=y[i:]
y=y[:i]
del abnormal_indexes

#Counting normal trials windows
no_of_trials=0
stride=sampling_freq*10
for i in range(len(X)):
    no_of_trials+=((X[i].shape[1]-input_time_length)//stride)-1
#Counting abnormal trials windows
abstride=sampling_freq
for i in range(len(abnormal)):
    no_of_trials+=((abnormal[i].shape[1]-input_time_length)//abstride)-1
features=np.zeros(shape=(no_of_trials,feat),dtype=np.float32)
labels=[]

#Normal features
position=0
for i in range(len(X)):
    windows=[]
    no_of_windows=((X[i].shape[1]-input_time_length)//stride)-1
    for j in range(no_of_windows):
        windows.append(X[i][:,j*stride:j*stride+input_time_length])
        labels.append(y[i])
    windows=np.array(windows)
    features[position:position+no_of_windows]=network.predict(windows)
    position+=no_of_windows
del i,j,no_of_windows,X,y,windows
#Abnormal features
for i in range(len(abnormal)):
    windows=[]
    no_of_windows=((abnormal[i].shape[1]-input_time_length)//abstride)-1
    for j in range(no_of_windows):
        windows.append(abnormal[i][:,j*abstride:j*abstride+input_time_length])
        labels.append(abnormal_labels[i])
    windows=np.array(windows)
    features[position:position+no_of_windows]=network.predict(windows)
    position+=no_of_windows
del i,j,no_of_windows,abnormal,abnormal_labels,windows
labels=np.array(labels)

#This saves the features along with labels of each trial in a .mat file
scipy.io.savemat("E:/train_features.mat",{"x":features,"y":labels})
del features,labels

In [None]:
no_of_trials=0
#Test set must match test set from paper as much as possible
stride=sampling_freq
for i in range(len(test_x)):
    no_of_trials+=((test_x[i].shape[1]-input_time_length)//stride)-1
test_features=np.zeros(shape=(no_of_trials,feat),dtype=np.float32)
test_labels=[]
position=0
for i in range(len(test_x)):
    windows=[]
    no_of_windows=((test_x[i].shape[1]-input_time_length)//stride)-1
    for j in range(no_of_windows):
        windows.append(test_x[i][:,j*stride:j*stride+input_time_length])
        test_labels.append(test_y[i])
    windows=np.array(windows)
    test_features[position:position+no_of_windows]=network.predict(windows)
    position+=no_of_windows
del i,j,no_of_windows,test_x,test_y,windows
test_labels=np.array(test_labels)

scipy.io.savemat("E:/test_features.mat",{"x":test_features,"y":test_labels})
del test_features,test_labels

In [None]:
import scipy
import numpy as np
inputs=scipy.io.loadmat("E:/train_features.mat")
features=inputs["x"]
labels=inputs["y"].squeeze()
inputs=scipy.io.loadmat("E:/test_features.mat")
test_features=inputs["x"]
test_labels=inputs["y"].squeeze()
del inputs

In [None]:
#t variable determines timesteps for hybrid model
t=7
f=features.shape[-1]
seq_features=features[:(len(labels)//t)*t].reshape((len(labels)//t,t,f))
seq_labels=labels[:(len(labels)//t)*t].reshape((len(labels)//t,t))[:,0]

In [None]:
class SimpleModel(torch.nn.Module):
  def __init__(self,input_features):
    super().__init__()
    self.lstm = torch.nn.LSTM(input_size=input_features, hidden_size=50, batch_first=True)
    self.fc = torch.nn.Linear(50, 2)
    self.tanh = torch.nn.Tanh()
    self.softmax = torch.nn.LogSoftmax(dim=1)

  def forward(self, inputs):
    _, (h1_T,_) = self.lstm(inputs)
    h2=self.tanh(h1_T.squeeze())
    h3 = self.fc(h2)       # inplace of h2[-1,:,:] we can use h2_T. Both are identical
    output = self.softmax(h3)
    return output
model = SimpleModel(f)

In [None]:
seq_test_features=test_features[:(len(test_labels)//t)*t].reshape((len(test_labels)//t,t,f))
seq_test_labels=test_labels[:(len(test_labels)//t)*t].reshape((test_labels.shape[0]//t,t))[:,0]
test_set=Dataset(seq_test_features,seq_test_labels)

In [None]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params='LSTMbest_param.pkl',f_optimizer='LSTMbest_opt.pkl',f_history='LSTMbest_history.json')
classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(test_set),
        optimizer__lr=0.0001,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",'roc_auc',cp],
        warm_start=True,
        )
classifier.initialize()

In [None]:
test=torch.randn(size=(2,t,f))
shape=classifier.predict(test).shape
print(shape)

In [None]:
#Try deep smac by itself and as feature extractor and determine effectiveness
classifier.fit(seq_features,y=seq_labels,epochs=10)

In [None]:
out=classifier.predict(test_features)
accuracy=np.mean(out==test_labels)
print(f"Accuracy:{accuracy}")
tp=np.sum(out*test_labels)
precision=tp/np.sum(out)
recall=tp/np.sum(test_labels)
f1=2*precision*recall/(precision+recall)
print(f"F1-Score:{f1}") 
roc_auc_score(test_labels,classifier.predict_proba(valid_set)[:,1])

In [None]:
min_length=min([trial.shape[1] for trial in X])
trials=len(X)
X_new=np.zeros(shape=(trials,n_chans,min_length),dtype=np.float32)
for i in range(trials):
    X_new[i]=X[i][:,:min_length]
trials=len(test_x)
test_x_new=np.zeros(shape=(trials,n_chans,min_length),dtype=np.float32)
for i in range(trials):
    test_x_new[i]=test_x[i][:,:min_length]

In [None]:
min_length=min([trial.shape[1] for trial in X])
trials=len(X)
X_new=np.zeros(shape=(trials,n_chans,min_length),dtype=np.float32)
for i in range(trials):
    X_new[i]=X[i][:,:min_length]
trials=len(test_x)
test_x_new=np.zeros(shape=(trials,n_chans,min_length),dtype=np.float32)
for i in range(trials):
    test_x_new[i]=test_x[i][:,:min_length]
mid=len(X_new)//2
scipy.io.savemat("E:/train_set_1.mat",{"x":X_new[:mid,:,:],"y":y[:mid]})
scipy.io.savemat("E:/train_set_2.mat",{"x":X_new[mid:,:,:],"y":y[mid:]})
scipy.io.savemat("E:/test_set.mat",{"x":test_x_new,"y":test_y})

In [1]:
import torch
from torch import nn
from torch.nn.functional import elu,relu,leaky_relu
import torchvision.transforms as transform
import braindecode
from braindecode.models import *
from braindecode.models.modules import Expression
from braindecode.models.functions import squeeze_final_output,square,safe_log
from skorch.dataset import Dataset
from skorch.callbacks import Checkpoint
from skorch.helper import predefined_split
from config import *
from dataset import *
from sklearn.metrics import roc_auc_score
from mne import set_log_level
set_log_level(False)
device = 'cuda' if cuda else 'cpu'

Tensorflow not install, you could not use those pipelines


In [2]:
#Loads preprocessed data from mat files
import scipy
import numpy as np
inputs=scipy.io.loadmat("E:/train_set_1.mat")
X=inputs["x"]
y=inputs["y"].squeeze()

inputs=scipy.io.loadmat("E:/train_set_2.mat")
X=np.concatenate((X,inputs["x"]),axis=0)
y=np.concatenate((y,inputs["y"].squeeze()),axis=0)
input_time_length=X.shape[-1]
inputs=scipy.io.loadmat("E:/test_set.mat")
test_x=inputs["x"]
test_y=inputs["y"].squeeze()
del inputs

In [16]:
#We will now train the model by taking pairs or combinations of channels and passing their entire length.
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
picked_ch=[0,1,2,3,4,11,16,18,17]
train_x=X[:,picked_ch]
n_chans=len(picked_ch)
eval_x=test_x[:,picked_ch]
print(f'Channels chosen:{[ch_names[ch] for ch in picked_ch]}')

Channels chosen:['A1', 'A2', 'C3', 'C4', 'CZ', 'FZ', 'PZ', 'T4', 'T3']


In [17]:
model_name='shallow_deep'

In [18]:
n_classes = 2
if model_name=="shallow":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    pool_time_length=150
    pool_time_stride=50
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = ShallowFBCSPNet(n_chans,
                                    n_classes,
                                    n_filters_time=n_start_chans,
                                    n_filters_spat=n_start_chans,
                                    input_window_samples=input_time_length,
                                    pool_time_length=pool_time_length,
                                    pool_time_stride=pool_time_stride,
                                    final_conv_length='auto',)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="deep":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    model = Deep4Net(n_chans, n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_window_samples=input_time_length,
                         n_filters_2 = int(n_start_chans * n_chan_factor),
                         n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                         n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                         final_conv_length='auto',
                        stride_before_pool=True)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
elif model_name=="deep_smac" or model_name == 'deep_smac_bnorm':
    optimizer_lr = 0.0000625
    if model_name == 'deep_smac':
            do_batch_norm = False
    else:
        do_batch_norm = True
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 24
    filter_length_4 = 36
    filter_time_length = 21
    #final_conv_length = 1
    first_nonlin = elu
    first_pool_mode = 'mean'
    later_nonlin = elu
    later_pool_mode = 'mean'
    n_filters_factor = 1.679066
    n_filters_start = 32
    pool_time_length = 3
    pool_time_stride = 3
    split_first_layer = True
    n_chan_factor = n_filters_factor
    n_start_chans = n_filters_start
    model = Deep4Net(n_chans, n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_window_samples=input_time_length,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
            final_conv_length='auto',
            batch_norm=do_batch_norm,
            drop_prob=drop_prob,
            filter_length_2=filter_length_2,
            filter_length_3=filter_length_3,
            filter_length_4=filter_length_4,
            filter_time_length=filter_time_length,
            first_conv_nonlin=first_nonlin,
            first_pool_mode=first_pool_mode,
            later_conv_nonlin=later_nonlin,
            later_pool_mode=later_pool_mode,
            #pool_time_length=pool_time_length,
            #pool_time_stride=pool_time_stride,
            split_first_layer=split_first_layer,
            stride_before_pool=True)
elif model_name=="shallow_deep":
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 14
    filter_length_4 = 32
    n_filters_factor = 1.679066
    n_filters_start = 32
    split_first_layer = True
    n_chan_factor = n_filters_factor
    #n_start_chans = n_filters_start

    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    conv_time_length=25
    first_conv_nonlin=relu
    first_pool_nonlin=safe_log
    later_conv_nonlin=elu
    later_pool_nonlin=safe_log
    first_pool_mode = 'mean'
    later_pool_mode = 'mean'
    pool_time_length=15
    model = Deep4Net(n_chans, n_classes,
                            input_time_length,
                            n_filters_time=n_start_chans,
                            n_filters_spat=n_start_chans,
                            n_filters_2 = int(n_start_chans * n_chan_factor),
                            n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                            n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                            final_conv_length='auto',
                            first_pool_nonlin=first_pool_nonlin,
                            first_conv_nonlin=first_conv_nonlin,
                            #later_pool_nonlin=later_pool_nonlin,
                            #later_conv_nonlin=later_conv_nonlin,
                            filter_time_length=conv_time_length,
                            pool_time_length=pool_time_length,
                            first_pool_mode=first_pool_mode,
                            later_pool_mode=later_pool_mode,
                            split_first_layer=split_first_layer,
                            drop_prob=drop_prob,
                            filter_length_2=filter_length_2,
                            filter_length_3=filter_length_3,
                            filter_length_4=filter_length_4,
                            )
elif model_name=="attention":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    model=ATCNet(n_chans,n_classes,input_time_length//sampling_freq,sampling_freq,concat=True,tcn_depth=3)
    test=torch.ones(size=(7,n_chans,input_time_length))
    out=model.forward(test)
    print(out.shape)
    del test,out
if cuda:
    model.cuda()

print(model_name)

shallow_deep




In [19]:
model

Deep4Net(
  (ensuredims): Ensure4d()
  (dimshuffle): Rearrange('batch C T 1 -> batch 1 T C')
  (conv_time_spat): CombinedConv(
    (conv_time): Conv2d(1, 25, kernel_size=(25, 1), stride=(1, 1))
    (conv_spat): Conv2d(25, 25, kernel_size=(1, 9), stride=(1, 1), bias=False)
  )
  (bnorm): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=relu) 
  (pool): AvgPool2dWithConv()
  (pool_nonlin): Expression(expression=safe_log) 
  (drop_2): Dropout(p=0.244445, inplace=False)
  (conv_2): Conv2d(25, 41, kernel_size=(12, 1), stride=(1, 1), bias=False)
  (bnorm_2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (nonlin_2): Expression(expression=elu) 
  (pool_2): AvgPool2dWithConv()
  (pool_nonlin_2): Expression(expression=identity) 
  (drop_3): Dropout(p=0.244445, inplace=False)
  (conv_3): Conv2d(41, 70, kernel_size=(14, 1), stride=(1, 1), bias=False)
  (bnorm_3): BatchNorm2d(70, eps=1e-05, mom

In [20]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params=f'chanbest_param.pkl',
               f_optimizer=f'chanbest_opt.pkl', f_history=f'chanbest_history.json')
classifier = braindecode.EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(Dataset(eval_x,test_y)),
    optimizer__lr=optimizer_lr,
    #optimizer__weight_decay=optimizer_weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    device=device,
    callbacks=["accuracy","f1",'roc_auc',cp],#Try 'roc_auc'
    warm_start=True,
    )
classifier.initialize()
del model

In [11]:
classifier.load_params(f_params=f'model/chanbest_param.pkl',f_optimizer=f'model/chanbest_opt.pkl', f_history=f'model/chanbest_history.json')
print("Paramters Loaded")

Paramters Loaded


In [21]:
test=np.random.rand(7,n_chans,input_time_length)
out=classifier.predict(test)
print(out)

[1 1 1 1 1 1 1]


In [32]:
classifier.fit(train_x,y,epochs=1)

      9            [36m0.9766[0m      [32m0.9240[0m        [35m0.1249[0m           [31m0.9982[0m       0.6444            0.6444      0.4419        0.9154           0.8259        465.1024


<class 'braindecode.classifier.EEGClassifier'>[initialized](
  Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
  Deep4Net (Deep4Net)                      [1, 9, 60000]             [1, 2]                    --                        --
  ├─Ensure4d (ensuredims): 1-1             [1, 9, 60000]             [1, 9, 60000, 1]          --                        --
  ├─Rearrange (dimshuffle): 1-2            [1, 9, 60000, 1]          [1, 1, 60000, 9]          --                        --
  ├─CombinedConv (conv_time_spat): 1-3     [1, 1, 60000, 9]          [1, 25, 59976, 1]         6,275                     --
  ├─BatchNorm2d (bnorm): 1-4               [1, 25, 59976, 1]         [1, 25, 59976, 1]         50                        --
  ├─Expression (conv_nonlin): 1-5          [1, 25, 59976, 1]         [1, 25, 59976, 1]         --                        --
  ├─AvgPool2dWithConv (pool): 1-6          [1, 25, 59976, 1] 

In [29]:
classifier.save_params(f_params=f'model/chanbest_param.pkl',f_optimizer=f'model/chanbest_opt.pkl', f_history=f'model/chanbest_history.json')

In [14]:
#This block finds the accuracy, f1 score and roc auc of the valid/test set
pred_labels=classifier.predict(eval_x)
auc=roc_auc_score(test_y,classifier.predict_proba(eval_x)[:,1])
accuracy=np.mean(pred_labels==test_y)
tp=np.sum(pred_labels*test_y)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(test_y)
f1=2*precision*recall/(precision+recall)

print(model_name)
print(f"Accuracy:{accuracy}")
print(f"F1-Score:{f1}")
print(f"roc_auc score:{auc}")

Paramters Loaded
shallow_deep
Accuracy:0.6444444444444445
F1-Score:0.4146341463414634
roc_auc score:0.8455105633802817
