In [None]:
from avalanche.benchmarks.generators import ni_benchmark
from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive, Cumulative, LwF, EWC, JointTraining, GEM, Replay
from torch.optim import Adam
from torch.nn import CrossEntropyLoss, MSELoss
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics,class_accuracy_metrics, loss_metrics, timing_metrics, cpu_usage_metrics, confusion_matrix_metrics, disk_usage_metrics, gpu_usage_metrics
from avalanche.training.plugins import EvaluationPlugin, EarlyStoppingPlugin
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
import pickle
import torch.nn as nn
import torch
import numpy as np
import sys
import time

from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce


from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch.utils.data as Data
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score,accuracy_score,precision_score,recall_score,f1_score,classification_report

from sklearn.utils import class_weight

import myimporter
from BCI_functions import *  # BCI_functions.ipynb contains some functions we might use multiple times in this tutorial
import warnings
warnings.filterwarnings('ignore')
import os
os.getcwd()

In [None]:
# TODO: This class if has list of subject id can later support combination of sub ids
# TODO: add a function transform to convert dataset to train test, avoiding repetition of same code

class EEGMMIDTrSet(Data.Dataset):
    def __init__(self, subject_id, transform=None):
        root_dir = "../Deep-Learning-for-BCI/dataset/"
        dataset_raw = np.load(root_dir + str(subject_id) + '.npy')
        dataset=[]  # feature after filtering

        # EEG Gamma pattern decomposition
        for i in range(dataset_raw[:,:-1].shape[1]):
            x = dataset_raw[:, i]
            fs = 160.0
            lowcut = 8.0
            highcut = 30.0
            y = butter_bandpass_filter(x, lowcut, highcut, fs, order=3)
            dataset.append(y)
        dataset=np.array(dataset).T
        dataset=np.hstack((dataset,dataset_raw[:,-1:]))
        print(dataset.shape)
        # keep 4,5 which are left and right fist open close imagery classes, remove rest
        # refer 1-Data.ipynb for the details
        removed_label = [0,1,6,7,8,9,10]  # [0,1,2,3,4,5,10] for hf # [0,1,6,7,8,9,10] for lr
        for ll in removed_label:
            id = dataset[:, -1]!=ll
            dataset = dataset[id]

        # Pytorch needs labels to be sequentially ordered starting from 0
        dataset[:, -1][dataset[:, -1] == 2] = 0
        dataset[:, -1][dataset[:, -1] == 4] = 0
        dataset[:, -1][dataset[:, -1] == 3] = 1
        dataset[:, -1][dataset[:, -1] == 5] = 1
#         dataset[:, -1][dataset[:, -1] == 10] = 2
        
        # data segmentation
        n_class = 2 #int(11-len(removed_label))  # 0~9 classes ('10:rest' is not considered)
        no_feature = 64  # the number of the features
        segment_length = 160 #160  # selected time window; 16=160*0.1
        
        #Overlapping is removed to avoid training set overlap with test set
        data_seg = extract(dataset, n_classes=n_class, n_fea=no_feature, 
                           time_window=segment_length, moving=(segment_length))  # /2 for 50% overlapping
        print('After segmentation, the shape of the data:', data_seg.shape)

        # split training and test data
        no_longfeature = no_feature*segment_length
        data_seg_feature = data_seg[:, :no_longfeature]
        self.data_seg_label = data_seg[:, no_longfeature:no_longfeature+1]
        
        # Its important to have random state set equal for Training and test dataset
        train_feature, test_feature, train_label, test_label = train_test_split(
            data_seg_feature, self.data_seg_label,random_state=0, shuffle=True,stratify=self.data_seg_label)

        # Check the class label splits to maintain balance
        unique, counts = np.unique(self.data_seg_label, return_counts=True)
        left_perc = counts[0]/sum(counts)
        if left_perc < 0.4 or left_perc > 0.6:
            print("Imbalanced dataset with split of: ",left_perc,1-left_perc)
        else:
            print("Classes balanced.")
        unique, counts = np.unique(train_label, return_counts=True)
        print("Class label splits in training set \n ",np.asarray((unique, counts)).T)
        unique, counts = np.unique(test_label, return_counts=True)
        print("Class label splits in test set\n ",np.asarray((unique, counts)).T)



        # normalization
        # before normalize reshape data back to raw data shape
        train_feature_2d = train_feature.reshape([-1, no_feature])
        test_feature_2d = test_feature.reshape([-1, no_feature])

        scaler1 = StandardScaler().fit(train_feature_2d)
        train_fea_norm1 = scaler1.transform(train_feature_2d) # normalize the training data
        test_fea_norm1 = scaler1.transform(test_feature_2d) # normalize the test data
        print('After normalization, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter normalization, the shape of test feature:', test_fea_norm1.shape)
        
        # after normalization, reshape data to 3d
        train_fea_norm1 = train_fea_norm1.reshape([-1, segment_length, no_feature])
        test_fea_norm1 = test_fea_norm1.reshape([-1, segment_length, no_feature])
        print('After reshape, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter reshape, the shape of test feature:', test_fea_norm1.shape)
        
        # reshape for data shape: (trial, conv channel, electrode channel, time samples)
        # earlier it was (trial,timesamples,electrode_channel)
        train_fea_reshape1 = np.swapaxes(np.expand_dims(train_fea_norm1,1),2,3)
        test_fea_reshape1 = np.swapaxes(np.expand_dims(test_fea_norm1,1),2,3)
        print('After expand dims, the shape of training feature:', train_fea_reshape1.shape,
              '\nAfter expand dims, the shape of test feature:', test_fea_reshape1.shape)
        
        self.data = torch.tensor(train_fea_reshape1)
        self.targets = torch.tensor(train_label.flatten()).long()
        
        print("data and target type:",type(self.data),type(self.targets))


    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        return data, target
    
    def get_class_weights(self):
        class_weights=class_weight.compute_class_weight('balanced',np.unique(self.data_seg_label),
                                                        self.data_seg_label[:,0])
        return class_weights



