<a href="https://colab.research.google.com/github/Napomini/Individual_models/blob/main/EEGNet_MSTANN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# EEG Signal Acquisition


##Downloading BCI Competition IV 2a Dataset


In [None]:
!wget https://www.bbci.de/competition/download/competition_iv/BCICIV_2a_gdf.zip

In [None]:
!mkdir -p /content/cleaned_data/

In [None]:
%%capture
!unzip /content/BCICIV_2a_gdf.zip -d raw_data

# Install Packages

In [None]:
%%capture
!pip install mne

In [None]:
%%capture
!pip install torch-summary

In [None]:
%%capture
!pip install captum

In [None]:
%%capture
!pip install --upgrade mne == 1.4.2 numpy == 1.23.5 --quite

In [None]:
%%capture
!pip install tensorflow==2.15.0

In [None]:
import os
os.kill (os.getpid( ), 9)

# Libraries


In [None]:
import os
import mne
import math
import copy
import gdown
import random
import scipy.io
import numpy as np
import pandas as pd
import seaborn as sn
import tensorflow as tf
import matplotlib.pyplot as plt
from captum.attr import DeepLift
from tensorflow import keras


# Torch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split

# Scikit-Learn
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split

# Building Dataset

In [None]:
!mkdir -p /content/cleaned_data/first_session
!mkdir -p /content/cleaned_data/second_session

# First Session

##Pre-processing

In [None]:
raw_data_folder = '/content/raw_data/'
cleaned_data_folder = '/content/cleaned_data/first_session/'
files = os.listdir(raw_data_folder)

# Selecting files with suffix 'T.gdf'
filtered_files = [file for file in files if file.endswith('T.gdf')]

raw_list = []

# Iterating through filtered files
for file in filtered_files:
    file_path = os.path.join(raw_data_folder, file)

    # Reading raw data
    raw = mne.io.read_raw_gdf(file_path, eog=['EOG-left', 'EOG-central', 'EOG-right'], preload=True)
    # Droping EOG channels
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])

    # High Pass Filtering 4-40 Hz
    raw.filter(l_freq=4, h_freq=40, method='iir')

    # Notch filter for Removal of Line Voltage
    raw.notch_filter(freqs=50)

    # Saving the modified raw data to a file with .fif suffix
    new_file_path = os.path.join(cleaned_data_folder, file[:-4] + '.fif')
    raw.save(new_file_path, overwrite=True)
    # Appending data to the list
    raw_list.append(raw)

final_raw = mne.concatenate_raws(raw_list)
new_file_path = os.path.join(cleaned_data_folder, 'First_Session_Subjects.fif')
final_raw.save(new_file_path, overwrite=True)

In [None]:
events = mne.events_from_annotations(final_raw)
events[1]

In [None]:
epochs = mne.Epochs(final_raw, events[0], event_id=[7, 8, 9, 10], tmin=0, tmax=4, reject=None, baseline=None, preload=True)
first_session_data = epochs.get_data(copy=True)
first_session_labels = epochs.events[:,-1]

In [None]:
print("First_session_dataset shape:",first_session_data.shape)

# Second Session

In [None]:
# Replace this with your actual shareable link
shareable_link = 'https://drive.google.com/file/d/11Ke2Xta1kv2xu2Mybuu_X51zJYjQ-VFo/view?usp=drive_link'

# Extract file ID from the shareable link
file_id = shareable_link.split('/d/')[1].split('/view')[0]

# Create the direct download link
download_url = f'https://drive.google.com/uc?id={file_id}&export=download'

# Specify the output file path
output_file = 'true_labels.zip'

# Download the file
gdown.download(download_url, output_file, quiet=False)

In [None]:
%%capture
!unzip /content/true_labels.zip -d second_session_labels

In [None]:
raw_data_folder = '/content/raw_data/'
cleaned_data_folder = '/content/cleaned_data/second_session/'
mat_folder = '/content/second_session_labels/'

