In [1]:
import numpy as np
import torch
import scipy
from torch import nn
from torch.nn.functional import elu,relu,leaky_relu
import torchvision.transforms as transform
import braindecode 
from braindecode.models import *
from braindecode.models.modules import Expression
from braindecode.models.functions import squeeze_final_output
from braindecode.datasets import BaseDataset, BaseConcatDataset,create_from_X_y
from braindecode.models.util import to_dense_prediction_model, get_output_shape
import pandas as pd
import resampy
from skorch.dataset import Dataset
from skorch.callbacks import Checkpoint,ProgressBar
from skorch.helper import predefined_split
from config import *
from dataset import *
from braindecode.preprocessing import create_fixed_length_windows
from mne import set_log_level
set_log_level(False)
device = 'cuda' if cuda else 'cpu'

Tensorflow not install, you could not use those pipelines


In [2]:
preproc_functions = []
preproc_functions.append( lambda data, fs: (data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
preproc_functions.append(lambda data, fs: (data[:, :int(duration_recording_mins * 60 * fs)], fs))
if max_abs_val is not None:
    preproc_functions.append(lambda data, fs:(np.clip(data, -max_abs_val, max_abs_val), fs))
preproc_functions.append(lambda data, fs: (resampy.resample(data, fs,sampling_freq,axis=1,filter='kaiser_fast'),sampling_freq))
if divisor is not None:
    preproc_functions.append(lambda data, fs: (data / divisor, fs))
dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='train',
                           sensor_types=sensor_types)
if test_on_eval:
    test_dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='eval',
                           sensor_types=sensor_types)

