# EEGNet

2024.04.18 Written by @Chahyunee (Chaehyun Lee)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import datetime
import scipy.io
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# Cross Validation
from sklearn.model_selection import KFold

# Plot
import matplotlib.pyplot as plt

dataset_dir = ''
model_dir = ''
class1_data, class2_data, class3_data, class4_data, class5_data= np.array, np.array, np.array, np.array, np.array



### data load

In [None]:

class_data = {}

for subj_num in range(1, 6):
    load_spec_dir = f'class{subj_num}.mat'
    data = scipy.io.loadmat(dataset_dir + load_spec_dir)
    data = data['data']
    data = np.array(data)
    data = data.reshape(600, order='F')
    print('data shape : ',data.shape)
    
    # Add to dictionary
    class_data[f'class{subj_num}_data'] = np.array(data)

### Device configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### EEG Dataset Class

In [None]:
transform = transforms.Compose([
                transforms.ToTensor(),])

class EEGDataset(Dataset):
    def __init__(self, inputs, labels, transform=None):
        self.inputs = inputs
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        
        eeg_sample = (torch.tensor(self.inputs[idx], dtype=torch.float32),torch.tensor(self.labels[idx], dtype=torch.int8))
        
        return eeg_sample

Extract Data from Dictionary and chane to PyTorch Tensor
- Change this part!!

In [None]:
X_data = np.concatenate([class_data[f'class{i}_data'] for i in range(1, 6)], axis=0)

for d in X_data:
    d = torch.tensor(d)

subj_num = 5
y_labels = torch.tensor(np.concatenate([np.full(600, i) for i in range(5)]), dtype=torch.int64)
y_labels_one_hot = torch.nn.functional.one_hot(y_labels, subj_num)


### Split dataset to train, validation, test
 0.8 train, 0.2 validation

In [None]:
dataset = EEGDataset(X_data, y_labels_one_hot)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = int(0.1 * len(dataset))
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

### Data Loader

In [None]:
batch_size = 16 # If you need, change this parameter.

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
train_size, val_size, test_size

## EEGNet Model Class

In [None]:
class EEGNet(torch.nn.Module):
    """
    Parameters:
    C : number of channels in your EEG dataset
    T : number of time points in one trial (e.g. sec x sampling rate)
    dropout : the rate of dropout  Default: 0.5            
    kernelLength : size of the temporal kernel (e.g. half of T (timepoints))    
    
    F1, F2  : number of temporal filters (F1) and number of pointwise
            filters (F2) to learn. Default: F1 = 8, F2 = F1 * D. 
    D       : number of spatial filters to learn within each temporal
            convolution. Default: D = 2        
            
    num_classes = number of target classses
    """
    def __init__(self, C=64, T=128, dropout=0.5, kernelLength=64, F1=8, D=2, F2=16, num_classes=2):
        super().__init__()
        self.channelConv = torch.nn.Conv2d(1, F1, (1, C), padding=(0, int(C/2)))
        self.bn1 = torch.nn.BatchNorm2d(F1)
        self.depthwiseConv1 = torch.nn.Conv2d(F1, F1, (1, kernelLength), padding=(0, int(kernelLength/2)))
        self.pointwiseConv1 = torch.nn.Conv2d(F1, D * F1, 1)
        
        # Set the maximum value about model weights
        for param in self.pointwiseConv1.parameters():
            if param.dim() > 1:  # except to 1-dim tensor (bias)
                param.data = torch.clamp(param.data, min=-1.0, max=1.0)
        
        
        self.bn2 = torch.nn.BatchNorm2d(D*F1)
        self.elu1 = torch.nn.ELU()
        self.pooling1 = torch.nn.AvgPool2d((1,4))
        # self.pooling1 = torch.nn.MaxPool2d((1,4)) # If you want to try, you can replace AvgPool with MaxPool!
        self.dropout1 = torch.nn.Dropout(dropout)

        self.separableConv = torch.nn.Conv2d( D * F1, D * F1, kernel_size=(1,int(kernelLength/2)), padding=(0,int(kernelLength/4)), bias=False)
        self.pointwiseConv2 = torch.nn.Conv2d(D * F1, F2, 1, bias=False)
        self.bn3 = torch.nn.BatchNorm2d(F2)
        self.elu2 = torch.nn.ELU()
        self.pooling2 = torch.nn.AvgPool2d((1,8))
        # self.pooling2 = torch.nn.MaxPool2d((1,8)) # If you want to try, you can replace AvgPool with MaxPool!
        self.dropout2 = torch.nn.Dropout(dropout)

        self.flatten = torch.nn.Flatten()
        self.linear1 = torch.nn.Linear(F2 * T, num_classes) 
        # self.max_norm = nn.utils.weight_norm(self.linear1, dim=None) # TODO
        self.classifier = torch.nn.Sigmoid() if num_classes == 2 else torch.nn.Softmax(dim=1)



    def forward(self, x):
        
        ##### First layer #####
        x = self.channelConv(x)
        x = self.bn1(x)
        x = self.depthwiseConv1(x)
        x = self.pointwiseConv1(x)
        x = self.bn2(x)
        x = self.elu1(x)
        x = self.pooling1(x)
        x = self.dropout1(x)
        
        
        ##### Second layer #####
        x = self.separableConv(x)
        x = self.pointwiseConv2(x)

        x = self.bn3(x)
        x = self.elu2(x)
        x = self.pooling2(x)
        x = self.dropout2(x)
        out = self.flatten(x)

        x = self.linear1(out)
        x = self.classifier(x)

        return x, out