# Selecting files with suffix 'E.mat'
mat_files = os.listdir(mat_folder)
filtered_math_labels = [file for file in mat_files if file.endswith('E.mat')]

# Selecting files with suffix 'E.gdf'
files = os.listdir(raw_data_folder)
filtered_files = [file for file in files if file.endswith('E.gdf')]

raw_list = []
second_session_labels = np.array([])
# Iterating through filtered files
for file in filtered_files:
    file_path = os.path.join(raw_data_folder, file)

    # Reading raw data
    raw = mne.io.read_raw_gdf(file_path, eog=['EOG-left', 'EOG-central', 'EOG-right'], preload=True)
    # Droping EOG channels
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])

    # High Pass Filtering 4-40 Hz
    raw.filter(l_freq=4, h_freq=40, method='iir')

    # Saving the modified raw data to a file with .fif suffix
    new_file_path = os.path.join(cleaned_data_folder, file[:-4] + '.fif')
    raw.save(new_file_path, overwrite=True)
    # Appending data to t he list
    raw_list.append(raw)

    # Mat files for the labels
    mat_file_name = file.replace('.gdf', '.mat')
    mat_file_path = os.path.join(mat_folder, mat_file_name)
    print(f"data:{file}, label:{mat_file_name}")

    if os.path.exists(mat_file_path):
        mat_data = scipy.io.loadmat(mat_file_path)
        class_labels = mat_data.get('classlabel', [])

        # Check if 'classlabel' key exists and is not empty
        if class_labels.size > 0:
             # Convert to a NumPy array and flatten
            class_labels_array = np.array(class_labels, dtype=int).flatten()
            # Concatenate with the existing test_labels array
            second_session_labels = np.concatenate((second_session_labels, class_labels_array))
        else:
            print(f"Warning: 'classlabel' not found or empty in {mat_file_name}.")
    else:
        print(f"Warning: {mat_file_name} not found.")

final_raw = mne.concatenate_raws(raw_list)
new_file_path = os.path.join(cleaned_data_folder, 'Second_Session_Subjects.fif')
final_raw.save(new_file_path, overwrite=True)

In [None]:
events = mne.events_from_annotations(final_raw)
events[1]

In [None]:
epochs = mne.Epochs(final_raw, events[0], event_id=7, tmin=0, tmax=4, reject=None, baseline=None, preload=None)
second_session_data = epochs.get_data(copy=True)

In [None]:
print("Second Session Dataset shape:",second_session_data.shape)

# Structuring Data

In [None]:
# Choosing Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loss Function
criterion = nn.CrossEntropyLoss()

# Normalizing Labels to [0, 1, 2, 3]
y_train = first_session_labels - np.min(first_session_labels)
y_test = second_session_labels - np.min(second_session_labels)

# Normalizing Input features: z-score(mean=0, std=1)
X_first_session = (first_session_data - np.mean(first_session_data)) / np.std(first_session_data)
X_second_session = (second_session_data - np.mean(second_session_data)) / np.std(second_session_data)

X = np.concatenate((X_first_session, X_second_session))
y = np.concatenate((y_train, y_test))

# Spliting  Data: 90% for Train and 10% for Test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42, stratify=y)

# Converting to Tensor
X_train = torch.Tensor(X_train).unsqueeze(1).to(device)
X_test = torch.Tensor(X_test).unsqueeze(1).to(device)
y_train = torch.LongTensor(y_train).to(device)
y_test = torch.LongTensor(y_test).to(device)

# Creating Tensor Dataset
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

# Printing the sizes
print("Size of X_train:", X_train.size())
print("Size of X_test:", X_test.size())
print("Size of y_train:", y_train.size())
print("Size of y_test:", y_test.size())

# Class Training

