In [None]:
#----Colab----#
import os
!git clone https://github.com/ARIS2333/My_Project.git
os.chdir('./My_Project')

# 1 Imports

In [None]:
#----system imports----#
import os
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
from pathlib import Path

#----sklearn imports----#
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

#----pytorch imports----#
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as transforms
from torch import optim
import torch.nn.functional as F

#----custom imports----#
from processing_utils import downloader, mat_extractor, cropper, save_processed_data
from train_utils import *
from eva_utils import show_acc_loss_test, evaluation_test, save_results, summary_results

#----settings----#
warnings.filterwarnings('ignore')

# 2 Data

## 2.1 Download data

In [None]:
def download_data():
    path  = 'data/raw'
    dpath = downloader(path, subjects=list(range(1,10)), 
                       url="http://bnci-horizon-2020.eu/database/data-sets/001-2014/")
    return dpath

## 2.2 Processing data

In [None]:
def processing_data():
    tr_names = ['A0' + str(i) + 'T.mat' for i in range(1,10)]
    te_names = ['A0' + str(i) + 'E.mat' for i in range(1,10)]

    dpath_train_all = []
    dpath_test_all = []

    for id in range(len(tr_names)):
        print(f'*** Subject {id+1} ***\n')

        # Extracting the data from the mat files, and applying the preprocessing steps
        x_train, y_train = mat_extractor(path=dpath/tr_names[id], channel_norm = config_processor['if_normalize'], remove_eog = config_processor['if_remove_eog'],
                                        bpf_dict={'apply':config_processor['if_filter'], 'fs': 250, 'lc':config_processor['lc'], 'hc':config_processor['hc'], 'order':config_processor['order']})
        x_test , y_test  = mat_extractor(path=dpath/te_names[id], channel_norm = config_processor['if_normalize'], remove_eog = config_processor['if_remove_eog'],
                                        bpf_dict={'apply':config_processor['if_filter'], 'fs': 250, 'lc':config_processor['lc'], 'hc':config_processor['hc'], 'order':config_processor['order']})
       
       # Changing the labels to start from 0
        y_train, y_test = y_train-1, y_test-1

        # Cropping the data to increase the number of samples
        if config_processor['if_cropper']:
            x_train, y_train = cropper(x_train, y_train, window=config_processor['cropper_window'], step=config_processor['cropper_step'])
            x_test , y_test  = cropper(x_test , y_test, window=config_processor['cropper_window'], step=config_processor['cropper_step'])
            
        # save the data to another folder
        dpath_train = save_processed_data(config_processor['save_path_train'], x_train, y_train, 'A0' + str(id+1)+'T.pt')
        dpath_test = save_processed_data(config_processor['save_path_test'], x_test, y_test, 'A0' + str(id+1)+'E.pt')
        dpath_train_all.append(dpath_train)
        dpath_test_all.append(dpath_test)

    return dpath_train_all, dpath_test_all

## 2.3 Configs

In [None]:
#----Download data----#
dpath = download_data()

#----configuration----#
config_processor = {
    #config for cropper
    'if_cropper': False,
    'cropper_window': 500,
    'cropper_step': 500,

    #config for filter
    'if_filter': False,
    'lc': 4,
    'hc': 40,
    'order': 3,

    # config for normalizer
    'if_normalize': False,

    # config for EOG
    'if_remove_eog': False,

    # config for saving
    'save_path_train': 'data/processed/train',
    'save_path_test': 'data/processed/test',
    }

#----process data----#
dpath_train_all, dpath_test_all = processing_data()

# 3 Torch utils

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.deterministic=True
    # torch.backends.cudnn.benchmark = False

## 3.1 Dataset

In [None]:
class feature_dataset(Dataset):
    def __init__(self,file_path,transform=None):
        self.file_path = file_path
        self.data, self.label = self.parse_data_file(file_path)
        self.transform = transform

    def parse_data_file(self,file_path):
        pt_file = torch.load(file_path)
        data, label = pt_file['data'], pt_file['label']

        # data
        data = torch.tensor(data)
        data = data.unsqueeze(1)
        # label
        label = torch.tensor(label)
        
        return data.float(), label.long()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        target = self.label[index]

        return x, target
    

## 3.2 DataLoader

In [None]:
def getDataLoader(train_path, test_path, 
                  data_transforms = transforms.Compose([transforms.ToTensor()]), 
                  batch_size=32):
    set_seed(config_train['seed'])

    train_dataset = feature_dataset(train_path, transform=data_transforms)
    test_dataset = feature_dataset(test_path, transform=data_transforms)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

## 3.3 Model blocks