In [None]:
class EEGMMIDTsSet(Data.Dataset):
    def __init__(self, subject_id, transform=None):
        root_dir = "../Deep-Learning-for-BCI/dataset/"
#         dataset = np.load(root_dir + str(subject_id) + '.npy')
        dataset_raw = np.load(root_dir + str(subject_id) + '.npy')
        dataset=[]  # feature after filtering

        # EEG Gamma pattern decomposition
        for i in range(dataset_raw[:,:-1].shape[1]):
            x = dataset_raw[:, i]
            fs = 160.0
            lowcut = 8.0
            highcut = 30.0
            y = butter_bandpass_filter(x, lowcut, highcut, fs, order=3)
            dataset.append(y)
        dataset=np.array(dataset).T
        dataset=np.hstack((dataset,dataset_raw[:,-1:]))
        # keep 4,5 which are left and right fist open close imagery classes, remove rest
        # refer 1-Data.ipynb for the details
        removed_label = [0,1,6,7,8,9,10]  # [0,1,2,3,4,5,10] for hf # [0,1,6,7,8,9,10] for lr
        for ll in removed_label:
            id = dataset[:, -1]!=ll
            dataset = dataset[id]

        # Pytorch needs labels to be sequentially ordered starting from 0
        dataset[:, -1][dataset[:, -1] == 2] = 0
        dataset[:, -1][dataset[:, -1] == 4] = 0
        dataset[:, -1][dataset[:, -1] == 3] = 1
        dataset[:, -1][dataset[:, -1] == 5] = 1