In [None]:
class TrainModel():
    def __init__(self,):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def train_model(self, model, train_dataset, learning_rate=0.001, batch_size=64, epochs=500):
        model = model.to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        highest_train_accuracy = 0.0

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            for inputs, labels in train_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                optimizer.zero_grad()
                outputs, _ = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_accuracy = correct / total
            if epoch_accuracy > highest_train_accuracy:
                highest_train_accuracy = epoch_accuracy
            print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {(epoch_accuracy*100):.2f}%")

        average_loss = running_loss / len(train_loader.dataset)
        print("Average Loss:", average_loss)
        print(f"Highest Train Accuracy:{(highest_train_accuracy*100):.2f}")

        # Saving model
        torch.save(model.state_dict(), 'eegnet_model.pth')
        return model

# Evaluating Model

In [None]:
class EvalModel():
    def __init__(self, model):
        self.model = model.to(device)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    def test_model(self, test_dataset):
        self.model.eval()
        correct = 0
        total = 0
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs, _ = self.model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = (correct / total) * 100
        print("/------------------------------/")
        print(f"Test Accuracy: {accuracy:.2f}%")
        print("/------------------------------/")
        return accuracy

    def plot_confusion_matrix(self, test_dataset, classes):
        self.model.eval()
        y_pred = []
        y_true = []
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs, _ = self.model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                y_pred.append(predicted.item())
                y_true.append(labels.item())

        cf_matrix = confusion_matrix(y_true, y_pred)
        cf_matrix = cf_matrix.astype('float') / cf_matrix.sum(axis=1)[:, np.newaxis]

        df_cm = pd.DataFrame(cf_matrix, index=classes, columns=classes)

        plt.figure(figsize=(10, 7))
        sn.heatmap(df_cm, annot=True, cmap='Blues', fmt='.2f')
        plt.xlabel('Predicted labels')
        plt.ylabel('True labels')
        plt.title('Confusion Matrix')
        plt.savefig('confusion_matrix_model.png')
        plt.show()

# EEGNet Model