### model evaluation

In [None]:

from pprint import pprint as pp

def evaluate(true_labels, predicted_labels, subj_num = 5, mode='train'):
    
    result = dict(recall_per_class = [], f1_per_class = [], acc_per_class = [], precision_per_class = [])
    
    if mode == 'train':
        for class_idx in range(subj_num):
            recall_class = recall_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy())
            f1_class = f1_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy())
            acc_class = accuracy_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy())
            precision_class = precision_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy(), zero_division=1)
            
            result['recall_per_class'].append(recall_class)
            result['f1_per_class'].append(f1_class)
            result['acc_per_class'].append(acc_class)
            result['precision_per_class'].append(precision_class)
        
    else:    
        for class_idx in range(subj_num):
            recall_class = recall_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy())
            f1_class = f1_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy())
            acc_class = accuracy_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy())
            precision_class = precision_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy(), zero_division=1)
            
            result['recall_per_class'].append(recall_class)
            result['f1_per_class'].append(f1_class)
            result['acc_per_class'].append(acc_class)
            result['precision_per_class'].append(precision_class)
                    
    result['average_recall'] = sum(result['recall_per_class']) / len(result['recall_per_class'])
    result['average_f1'] = sum(result['f1_per_class']) / len(result['f1_per_class'])
    result['average_acc'] = sum(result['acc_per_class']) / len(result['acc_per_class'])
    result['average_prec'] = sum(result['precision_per_class']) / len(result['precision_per_class'])          
    
    return result


## non k-fold Training

In [None]:

# Set fixed random number seed
seed_n = np.random.randint(500)
print('seed is ' + str(seed_n))
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)

In [None]:
# Change it!
epochs = 300
train_losses, val_losses = [], []

history = {'val_loss': [], 'val_acc': [], 
            'train_loss': [], 'train_acc' : []}


model = EEGNet(num_classes=5)
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:

for epoch in range(epochs):
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        model.train()
    
        optimizer.zero_grad()
        
        inputs = inputs.to(device)
        labels = labels.to(device)
    
        outputs, _ = model(inputs.unsqueeze(1))
        predicted_labels = torch.round(outputs)
        true_labels = labels.float() 
        
        train_loss = criterion(outputs, labels.float())
        train_losses.append(train_loss.cpu())
        
        train_result = evaluate(true_labels, predicted_labels) # dictionary return

        train_loss.backward()
        optimizer.step()

    # evaluation
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs, _ = model(inputs.unsqueeze(1))
            predicted_labels = torch.round(outputs)
            true_labels = labels.float()
            
            val_loss += criterion(outputs, true_labels)
            val_losses.append(val_loss.cpu())
            
            valid_result = evaluate(true_labels, predicted_labels, mode='valid') # dictionary return
            
    

    print(f'\nEpoch {epoch + 1}/{epochs} \n\
        train loss: {train_loss}, valid loss: {val_loss / len(val_loader)}')
    pp(train_result)
    pp(valid_result)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss / len(val_loader))
    history['train_acc'].append(train_result['average_acc'])
    history['val_acc'].append(valid_result['average_acc'])



print('Finished Training')


## Training with k-fold

In [None]:

kfold = KFold(n_splits=5, shuffle=True)

In [None]:

# Set fixed random number seed
torch.manual_seed(42)

# training
epochs = 200
train_losses, val_losses = [], []

for fold, (train_idx, val_idx) in enumerate(kfold.split(train_dataset)):
        
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx) # create index
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx) 
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_subsampler) 
    val_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=val_subsampler)
    

    model = EEGNet(num_classes=5)
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

        
    for epoch in range(epochs):
        
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            model.train()
        
            optimizer.zero_grad()
            
            inputs = inputs.to(device)
            labels = labels.to(device)
        
            outputs, _ = model(inputs.unsqueeze(1))
            predicted_labels = torch.round(outputs)
            true_labels = labels.float()
            
            train_loss = criterion(outputs, labels.float())
            train_losses.append(train_loss)
            
            train_result = evaluate(true_labels, predicted_labels) # dictionary return

            train_loss.backward()
            optimizer.step()

        # evaluation
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs, _ = model(inputs.unsqueeze(1))
                predicted_labels = torch.round(outputs)
                true_labels = labels.float()
                
                val_loss += criterion(outputs, true_labels)
                val_losses.append(val_loss)
                
                valid_result = evaluate(true_labels, predicted_labels, mode='valid') # dictionary return
                
        
    
        print(f'\n{fold+1} fold & Epoch {epoch + 1}/{epochs} \n\
            train loss: {train_loss}, valid loss: {val_loss / len(val_loader)}')
        pp(train_result)
        pp(valid_result)
    
    

print('Finished Training')


Model shape

In [None]:

print(model)

### Test the Model


In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

y_pred = []
y_true = []

# iterate over test data
for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs, _ = model(inputs.unsqueeze(1)) # Feed Network

        predicted_labels = torch.round(outputs)
        print('predicted_labels : ', predicted_labels)

        output = (torch.max(torch.exp(predicted_labels), 1)[1]).cpu().numpy()
        print('output : ', output)
        y_pred.extend(output) # Save Prediction
        
        labels = (torch.max(torch.exp(labels), 1)[1]).cpu().numpy()
        print('labels  : ', labels)
        y_true.extend(labels) # Save Truth

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M")
current_time

In [None]:
# constant for classes
classes = ('class a', 'calss b', 'class c', 'class d', 'class e')

# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                    columns = [i for i in classes])

df_test = pd.DataFrame(cf_matrix)

df_test
df_cm

### Visualization confusion matrix

In [None]:

plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True, cmap="YlGnBu")

# Adding labels to x-axis and y-axis
plt.xlabel('Predicted')
plt.ylabel('True')

acc = valid_result['average_acc']
plt.savefig(f'Figure/class{subj_num}_EEGNet_acc{acc}_{current_time}_output.png')



### Model save code

In [None]:
# model save
model_save_path = f'{model_dir}EEGNet_acc{acc}_{current_time}.pth'
model_save_path
torch.save(model.state_dict(), model_save_path)


### Model Load code

In [None]:
loaded_model = EEGNet(num_classes=5)
loaded_model.load_state_dict(torch.load(model_save_path))
loaded_model = loaded_model.to(device)