#         dataset[:, -1][dataset[:, -1] == 10] = 2
        
        # data segmentation
        n_class = 2 #int(11-len(removed_label))  # 0~9 classes ('10:rest' is not considered)
        no_feature = 64  # the number of the features
        segment_length = 160 #160  # selected time window; 16=160*0.1
        
        #Overlapping is removed to avoid training set overlap with test set
        data_seg = extract(dataset, n_classes=n_class, n_fea=no_feature, 
                           time_window=segment_length, moving=(segment_length))  # /2 for 50% overlapping
        print('After segmentation, the shape of the data:', data_seg.shape)

        # split training and test data
        no_longfeature = no_feature*segment_length
        data_seg_feature = data_seg[:, :no_longfeature]
        data_seg_label = data_seg[:, no_longfeature:no_longfeature+1]
        # Its important to have random state set equal for Training and test dataset
        train_feature, test_feature, train_label, test_label = train_test_split(
            data_seg_feature, data_seg_label,random_state=0, shuffle=True,stratify=data_seg_label)

        # Check the class label splits to maintain balance
        unique, counts = np.unique(data_seg_label, return_counts=True)
        left_perc = counts[0]/sum(counts)
        if left_perc < 0.4 or left_perc > 0.6:
            print("Imbalanced dataset with split of: ",left_perc,1-left_perc)
        else:
            print("Classes balanced.")
        unique, counts = np.unique(train_label, return_counts=True)
        print("Class label splits in training set \n ",np.asarray((unique, counts)).T)
        unique, counts = np.unique(test_label, return_counts=True)
        print("Class label splits in test set\n ",np.asarray((unique, counts)).T)



        # normalization
        # before normalize reshape data back to raw data shape
        train_feature_2d = train_feature.reshape([-1, no_feature])
        test_feature_2d = test_feature.reshape([-1, no_feature])

        scaler1 = StandardScaler().fit(train_feature_2d)
        train_fea_norm1 = scaler1.transform(train_feature_2d) # normalize the training data
        test_fea_norm1 = scaler1.transform(test_feature_2d) # normalize the test data
        print('After normalization, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter normalization, the shape of test feature:', test_fea_norm1.shape)

        # after normalization, reshape data to 3d
        train_fea_norm1 = train_fea_norm1.reshape([-1, segment_length, no_feature])
        test_fea_norm1 = test_fea_norm1.reshape([-1, segment_length, no_feature])
        print('After reshape, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter reshape, the shape of test feature:', test_fea_norm1.shape)
        
        
        
        # reshape for data shape: (trial, conv channel, electrode channel, time samples)
        # earlier it was (trial,timesamples,electrode_channel)
        train_fea_reshape1 = np.swapaxes(np.expand_dims(train_fea_norm1,1),2,3)
        test_fea_reshape1 = np.swapaxes(np.expand_dims(test_fea_norm1,1),2,3)
        print('After expand dims, the shape of training feature:', train_fea_reshape1.shape,
              '\nAfter expand dims, the shape of test feature:', test_fea_reshape1.shape)
        
        self.data =  torch.tensor(test_fea_reshape1)
        self.targets = torch.tensor(test_label.flatten()).long()
        
        print("data and target type:",type(self.data),type(self.targets))

    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        return data, target


In [None]:
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (22, 1), (1, 1)), # 22 when using 64 channels # 7 for 21,19,18 channels
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )


    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.shallownet(x.float())
        x = self.projection(x)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes):
        super().__init__()
        
        # global average pooling
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )
        # 3000 for 21 channels 1 s # 2600 for top 19 channels # 2400,18
        self.fc = nn.Sequential(
            nn.Linear(8600, 256), # 25800 for 2s, 8600 for 1s for 64 channels
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(256, 32),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(32, n_classes) #4 # change here for classes
        )

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        out = self.fc(x)
        return out


class Conformer(nn.Sequential):
    def __init__(self, emb_size=40, depth=2, n_classes=2, **kwargs):
        super().__init__(

            PatchEmbedding(emb_size),
            TransformerEncoder(depth, emb_size),
            ClassificationHead(emb_size, n_classes)
        )

In [None]:
def train_eegmmid(task_type, strat, sub_id, i=""):
    

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_ds = EEGMMIDTrSet(subject_id=sub_id)
    test_ds = EEGMMIDTsSet(subject_id=sub_id)
    class_weights = torch.tensor(train_ds.get_class_weights(),dtype=torch.float,device=device)
    scenario = ni_benchmark(train_dataset=train_ds,test_dataset=test_ds,
                       n_experiences=5,task_labels=True)

#     tb_logger = TensorboardLogger()
    text_logger = TextLogger(open('eegmmidlog.txt', 'a'))
#     int_logger = InteractiveLogger()

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),  # forward = - backward su streaming, stream è la media
        class_accuracy_metrics(epoch=True, stream=True, classes=list(range(scenario.n_classes))),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        timing_metrics(epoch=True, epoch_running=True),
        forgetting_metrics(experience=True, stream=True),
        cpu_usage_metrics(experience=True),
        gpu_usage_metrics(0, experience=True),
        disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=[text_logger]
    )

    es = EarlyStoppingPlugin(patience=50, val_stream_name="train_stream")
    
    # Changed learning rate and betas based on the parameters from Conformer paper
    results = []
    model = Conformer(n_classes=2).cuda()
    if (strat == "naive"):
        print("Naive continual learning")
        strategy = Naive(model, Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.99)), CrossEntropyLoss(weight=class_weights), train_epochs=100, eval_every=10, plugins=[es], evaluator=eval_plugin, device=device)
    elif (strat == "offline"):
        print("Offline learning")
        strategy = JointTraining(model, Adam(model.parameters(), lr=0.0005, betas=(0.5, 0.99)), CrossEntropyLoss(weight=class_weights), train_epochs=1000, eval_every=10, plugins=[es], evaluator=eval_plugin, device=device, train_mb_size=25)
    elif (strat == "cumulative"):
        print("Cumulative continual learning")
        strategy = Cumulative(model, Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.99)), CrossEntropyLoss(weight=class_weights), train_epochs=100, eval_every=10, plugins=[es], evaluator=eval_plugin, device=device,train_mb_size=25)
    elif (strat == "replay"):
        print("Replay training")
        strategy = Replay(model, Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.99)), CrossEntropyLoss(weight=class_weights), train_epochs=100, eval_every=10, plugins=[es], evaluator=eval_plugin, device=device, mem_size=25, train_mb_size=25)  #circa 25% of ASCERTAIN
    elif (strat == "lwf"):
        print("LwF continual learning")
        strategy = LwF(model, Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.99)), CrossEntropyLoss(weight=class_weights), train_epochs=100, eval_every=10, plugins=[es], evaluator=eval_plugin, device=device, alpha=0.5, temperature=1)
    elif (strat == "ewc"):
        print("EWC continual learning")
        torch.backends.cudnn.enabled = False
        strategy = EWC(model, Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.99)), CrossEntropyLoss(weight=class_weights), train_epochs=100, eval_every=10, plugins=[es], evaluator=eval_plugin, device=device, ewc_lambda=0.99)
    elif (strat == "episodic"):
        print("Episodic continual learning")
        strategy = GEM(model, Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.99)), CrossEntropyLoss(weight=class_weights), train_epochs=1, eval_every=10, plugins=[es], evaluator=eval_plugin, device=device, patterns_per_exp=70)

    thisresults = []

    print(i + ".")
    start = time.time()
    if strat == "offline":
        res = strategy.train(scenario.train_stream)