In [3]:
#To get timesteps, we can use numpy.reshape
def create_windows(X,y,stride=sampling_freq):
    no_of_trials=0
    trials=[]
    labels=[]
    for i in range(len(X)):
        no_of_trials+=((X[i].shape[1]-input_time_length)//stride)-1

    trials=np.zeros(shape=(no_of_trials,21,6000),dtype=np.float32)
    position=0
    for i in range(len(X)):
        windows=[]
        no_of_windows=((X[i].shape[1]-input_time_length)//stride)-1
        for j in range(no_of_windows):
            windows.append(X[i][:,j*stride:j*stride+input_time_length])
            labels.append(y[i])
        trials[position:position+no_of_windows]=np.array(windows)
        position+=no_of_windows
    labels=np.array(labels)
    return trials,labels

In [4]:
#Numpy array doesn't work as they take too much space, use BaseConcatDataset instead
#BaseConcatDataset does work recursively with itself.
X,y=dataset.load()
if test_on_eval:
    test_x,test_y=test_dataset.load()

In [5]:
'''
Methods of data augmentation:-
Time Warping: This involves stretching or compressing the time axis. In time-series analysis, it can lead to a better understanding of variations in time.
Mathematical Explanation:
x′(t)=x(a⋅t)
where a is the warping factor.
Window Slicing: Similar to cropping, but with fixed-size windows. Overlapping windows can also be used to increase the amount of data.
Time Masking: Certain time steps are masked (set to zero or mean value), which can help the model become more robust to missing data.
Noise Injection: Random noise can be added to the sequence, aiding the model in learning to ignore irrelevant variations.
Mathematical Explanation:
x′(t)=x(t)+N(0,σ2)
where N(0,σ2) is Gaussian noise with mean 0 and variance σ2.
Data Mixing: By mixing two or more sequences, you can create a new sequence. For instance, in audio processing, overlaying two sound tracks.
Temporal Jittering: It involves adding small random shifts to the temporal alignment of the sequence. It's often used in speech and audio processing.
Sequence-to-sequence Transformation: This involves applying complex transformations like Fourier transform followed by an inverse transformation after modifications in the frequency domain.
Mathematical Explanation:
X′=F−1(F(X)+N)
where F and F−1 are the Fourier and inverse Fourier transforms, and N is a noise term.
'''

"\nMethods of data augmentation:-\nTime Warping: This involves stretching or compressing the time axis. In time-series analysis, it can lead to a better understanding of variations in time.\nMathematical Explanation:\nx′(t)=x(a⋅t)\nwhere a is the warping factor.\nWindow Slicing: Similar to cropping, but with fixed-size windows. Overlapping windows can also be used to increase the amount of data.\nTime Masking: Certain time steps are masked (set to zero or mean value), which can help the model become more robust to missing data.\nNoise Injection: Random noise can be added to the sequence, aiding the model in learning to ignore irrelevant variations.\nMathematical Explanation:\nx′(t)=x(t)+N(0,σ2)\nwhere N(0,σ2) is Gaussian noise with mean 0 and variance σ2.\nData Mixing: By mixing two or more sequences, you can create a new sequence. For instance, in audio processing, overlaying two sound tracks.\nTemporal Jittering: It involves adding small random shifts to the temporal alignment of the

In [6]:
del divisor,max_abs_val,sec_to_cut,duration_recording_mins,preproc_functions,data_folders
def create_set(X, y, inds):
    """
    X list and y nparray
    :return: 
    """
    new_X = []
    for i in inds:
        new_X.append(X[i])
    new_y = y[inds]
    return (new_X, new_y)
#Use of TrainValidTestSplitter is not necessary in newer versions of braindecode
class TrainValidSplitter(object):
    def __init__(self, n_folds, i_valid_fold, shuffle):
        self.n_folds = n_folds
        self.i_valid_fold = i_valid_fold
        self.rng = np.random.RandomState(39483948)
        self.shuffle = shuffle

    def split(self, X, y):
        if len(X) < self.n_folds:
            raise ValueError("Less Trials: {:d} than folds: {:d}".format(
                len(X), self.n_folds
            ))
        indices=np.arange(len(y))
        #Compared to paper, the valid set will be unbalanced
        batch_size=len(X)//self.n_folds
        if self.shuffle:
            self.rng.shuffle(indices)
        valid_inds=indices[self.i_valid_fold*batch_size:(self.i_valid_fold+1)*batch_size]
        train_inds = np.setdiff1d(indices,valid_inds)
        train_set = create_set(X, y, train_inds)
        valid_set = create_set(X, y, valid_inds)
        return train_set, valid_set

In [7]:
if test_on_eval==False:
    splitter=TrainValidSplitter(n_folds,i_test_fold,True)
    train_set,valid_set=splitter.split(X,y)
    del X,y
    X,y=train_set
    valid_X,valid_y=valid_set
    del n_folds,i_test_fold,train_set,valid_set

In [8]:
#This block will be used to separate the abnormal and normal training trials
abnormal_indexes=np.nonzero(y)[0][::-1]
abnormal=[]
for i in abnormal_indexes:
    abnormal.append(X.pop(i))
abnormal_labels=y[i:]
y=y[:i]
del abnormal_indexes

In [9]:
print(f"normal cases:{len(y)}")
print(f"abnormal cases:{len(abnormal_labels)}")
print(len(y)//len(abnormal_labels))

normal cases:1396
abnormal cases:274
5


In [10]:
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
#we take a 10 second stride as 1 second stride takes too long
stride=sampling_freq*10
train_set=create_from_X_y(X,y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                       window_stride_samples=stride)
if test_on_eval==False:
    valid_set=create_from_X_y(valid_X,valid_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
    del valid_X,valid_y
elif test_on_eval:
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
    del test_x,test_y
del X,y

In [11]:
abnormal_train_set=create_from_X_y(abnormal,abnormal_labels,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names
                            ,window_size_samples=input_time_length,window_stride_samples=sampling_freq)
del abnormal,abnormal_labels,stride,ch_names

In [12]:
print(f"normal windows:{len(train_set)}")
print(f"abnormal windows:{len(abnormal_train_set)}")

normal windows:72722
abnormal windows:142814


In [13]:
train_set=BaseConcatDataset([abnormal_train_set,train_set])
del abnormal_train_set

In [14]:
n_classes = 2
if model_name=="shallow":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = ShallowFBCSPNet(n_chans,
                                    n_classes,
                                    n_filters_time=n_start_chans,
                                    n_filters_spat=n_start_chans,
                                    input_window_samples=input_time_length,
                                    final_conv_length='auto',)
    test=torch.ones(size=(7,21,6000))
    out=model.forward(test)
    print(out.shape)
elif model_name=="deep":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    model = Deep4Net(n_chans, n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_window_samples=input_time_length,
                         n_filters_2 = int(n_start_chans * n_chan_factor),
                         n_filters_3 = int(n_start_chans * (n_chan_factor ** 2.0)),
                         n_filters_4 = int(n_start_chans * (n_chan_factor ** 3.0)),
                         final_conv_length='auto',
                        stride_before_pool=True)
    test=torch.ones(size=(7,21,6000,1))
    out=model.forward(test)
    print(out.shape)
elif model_name=="deep_smac" or model_name == 'deep_smac_bnorm':
    optimizer_lr = init_lr
    if model_name == 'deep_smac':
            do_batch_norm = False
    else:
        do_batch_norm = True
    double_time_convs = False
    drop_prob = 0.244445
    filter_length_2 = 12
    filter_length_3 = 14
    filter_length_4 = 32
    filter_time_length = 21
    #final_conv_length = 1
    first_nonlin = elu
    first_pool_mode = 'mean'
    later_nonlin = elu
    later_pool_mode = 'mean'
    n_filters_factor = 1.679066
    n_filters_start = 32
    pool_time_length = 1
    pool_time_stride = 2
    split_first_layer = True
    n_chan_factor = n_filters_factor
    n_start_chans = n_filters_start
    model = Deep4Net(n_chans, n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_window_samples=input_time_length,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
            final_conv_length='auto',
            batch_norm=do_batch_norm,
            drop_prob=drop_prob,
            filter_length_2=filter_length_2,
            filter_length_3=filter_length_3,
            filter_length_4=filter_length_4,
            filter_time_length=filter_time_length,
            first_conv_nonlin=first_nonlin,
            first_pool_mode=first_pool_mode,
            later_conv_nonlin=later_nonlin,
            later_pool_mode=later_pool_mode,
            pool_time_length=pool_time_length,
            pool_time_stride=pool_time_stride,
            split_first_layer=split_first_layer,
            stride_before_pool=True)
    test=torch.ones(size=(6,21,6000,1))
    out=model.forward(test)
    print(out.shape)
    del do_batch_norm,double_time_convs,drop_prob,filter_length_2,filter_length_3,filter_length_4,filter_time_length,first_nonlin,n_chan_factor,n_start_chans,first_pool_mode,later_nonlin,later_pool_mode,n_filters_factor,n_filters_start,pool_time_length,pool_time_stride,split_first_layer
#Works properly, fit the hybrid cnn
elif model_name=="hybrid":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = HybridNet(n_chans, n_classes,input_window_samples=input_time_length,)
    test=torch.ones(size=(2,21,6000))
    out=model.forward(test)
    out_length=out.shape[2]
    model.final_conv=nn.Conv2d(100,n_classes,(out_length,1),bias=True,)
    model=nn.Sequential(model,Expression(torch.squeeze))
    out=model.forward(test)
    print(out.shape)
    del out_length
elif model_name=="TCN":
    import warnings
    #This disables the warning of the dropout2d layers receiving 3d input
    warnings.filterwarnings("ignore")
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    n_blocks=7
    n_filters=32
    kernel_size=24
    drop_prob = 0.3
    add_log_softmax=False
    x=TCN(n_chans,n_classes,n_blocks,n_filters,kernel_size,drop_prob,add_log_softmax)
    test=torch.ones(size=(7,21,6000,1))
    out=x.forward(test)
    print(out.shape)
    out_length=out.shape[2]
    #There is no hyperparameter where output of TCN is (Batch_Size,Classes) when input is (Batch_Size,21,6000) so add new layers to meet size
    model=nn.Sequential(x,nn.Conv1d(n_classes,n_classes,out_length,bias=True,),Expression(torch.squeeze),nn.LogSoftmax(dim=1))
    out=model.forward(test)
    print(out.shape)
    del out_length,x
if cuda:
    model.cuda()
del test,out
print(model_name)

torch.Size([6, 2])
deep_smac_bnorm


In [15]:
model

Deep4Net(
  (ensuredims): Ensure4d()
  (dimshuffle): Expression(expression=transpose_time_to_spat) 
  (conv_time): Conv2d(1, 32, kernel_size=(21, 1), stride=(1, 1))
  (conv_spat): Conv2d(32, 32, kernel_size=(1, 21), stride=(2, 1), bias=False)
  (bnorm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=elu) 
  (pool): AvgPool2dWithConv()
  (pool_nonlin): Expression(expression=identity) 
  (drop_2): Dropout(p=0.244445, inplace=False)
  (conv_2): Conv2d(32, 53, kernel_size=(12, 1), stride=(2, 1), bias=False)
  (bnorm_2): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (nonlin_2): Expression(expression=elu) 
  (pool_2): AvgPool2dWithConv()
  (pool_nonlin_2): Expression(expression=identity) 
  (drop_3): Dropout(p=0.244445, inplace=False)
  (conv_3): Conv2d(53, 90, kernel_size=(14, 1), stride=(2, 1), bias=False)
  (bnorm_3): BatchNorm2d(90, eps=1e-05, momentum=0.1, affine=True, track_runni

In [16]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params=f'{model_name}best_param.pkl',
               f_optimizer=f'{model_name}best_opt.pkl', f_history=f'{model_name}best_history.json')
if test_on_eval==False:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=optimizer_lr,
        #optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],#Try ‘roc_auc’
        warm_start=True,
        )
elif test_on_eval:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(test_set),
        optimizer__lr=optimizer_lr,
        #optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],#Try ‘roc_auc’
        warm_start=True,
        )
classifier.initialize()
del model

In [17]:
path=f'{model_name}II'

In [None]:
#Loads Phase 1 parameters and fit them further in phase 2
path=f'{model_name}'
if test_on_eval:
    classifier.load_params(
        f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
    print("Paramters Loaded")
    path=f'{model_name}II'

In [17]:
if test_on_eval:
    path=f'{model_name}II'
elif test_on_eval==False:
    path=f'{model_name}'
try:
    classifier.load_params(
        f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
    print("Paramters Loaded")
except:
    pass

Paramters Loaded


In [18]:
#Shows the history of training the neural network
classifier.history_

[{'batches': [{'train_loss': 0.8555908799171448, 'train_batch_size': 64},
   {'train_loss': 1.7851495742797852, 'train_batch_size': 64},
   {'train_loss': 3.5661375522613525, 'train_batch_size': 64},
   {'train_loss': 2.349621295928955, 'train_batch_size': 64},
   {'train_loss': 1.4877064228057861, 'train_batch_size': 64},
   {'train_loss': 1.1939198970794678, 'train_batch_size': 64},
   {'train_loss': 1.3494302034378052, 'train_batch_size': 64},
   {'train_loss': 1.188534140586853, 'train_batch_size': 64},
   {'train_loss': 1.6109366416931152, 'train_batch_size': 64},
   {'train_loss': 1.0558499097824097, 'train_batch_size': 64},
   {'train_loss': 1.3449910879135132, 'train_batch_size': 64},
   {'train_loss': 1.5934693813323975, 'train_batch_size': 64},
   {'train_loss': 1.2171958684921265, 'train_batch_size': 64},
   {'train_loss': 1.0476877689361572, 'train_batch_size': 64},
   {'train_loss': 1.0419310331344604, 'train_batch_size': 64},
   {'train_loss': 1.2860733270645142, 'train_b

In [19]:
classifier.fit(train_set,y=None,epochs=3)

     23            0.9816      0.9863        [35m0.0702[0m            0.7273      0.6975        1.4381        3030.0230
     24            0.8693      0.8926        [35m0.0676[0m            0.7000      0.6283        3.0188        2969.8640
     25            0.9812      0.9860        [35m0.0660[0m            0.7074      0.6695        1.6854        2924.9533


<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=Deep4Net(
    (ensuredims): Ensure4d()
    (dimshuffle): Expression(expression=transpose_time_to_spat) 
    (conv_time): Conv2d(1, 32, kernel_size=(21, 1), stride=(1, 1))
    (conv_spat): Conv2d(32, 32, kernel_size=(1, 21), stride=(2, 1), bias=False)
    (bnorm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin): Expression(expression=elu) 
    (pool): AvgPool2dWithConv()
    (pool_nonlin): Expression(expression=identity) 
    (drop_2): Dropout(p=0.244445, inplace=False)
    (conv_2): Conv2d(32, 53, kernel_size=(12, 1), stride=(2, 1), bias=False)
    (bnorm_2): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (nonlin_2): Expression(expression=elu) 
    (pool_2): AvgPool2dWithConv()
    (pool_nonlin_2): Expression(expression=identity) 
    (drop_3): Dropout(p=0.244445, inplace=False)
    (conv_3): Conv2d(53, 90, kernel_size=(14, 1), st

In [20]:
classifier.save_params(
    f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
#torch.save({"model":classifier.module_.state_dict(),"optimizer":classifier.optimizer_.state_dict()}, path)

In [None]:
if test_on_eval==False:
    pred_labels=classifier.predict(valid_set)
    actual_labels=[label[1] for label in valid_set]
elif test_on_eval:
    pred_labels=classifier.predict(test_set)
    actual_labels=[label[1] for label in test_set]
actual_labels=np.array(actual_labels)
accuracy=np.mean(pred_labels==actual_labels)
print(f"Accuracy:{accuracy}")
tp=np.sum(pred_labels*actual_labels)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(actual_labels)
f1=2*precision*recall/(precision+recall)
print(f"F1-Score:{f1}")

In [None]:
#Test the model on proper test set according to paper
if test_on_eval:
    try:
        del train_set,test_set
    except:
        pass
    test_x,test_y=test_dataset.load()
    ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
    #Stride between windows is set to sampling frequency as written in paper
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
    del test_x,test_y

In [None]:
if test_on_eval:    
    pred_labels=classifier.predict(test_set)
    actual_labels=[label[1] for label in test_set]
    actual_labels=np.array(actual_labels)
    accuracy=np.mean(pred_labels==actual_labels)
    print(f"Accuracy:{accuracy}")
    tp=np.sum(pred_labels*actual_labels)
    precision=tp/np.sum(pred_labels)
    recall=tp/np.sum(actual_labels)
    f1=2*precision*recall/(precision+recall)
    print(f"F1-Score:{f1}") 

In [5]:
#This will load the model and parameters and then replace it with one whose classification layer is removed
from skorch import NeuralNet
network=NeuralNet(module=model,criterion=torch.nn.modules.loss.NLLLoss,batch_size=batch_size,device=device)
network.initialize()
network.load_params(
    f_params=f'model/{model_name}best_param.pkl', f_optimizer=f'model/{model_name}best_opt.pkl', f_history=f'model/{model_name}best_history.json')
print("Paramters Loaded")
network.module_=torch.nn.Sequential(*(list(network.module_.children())[:-3]),nn.modules.Flatten())

Paramters Loaded


In [6]:
network.module_

Sequential(
  (0): Ensure4d()
  (1): Expression(expression=transpose_time_to_spat) 
  (2): Conv2d(1, 32, kernel_size=(21, 1), stride=(1, 1))
  (3): Conv2d(32, 32, kernel_size=(1, 21), stride=(2, 1), bias=False)
  (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Expression(expression=elu) 
  (6): AvgPool2dWithConv()
  (7): Expression(expression=identity) 
  (8): Dropout(p=0.244445, inplace=False)
  (9): Conv2d(32, 53, kernel_size=(12, 1), stride=(2, 1), bias=False)
  (10): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): Expression(expression=elu) 
  (12): AvgPool2dWithConv()
  (13): Expression(expression=identity) 
  (14): Dropout(p=0.244445, inplace=False)
  (15): Conv2d(53, 90, kernel_size=(14, 1), stride=(2, 1), bias=False)
  (16): BatchNorm2d(90, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (17): Expression(expression=elu) 
  (18): AvgPool2dWithConv()
  (19): Expression(expression

In [7]:
test=torch.ones(size=(2,21,6000))
feat=network.predict(test).shape[1]
print(feat)
del test

53605


In [8]:
#Loads dataset, finds smallest trial, with this, we find number of windows using stride and convert it to array of windows of trials
#shape is (no_of_trials,no_of_windows,channels,input_time_length) in the end
X,y=dataset.load()
test_x,test_y=test_dataset.load()

In [9]:
abnormal_indexes=np.nonzero(y)[0][::-1]
abnormal=[]
for i in abnormal_indexes:
    abnormal.append(X.pop(i))
abnormal_labels=y[i:]
y=y[:i]
del abnormal_indexes

In [10]:
#Counting normal trials windows
no_of_trials=0
stride=sampling_freq*10
for i in range(len(X)):
    no_of_trials+=((X[i].shape[1]-input_time_length)//stride)-1
#Counting abnormal trials windows
abstride=sampling_freq*10
for i in range(len(abnormal)):
    no_of_trials+=((abnormal[i].shape[1]-input_time_length)//abstride)-1
features=np.zeros(shape=(no_of_trials,feat),dtype=np.float32)
labels=[]
position=0
#Normal features
for i in range(len(X)):
    windows=[]
    no_of_windows=((X[i].shape[1]-input_time_length)//stride)-1
    for j in range(no_of_windows):
        windows.append(X[i][:,j*stride:j*stride+input_time_length])
        labels.append(y[i])
    windows=np.array(windows)
    features[position:position+no_of_windows]=network.predict(windows)
    position+=no_of_windows
del i,j,no_of_windows,X,y,windows
#Abnormal features
for i in range(len(abnormal)):
    windows=[]
    no_of_windows=((abnormal[i].shape[1]-input_time_length)//abstride)-1
    for j in range(no_of_windows):
        windows.append(abnormal[i][:,j*abstride:j*abstride+input_time_length])
        labels.append(abnormal_labels[i])
    windows=np.array(windows)
    features[position:position+no_of_windows]=network.predict(windows)
    position+=no_of_windows
del i,j,no_of_windows,abnormal,abnormal_labels,windows
labels=np.array(labels)

In [None]:
#This saves the features along with labels of each trial in a .mat file
scipy.io.savemat("E:/train_features.mat",{"x":features,"y":labels})
del features,labels

In [11]:
no_of_trials=0
#Test set must match test set from paper as much as possible
stride=sampling_freq*2
for i in range(len(test_x)):
    no_of_trials+=((test_x[i].shape[1]-input_time_length)//stride)-1
test_features=np.zeros(shape=(no_of_trials,feat),dtype=np.float32)
test_labels=[]
position=0
for i in range(len(test_x)):
    windows=[]
    no_of_windows=((test_x[i].shape[1]-input_time_length)//stride)-1
    for j in range(no_of_windows):
        windows.append(test_x[i][:,j*stride:j*stride+input_time_length])
        test_labels.append(test_y[i])
    windows=np.array(windows)
    test_features[position:position+no_of_windows]=network.predict(windows)
    position+=no_of_windows
del i,j,no_of_windows,test_x,test_y,windows
test_labels=np.array(test_labels)

In [None]:
scipy.io.savemat("E:/test_features.mat",{"x":test_features,"y":test_labels})
del test_features,test_labels

In [None]:
import scipy
import numpy as np
inputs=scipy.io.loadmat("E:/train_features.mat")
features=inputs["x"]
labels=inputs["y"].squeeze()
del inputs

In [12]:
#t variable determines timesteps for hybrid model
t=7
f=features.shape[-1]
features=features[:(len(labels)//t)*t].reshape((len(labels)//t,t,f))
labels=labels[:(len(labels)//t)*t].reshape((len(labels)//t,t))[:,0]

In [13]:
class SimpleModel(torch.nn.Module):
  def __init__(self,input_features):
    super().__init__()
    self.lstm = torch.nn.LSTM(input_size=input_features, hidden_size=50, batch_first=True)
    self.fc = torch.nn.Linear(50, 2)
    self.tanh = torch.nn.Tanh()
    self.softmax = torch.nn.LogSoftmax(dim=1)

  def forward(self, inputs):
    _, (h1_T,_) = self.lstm(inputs)
    h2=self.tanh(h1_T.squeeze())
    h3 = self.fc(h2)       # inplace of h2[-1,:,:] we can use h2_T. Both are identical
    output = self.softmax(h3)
    return output
model = SimpleModel(f)

In [None]:
inputs=scipy.io.loadmat("E:/test_features.mat")
test_features=inputs["x"]
test_labels=inputs["y"].squeeze()
test_set=Dataset(test_features,test_labels)

In [14]:
test_features=test_features[:(len(test_labels)//t)*t].reshape((len(test_labels)//t,t,f))
test_labels=test_labels[:(len(test_labels)//t)*t].reshape((test_labels.shape[0]//t,t))[:,0]
test_set=Dataset(test_features,test_labels)

In [15]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_f1_best',dirname='model',f_params='LSTMbest_param.pkl',f_optimizer='LSTMbest_opt.pkl',f_history='LSTMbest_history.json')
classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(test_set),
        optimizer__lr=init_lr,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",'roc_auc',cp],
        warm_start=True,
        )
classifier.initialize()

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=SimpleModel(
    (lstm): LSTM(53605, 50, batch_first=True)
    (fc): Linear(in_features=50, out_features=2, bias=True)
    (tanh): Tanh()
    (softmax): LogSoftmax(dim=1)
  ),
)

In [16]:
test=torch.randn(size=(2,t,f))
shape=classifier.predict(test)
print(shape)

[0 1]


In [18]:
#Try deep smac by itself and as feature extractor and determine effectiveness
classifier.fit(features,y=labels,epochs=10)

     11            0.9115      0.6734        [35m0.2190[0m           0.9375            0.6851      0.5370        0.8713           0.8279        50.8554
     12            0.9127      0.6821        [35m0.2119[0m           [31m0.9424[0m            0.6860      0.5426        0.8647           0.8284        67.6770
     13            [36m0.9225[0m      [32m0.7360[0m        [35m0.2114[0m           [31m0.9490[0m            0.6982      0.5743        0.7397           0.8243        49.1320
     14            0.9195      0.7178        [35m0.2017[0m           0.9487            0.6992      0.5727        0.8886           0.8285        48.5494
     15            [36m0.9230[0m      [32m0.7469[0m        [35m0.2007[0m           0.9478            0.7035      0.5921        0.8155           0.8257        49.1985
     16            [36m0.9304[0m      [32m0.7694[0m        [35m0.1977[0m           [31m0.9531[0m            [94m0.7141[0m      0.6069        0.7667           0.8308  

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=SimpleModel(
    (lstm): LSTM(53605, 50, batch_first=True)
    (fc): Linear(in_features=50, out_features=2, bias=True)
    (tanh): Tanh()
    (softmax): LogSoftmax(dim=1)
  ),
)

In [None]:
out=classifier.predict(test_features)
accuracy=np.mean(out==test_labels)
print(f"Accuracy:{accuracy}")
tp=np.sum(out*test_labels)
precision=tp/np.sum(out)
recall=tp/np.sum(test_labels)
f1=2*precision*recall/(precision+recall)
print(f"F1-Score:{f1}") 