In [None]:
class EEGNetModel(nn.Module): # EEGNET-8,2
    def __init__(self, chans=22, classes=4, time_points=1001, temp_kernel=32,
                 f1=16, f2=32, d=2, pk1=8, pk2=16, dropout_rate=0.5, max_norm1=1, max_norm2=0.25):
        super(EEGNetModel, self).__init__()
        # Calculating FC input features
        linear_size = (time_points//(pk1*pk2))*f2

        # Temporal Filters
        self.block1 = nn.Sequential(
            nn.Conv2d(1, f1, (1, temp_kernel), padding='same', bias=False),
            nn.BatchNorm2d(f1),
        )
        # Spatial Filters
        self.block2 = nn.Sequential(
            nn.Conv2d(f1, d * f1, (chans, 1), groups=f1, bias=False), # Depthwise Conv
            nn.BatchNorm2d(d * f1),
            nn.ELU(),
            nn.AvgPool2d((1, pk1)),
            nn.Dropout(dropout_rate)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(d * f1, f2, (1, 16),  groups=f2, bias=False, padding='same'), # Separable Conv
            nn.Conv2d(f2, f2, kernel_size=1, bias=False), # Pointwise Conv
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d((1, pk2)),
            nn.Dropout(dropout_rate)
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(linear_size, classes)

        # Apply max_norm constraint to the depthwise layer in block2
        self._apply_max_norm(self.block2[0], max_norm1)

        # Apply max_norm constraint to the linear layer
        self._apply_max_norm(self.fc, max_norm2)

    def _apply_max_norm(self, layer, max_norm):
        for name, param in layer.named_parameters():
            if 'weight' in name:
                param.data = torch.renorm(param.data, p=2, dim=0, maxnorm=max_norm)

    def forward(self, x):
        x = self.block1(x)
        temporal_features = x
        x = self.block2(x)
        spatial_features1 = x
        x = self.block3(x)
        spatial_features2 = x
        x = self.flatten(x)
        x = self.fc(x)
        return x, [temporal_features, spatial_features1, spatial_features2]

# Model Summary

In [None]:
input_size = (1, 22, 1001)
eegnet_model = EEGNetModel().to(device)
summary(eegnet_model, input_size)

# Training Model

In [None]:
eegnet_model = EEGNetModel().to(device)

# Training Hyperparameters
EPOCHS = 500
BATCH_SIZE = 64
LEARNING_RATE = 0.001
trainer = TrainModel()
trained_eegnet_model = trainer.train_model(eegnet_model, train_dataset, learning_rate=LEARNING_RATE,
                                   batch_size=BATCH_SIZE, epochs=EPOCHS)
torch.save(trained_eegnet_model.state_dict(), 'eegnet_model.pth')

# Evaluating Model (Confusion Matrix)

In [None]:
classes_list = ['Left', 'Right', 'Foot', 'Tongue']
eval_model = EvalModel(trained_eegnet_model)
test_accuracy = eval_model.test_model(test_dataset)
eval_model.plot_confusion_matrix(test_dataset, classes_list)

# MSTANN Model

In [None]:
class MSM(nn.Module): # Multi-Scale Module
    def __init__(self, chans=22, time_points=1001, f1=36, f2=54, f3=108, f4=216):
        super(MSM, self).__init__()
        self.conv1 = nn.Conv1d(chans, f1, kernel_size=3, padding='same')
        self.conv2 = nn.Conv1d(chans, f1, kernel_size=11, padding='same')
        self.conv3 = nn.Conv1d(chans, f1, kernel_size=19, padding='same')
        self.conv4 = nn.Conv1d(f3, f2, kernel_size=1, padding='same')
        self.conv5 = nn.Conv1d(f4, f3, kernel_size=1, padding='same')
        self.mp1 = nn.MaxPool1d(kernel_size=7, stride=2, padding=3)
        self.mp2 = nn.MaxPool1d(kernel_size=19, stride=2, padding=9)
        self.mp3 = nn.MaxPool1d(kernel_size=31, stride=2, padding=15)

    def forward(self, x):
        # Parallel Convs
        conv1_out = self.conv1(x)
        temporal_features1 = conv1_out
        conv2_out = self.conv2(x)
        temporal_features2 = conv2_out
        conv3_out = self.conv3(x)
        temporal_features3 = conv3_out
        convs_cat = torch.cat([conv1_out, conv2_out, conv3_out], dim=1)
        convs_cat_t = torch.cat([conv1_out, conv2_out, conv3_out], dim=1).permute(0,2,1)
        temporal_features_cat1 = convs_cat_t
        # Max Poolings
        mp1 = self.mp1(convs_cat_t).permute(0,2,1)
        mp2 = self.mp2(convs_cat_t).permute(0,2,1)
        mp3 = self.mp3(convs_cat_t).permute(0,2,1)
        # Second Concat
        conv4_out = self.conv4(convs_cat)
        temporal_features4 = conv4_out
        convs_cat2 = torch.cat([mp1, mp2, mp3, conv4_out], dim=1)
        temporal_features_cat2 = convs_cat2
        out = self.conv5(convs_cat2)
        temporal_features_total = out
        temporal_features = [temporal_features1, temporal_features2, temporal_features3,
                            temporal_features4, temporal_features_cat1, temporal_features_cat2,
                            temporal_features_total]
        return out, temporal_features
class ResidualModule(nn.Module):
    def __init__(self, f1=108, f2=54):
        super(ResidualModule, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv1d(f1, f2, kernel_size=1, padding=0), # Conv6
            nn.BatchNorm1d(f2),
            nn.ReLU()
        )
        self.block2 = nn.Sequential(
            nn.Conv1d(f2, f1, kernel_size=3, padding=1), # Conv7
            nn.BatchNorm1d(f1),
            nn.ReLU()
        )
        self.block3 = nn.Sequential(
            nn.Conv1d(f1, f2, kernel_size=1, padding=0), # Conv8
            nn.BatchNorm1d(f2),
            nn.ReLU()
        )
        self.block4 = nn.Sequential(
            nn.Conv1d(f2, f1, kernel_size=3, padding=1), # Conv9
            nn.BatchNorm1d(f1),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        # Sequential Convs
        x = self.block1(x)  # Conv6
        x = self.block2(x)  # Conv7
        x = self.block3(x)  # Conv8
        x = self.block4(x)  # Conv9

        # Residual Connection
        x += residual
        x = self.relu(x)
        return x

class CTAM(nn.Module):
    def __init__(self, f1=216, linear_size=108, time_points=1001):
        super(CTAM, self).__init__()
        # Channel Attention Module(CAM)
        self.cam_maxpool = nn.Sequential(
            nn.MaxPool1d(kernel_size=1),
            nn.Flatten(),
            nn.Linear(linear_size*time_points, linear_size),
            nn.ReLU(),
            nn.Linear(linear_size, linear_size),
        )
        self.cam_avgpool = nn.Sequential(
            nn.AvgPool1d(kernel_size=1),
            nn.Flatten(),
            nn.Linear(linear_size*time_points, linear_size),
            nn.ReLU(),
            nn.Linear(linear_size, linear_size),
        )

        # Temporal Attention Module(TAM)
        self.tam_maxpool = nn.MaxPool1d(kernel_size=1)
        self.tam_avgpool = nn.AvgPool1d(kernel_size=1)
        self.tam_conv = nn.Conv1d(f1, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # CAM operations
        cam_maxpool_out = self.cam_maxpool(x)
        cam_avgpool_out = self.cam_avgpool(x)
        cam_pool_out = self.sigmoid(cam_maxpool_out + cam_avgpool_out).unsqueeze(2)
        cam_out = x * cam_pool_out
        # TAM operations
        tam_maxpool_out = self.tam_maxpool(cam_out)
        tam_avgpool_out = self.tam_avgpool(cam_out)
        tam_cat = torch.cat((tam_avgpool_out, tam_maxpool_out), dim=1)
        tam_conv_out = self.tam_conv(tam_cat)
        tam_out = self.sigmoid(tam_conv_out)
        tam_out_expanded = tam_out.expand(-1, 108, -1)
        ctam_out = tam_out_expanded * cam_out
        return ctam_out


class MSCTANNModel(nn.Module):
    def __init__(self, chans=22, f1=36, f2=54, f3=108, f4=216,
                 classes=4, time_points=1001, dropout_rate=0.4):
        super(MSCTANNModel, self).__init__()
        linear_size = f3*time_points
        self.classes = classes
        self.msm = MSM(chans=chans,time_points=time_points, f1=f1, f2=f2, f3=f3, f4=f4)
        self.residual_module = ResidualModule(f1=f3, f2=f2)
        self.ctam = CTAM(f1=f4, linear_size=f3, time_points=time_points)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(linear_size, classes)
    def forward(self, x):
        x = x.squeeze(1) # Reducing Input dim from 4 to 3
        x, temporal_features = self.msm(x)
        x = self.residual_module(x)
        x = self.ctam(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = self.dropout(x)
        return x, temporal_features

# Model Summary

In [None]:
input_size = (1, 22, 1001)
msctaan_model = MSCTANNModel().to(device)
summary(msctaan_model, input_size)

# Training Model

In [None]:
msctaan_model = MSCTANNModel().to(device)

# Training Hyperparameters
EPOCHS = 500
BATCH_SIZE = 64
LEARNING_RATE = 0.001
trainer = TrainModel()
trained_msctaan_model = trainer.train_model(msctaan_model, train_dataset, learning_rate=LEARNING_RATE,
                                   batch_size=BATCH_SIZE, epochs=EPOCHS)
torch.save(trained_msctaan_model.state_dict(), 'msctaan_model.pth')

# Evaluating Model (Confusion Matrix)

In [None]:
classes_list = ['Left', 'Right', 'Foot', 'Tongue']
eval_model = EvalModel(trained_msctaan_model)
test_accuracy = eval_model.test_model(test_dataset)
eval_model.plot_confusion_matrix(test_dataset, classes_list)