#         print("-------------Train-----------")
#         print(res)
        r = strategy.eval(scenario.test_stream)
#         print("-------------Test-----------")
#         print(r)
        thisresults.append({"task_type":task_type,
                            "strategy":strat,
                            "sub_id":sub_id,
                            "iteration":i,
                            "loss":r["Loss_Exp/eval_phase/test_stream/Task000/Exp000"],
                            "acc":(float(r["Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000"])*100),
                            "acc0":(float(r["Top1_ClassAcc_Stream/eval_phase/test_stream/Task000/0"])*100),
                            "acc1":(float(r["Top1_ClassAcc_Stream/eval_phase/test_stream/Task000/1"])*100),
#                             "acc2":(float(r["Top1_ClassAcc_Stream/eval_phase/test_stream/Task000/2"])*100),
                            "forg":r["StreamForgetting/eval_phase/test_stream"],
                            "all":r})
        results.append({"task_type":task_type,
                        "strategy":strat,
                        "sub_id":sub_id,
                        "iteration":i,
                        "finalloss":r["Loss_Exp/eval_phase/test_stream/Task000/Exp000"],
                        "finalacc":r["Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000"],
                        "results":thisresults})
        torch.save(model.state_dict(), "./results/eegconformer/eegmmid_ws_" +strat +"_"+ str(sub_id)+ "_model" + i +'.pth')
    else:
        for experience in scenario.train_stream:
            res = strategy.train(experience)
            r = strategy.eval(scenario.test_stream)
            thisresults.append({"task_type":task_type,
                                "strategy":strat,
                                "sub_id":sub_id,
                                "iteration":i,
                                "loss":r["Loss_Exp/eval_phase/test_stream/Task000/Exp000"],
                                "acc":(float(r["Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000"])*100),
                                "acc0":(float(r["Top1_ClassAcc_Stream/eval_phase/test_stream/Task000/0"])*100),
                                "acc1":(float(r["Top1_ClassAcc_Stream/eval_phase/test_stream/Task000/1"])*100),
#                                 "acc2":(float(r["Top1_ClassAcc_Stream/eval_phase/test_stream/Task000/2"])*100),
                                "forg":r["StreamForgetting/eval_phase/test_stream"],
                                "all":r})
        results.append({"task_type":task_type,
                        "strategy":strat,
                        "sub_id":sub_id,
                        "iteration":i,
                        "finalloss":r["Loss_Exp/eval_phase/test_stream/Task000/Exp000"],
                        "finalacc":r["Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000"],
                        "results":thisresults})
    elapsed = time.time() - start
#     results.append({"time":elapsed})
    with open("./results/eegconformer/eegmmid_ws_" + strat +"_"+ str(sub_id)+ "_results" + i + ".pkl", "wb") as outfile:
        pickle.dump(results, outfile)
    print("\t" + str(elapsed) + " seconds")

for s_id in [7, 12, 22, 42, 43, 48, 49, 53, 70, 80, 82, 85, 94, 102]:
    print("\n --------------------------------------------------- \n")
    print("Starting for subject id:",s_id)
    for itr in range(5):
        train_eegmmid(task_type="within_sub",strat="offline", sub_id=s_id, i=str(itr))