In [None]:
# a block that enables us to enter customized functions in the structure of an nn.Module
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
        
    def forward(self, x):
        return self.func(x)

# a function for adding Flatten layers to Conv2d architectures
def myreshape(xb):
    return xb.view(-1,xb.shape[1]*xb.shape[3])


# constrained blocks are required for implementing of EEGNet and ShallowConvNet
class Conv2dConstrained(nn.Conv2d):
    def __init__(self, *args, max_norm=1, **kwargs):
        self.max_norm = max_norm
        super(Conv2dConstrained, self).__init__(*args, **kwargs)

    def forward(self, x):
        self.weight.data = torch.renorm(
            self.weight.data, p = 2, dim = 0, maxnorm = self.max_norm
        )
        return super(Conv2dConstrained, self).forward(x)

class LinearConstrained(nn.Linear):
    def __init__(self, *args, max_norm = 0.25, **kwargs):
        self.max_norm = max_norm
        super(LinearConstrained, self).__init__(*args, **kwargs)

    def forward(self, x):
        self.weight.data = torch.renorm(
            self.weight.data, p = 2, dim = 0, maxnorm = self.max_norm
        )
        return super(LinearConstrained, self).forward(x)
    
# a class that allows us to define linear layers without specifying in_features
class LinearModified(nn.Module):
    def __init__(self, out_features, bias=False, max_norm=None):
        super().__init__()
        self.in_features = None
        self.out_features = out_features
        self.bias = bias
        self.max_norm = max_norm
        self.__built = False
        self.lin = 0
        
    def forward(self, xb):
        assert xb.ndim == 2, 'xb should have 2 dimensions'
        if self.__built == False:
            self.__built = True
            self.in_features = xb.shape[1]
            dev = 'cpu' if xb.get_device == -1 else 'cuda'
            if self.max_norm == None:
                self.lin = nn.Linear(self.in_features, self.out_features, bias=self.bias).to(dev)
            else:
                self.lin = LinearConstrained(self.in_features, self.out_features, max_norm=self.max_norm, bias=self.bias).to(dev)
        xb = self.lin(xb)
        return xb

# 4 Models

## 4.1 EEGNet

In [None]:
class EEGNet(nn.Module):
    def __init__(self, classes_num):
        super(mine, self).__init__()

        self.drop_out = 0.5

        self.block_1 = nn.Sequential(
            # Pads the input tensor boundaries with zero
            # left, right, up, bottom
            nn.ZeroPad2d((31+62, 32+63, 0, 0)),
            nn.Conv2d(
                in_channels=1,  # input shape (1, C, T)
                out_channels=8,  # num_filters
                kernel_size=(1, 64),  # filter size
                bias=False
            ),  # output shape (8, C, T)
            nn.BatchNorm2d(8)  # output shape (8, C, T)
        )

        # block 2 and 3 are implementations of Depthwise Convolution and Separable Convolution
        self.block_2 = nn.Sequential(
            nn.Conv2d(
                in_channels=8,  # input shape (8, C, T)
                out_channels=16,  # num_filters
                kernel_size=(22, 1),  # filter size
                groups=8,
                bias=False
            ),  # output shape (16, 1, T)
            nn.BatchNorm2d(16),  # output shape (16, 1, T)
            nn.ELU(),
            nn.AvgPool2d((1, 4)),  # output shape (16, 1, T//4)
            nn.Dropout(self.drop_out)  # output shape (16, 1, T//4)
        )

        self.block_3 = nn.Sequential(
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(
                in_channels=16,  # input shape (16, 1, T//4)
                out_channels=16,  # num_filters
                kernel_size=(1, 16),  # filter size
                groups=16,
                bias=False
            ),  # output shape (16, 1, T//4)
            nn.Conv2d(
                in_channels=16,  # input shape (16, 1, T//4)
                out_channels=16,  # num_filters
                kernel_size=(1, 1),  # filter size
                bias=False
            ),  # output shape (16, 1, T//4)
            nn.BatchNorm2d(16),  # output shape (16, 1, T//4)
            nn.ELU(),
            nn.AvgPool2d((1, 8)),  # output shape (16, 1, T//32)
            nn.Dropout(self.drop_out)
        )

        self.out = nn.Linear((16 * 35), classes_num)

    def forward(self, x):
        x = self.block_1(x)
        # print("block1", x.shape)
        
        x = self.block_2(x)
        # print("block2", x.shape)

        x = self.block_3(x)
        # print("block3", x.shape)

        x = x.view(x.size(0), -1)
        x = self.out(x)
        # return F.softmax(x, dim=1), x  # return x for visualization
        return F.softmax(x, dim=1)

# 5 Training

In [None]:
##############################################################################
#                            Configurations                                  #
##############################################################################
config_train = {
    # Name
    'trial': '1', # trial number, used for naming the saved files

    # Hyperparameters
    'batch_size': 32,
    'lr': 0.0001,
    'epochs': 5000,
    'weight_decay': 1e-4,
    'seed': 42,
    'device': torch.device('cuda'),

    # Early stopping
    'if_EarlyStop': True, # if we want to use early stopping
    'patience': 300, # number of epochs to wait before early stopping
    'attribute': 'acc', # attribute to monitor for early stopping, 'acc' or 'loss'

    # Print
    'period': 100 # print every period epochs
}
set_seed(config_train['seed'])
model_init = EEGNet(classes_num=4).to(config_train['device'])
optimizer_init = optim.Adam(model_init.parameters(),lr=config_train['lr'],weight_decay=config_train['weight_decay'], eps=1e-8)
criterion = nn.CrossEntropyLoss()
##############################################################################
#                            Training                                        #
##############################################################################
#----Lists for each subject----#
history_allSubjects = []
checkpoint_allSubjects = []

for subject in range(9):
    print('#########################################\n#########################################\n#########################################')
    print(f'----------------Subject {subject+1}----------------\n')

    #----DataLoader----#
    train_dataloader, test_dataloader = getDataLoader(dpath_train_all[subject], dpath_test_all[subject], batch_size=config_train['batch_size'])
    #----Model, Optimizer, Loss----#
    model = model_init
    optimizer = optimizer_init
    loss_fn = criterion.to(config_train['device'])
    #----Training----#
    e_stop = EarlyStopping(state=config_train['if_EarlyStop'], patience=config_train['patience'], attribute=config_train['attribute'])
    results = train(model, optimizer,
                    train_dataloader, test_dataloader,
                    epochs=config_train['epochs'],
                    loss_func=loss_fn, period=config_train['period'],
                    er_stop=e_stop)


    history, checkpoint = results
   
        
    history_allSubjects.append(history)
    checkpoint_allSubjects.append(checkpoint)

##############################################################################
#                            Evaluation                                      #
##############################################################################
# acc_all, kappa_all, precision_all, recall_all = [], [], [], []
# for i in range(9):
#     print(f'*** Subject {i+1} ***\n')

#     checkpoint = checkpoint_allSubjects[i]
#     history = history_allSubjects[i]
    
#     # checkpoint = checkpoint_allSubjects[0]
#     # history = history_allSubjects[0]
#     model = model_init
#     _, test_dataloader = getDataLoader(dpath_train_all[i], dpath_test_all[i], batch_size=config_train['batch_size'])
#     show_acc_loss_test(history, subject=i)
#     acc, kappa, precision, recall = evaluation_test(config_train, checkpoint, model, test_dataloader, subject=i) if checkpoint is not None else (None, None, None, None)
    
#     acc_all.append(acc)
#     kappa_all.append(kappa)
#     precision_all.append(precision)
#     recall_all.append(recall)
    
# print('\n\n')
# print('*'*50)
# print('*** Summary ***')
# print(f'Accuracy: {np.mean(acc_all):.2f} +/- {np.std(acc_all):.2f}')
# print(f'Kappa: {np.mean(kappa_all):.2f} +/- {np.std(kappa_all):.2f}')
# print(f'Precision: {np.mean(precision_all):.2f} +/- {np.std(precision_all):.2f}')
# print(f'Recall: {np.mean(recall_all):.2f} +/- {np.std(recall_all):.2f}')

##############################################################################
#                            Saving Results                                  #
##############################################################################
acc_all, kappa_all, precision_all, recall_all = [], [], [], []
for i in range(9):
    checkpoint = checkpoint_allSubjects[i]
    history = history_allSubjects[i]
    model = model_init
    _, test_dataloader = getDataLoader(dpath_train_all[i], dpath_test_all[i], batch_size=config_train['batch_size'])
    savings = save_results(save_dir='results', 
                        trial=config_train['trial'],
                        subject=str(i+1))
    savings.info_Preprocessing(config_processor)
    savings.info_Model(config_train)
    savings.info_Training(history)
    savings.info_model_optimizer(checkpoint)
    savings.show_acc_loss(history)
    acc, kappa, precision, recall = savings.evaluation(config_train, checkpoint, model, test_dataloader)
    savings.write_TensorBoard(history)

    acc_all.append(acc)
    kappa_all.append(kappa)
    precision_all.append(precision)
    recall_all.append(recall)

summary_results(acc_all, kappa_all, precision_all, recall_all, save_dir='results', trial=config_train['trial'])
