# Training the MI and SSVEP models using a simple CNN 

In [1]:
import torch

import os
import pandas as pd
from typing import Tuple

from torch import nn, optim
from torcheval.metrics.functional import (multiclass_f1_score,
                                          binary_accuracy,
                                          binary_f1_score)

from pytorch_lightning import seed_everything

# Code necessary to create reproducible runs
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed_everything(42, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Dataset Abstraction

In [2]:
# data.py
from torch.utils.data import DataLoader, Dataset

class BCIDataset(Dataset):
    def __init__(self, csv_file, base_path, task_type="MI", label_mapping=None):
        # Filter the main dataframe for the specific task (MI or SSVEP)
        self.metadata = pd.read_csv(os.path.join(base_path, csv_file))
        self.metadata = self.metadata[self.metadata["task"] == task_type]
        self.base_path = base_path
        self.task_type = task_type
        self.label_mapping = label_mapping

        # 9 seconds * 250 Hz = 2250 for MI
        # 7 seconds * 250 Hz = 1750 for SSVEP
        self.sequence_length = 2250 if task_type == "MI" else 1750


        num_trials = len(self.metadata)

        self.tensor_data = torch.empty(
            num_trials, self.sequence_length, 8, dtype=torch.float32
        )
        self.labels = torch.empty(num_trials, dtype=torch.long)

        for i, (idx, row) in enumerate(self.metadata.iterrows()):
            # Determine dataset split (train/validation/test)
            id_num = row["id"]
            if id_num <= 4800:
                dataset_split = "train"
            elif id_num <= 4900:
                dataset_split = "validation"
            else:
                dataset_split = "test"

            # Path to the EEG data file
            eeg_path = os.path.join(
                self.base_path,
                row["task"],
                dataset_split,
                row["subject_id"],
                str(row["trial_session"]),
                "EEGdata.csv",
            )

            eeg_data = pd.read_csv(eeg_path)

            # Extract the correct trial segment
            trial_num = int(row["trial"])

            samples_per_trial = self.sequence_length
            start_idx = (trial_num - 1) * samples_per_trial
            end_idx = start_idx + samples_per_trial - 1

            # Select only the 8 EEG channels
            eeg_channels = ["FZ", "C3", "CZ", "C4", "PZ", "PO7", "OZ", "PO8"]
            trial_data = eeg_data.loc[start_idx : end_idx, eeg_channels].values

            # uncomment the line below and comment the one above to include all 18 columns
            # trial_data = eeg_data.loc[start_idx:end_idx-1].values

            # Preprocess the data (see next section)
            processed_data = self.preprocess(trial_data)

            # Convert to tensor
            tensor_data = torch.tensor(processed_data, dtype=torch.float32)
            self.tensor_data[i] = tensor_data

            # Get label if it exists
            if "label" in row and self.label_mapping:
                label_str = row["label"]
                label_int = self.label_mapping[label_str]
                self.labels[i] = label_int

    def __len__(self):
        return len(self.tensor_data)

    def __getitem__(self, idx):
        if self.labels[idx] is not None:
            return self.tensor_data[idx], self.labels[idx]
        else:
            return self.tensor_data[idx]

    def preprocess(self, eeg_data):
        # Apply preprocessing steps here (filtering, normalization, etc.)
        # This will be different for MI and SSVEP
        # ...
        return eeg_data


def load_data(base_path, task_type, label_mapping) -> Tuple[BCIDataset, BCIDataset, BCIDataset]:
    """
    Loads the train, val, test data for the given {task_type} inside the given {base_path}

    Returns:
        a tuple of BCIDataset in the order (train, val, test)
    """

    train = BCIDataset(csv_file="train.csv", base_path=base_path, task_type=task_type, label_mapping=label_mapping)
    val = BCIDataset(
        csv_file="validation.csv", base_path=base_path, task_type=task_type, label_mapping=label_mapping
    )
    test = BCIDataset(csv_file="test.csv", base_path=base_path, task_type=task_type, label_mapping=label_mapping)

    return train, val, test

### Simple CNN Model Architecture

In [3]:
# models/simple_cnn.py

# Simple CNN model
class BCIModel(nn.Module):
    def __init__(self, input_channels, num_classes, sequence_length):
        super(BCIModel, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=3)  # 8 channels
        self.pool = nn.MaxPool1d(2)
        
        # Helper to calculate the flattened size
        self._get_conv_output_size(input_channels, sequence_length)
        
        self.fc1 = nn.Linear(self._to_linear, num_classes)
        
    def _get_conv_output_size(self, input_channels, sequence_length):
        # Create a dummy input tensor
        dummy_input = torch.randn(1, input_channels, sequence_length)
        
        # Pass it through the convolutional and pooling layers
        output = self.pool(torch.relu(self.conv1(dummy_input)))
        
        # Calculate the flattened size
        self._to_linear = output.numel()

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

### Helper Functions

In [4]:
# util.py
def rec_cpu_count() -> int:
    """Returns recommended cpu count based on machine and a simple heuristic"""
    cpu_count = os.cpu_count()

    if cpu_count is None:
        return 4

    return min(cpu_count // 2, 8)

In [5]:
base_path = "/kaggle/input/mtcaic3"
# base_path = "/data/raw/mtcaic3"
label_mapping = {'Left': 0, 'Right': 1, 'Forward': 2, 'Backward': 3}
task_type = "MI"  # MI or SSVEP
sequence_length = None

if task_type == "MI":
    num_classes = 2
    sequence_length = 2250
elif task_type == "SSVEP":
    num_classes = 4
    sequence_length = 1750

batch_size = 128
max_num_workers = rec_cpu_count()

# Loading the data
train_mi, val_mi, test_ssvep = load_data(
    base_path=base_path, task_type=task_type, label_mapping=label_mapping
)

train_loader_mi = DataLoader(
    train_mi, 
    batch_size=batch_size,
    shuffle=True,
    num_workers=max_num_workers,
)
val_loader_mi = DataLoader(
    val_mi, 
    batch_size=batch_size,
    shuffle=False,
    num_workers=max_num_workers, 
)

## MI

In [7]:
# Defining the mi_model, loss, and optimizer
mi_model = BCIModel(train_mi[0][0].shape[1], num_classes=num_classes, sequence_length=sequence_length).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mi_model.parameters(), lr=0.001)

In [8]:
from tqdm import tqdm

# Training loop
num_epochs = 100
with torch.profiler.profile() as prof:
    for epoch in range(num_epochs):
        mi_model.train()
        
        for data, labels in train_loader_mi:
            data = data.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = mi_model(data.transpose(1, 2))

            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

        # Validation
        mi_model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, labels in val_loader_mi:
                data = data.to(device)
                labels = labels.to(device)
                
                outputs = mi_model(data.transpose(1, 2))

                val_loss += criterion(outputs, labels).item()

        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch + 1}, Val Loss: {val_loss / len(val_loader_mi)}")

Epoch 5, Val Loss: 807.3358764648438
Epoch 10, Val Loss: 1075.530029296875
Epoch 15, Val Loss: 830.7078247070312
Epoch 20, Val Loss: 893.7257080078125
Epoch 25, Val Loss: 801.54736328125
Epoch 30, Val Loss: 818.6141357421875
Epoch 35, Val Loss: 859.888427734375
Epoch 40, Val Loss: 862.5120239257812
Epoch 45, Val Loss: 863.661376953125
Epoch 50, Val Loss: 863.2930297851562
Epoch 55, Val Loss: 862.216064453125
Epoch 60, Val Loss: 862.005126953125
Epoch 65, Val Loss: 861.9745483398438
Epoch 70, Val Loss: 862.0718994140625
Epoch 75, Val Loss: 862.033935546875
Epoch 80, Val Loss: 862.0060424804688
Epoch 85, Val Loss: 862.01220703125
Epoch 90, Val Loss: 862.0165405273438
Epoch 95, Val Loss: 862.0192260742188
Epoch 100, Val Loss: 862.021240234375


In [9]:
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        53.34%       16.902s        53.75%       17.032s       7.742ms       0.000us         0.00%       0.000us       0.000us          2200  
                                        cudaMemcpyAsync        19.57%        6.201s        19.57%        6.201s       1.512ms       0.000us         0.00%       0.000us       0.000us          4100  
         

#### MI Model Summary

In [11]:
from torchinfo import summary

summary(mi_model, 
        input_size=(1, 8, 2250),
        col_names=['input_size',
                   'output_size',
                   'num_params'])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
BCIModel                                 [1, 8, 2250]              [1, 2]                    --
├─Conv1d: 1-1                            [1, 8, 2250]              [1, 16, 2248]             400
├─MaxPool1d: 1-2                         [1, 16, 2248]             [1, 16, 1124]             --
├─Linear: 1-3                            [1, 17984]                [1, 2]                    35,970
Total params: 36,370
Trainable params: 36,370
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.94
Input size (MB): 0.07
Forward/backward pass size (MB): 0.29
Params size (MB): 0.15
Estimated Total Size (MB): 0.51

### MI Validation Metrics

In [12]:
all_preds = []
all_labels = []

mi_model.eval()
with torch.no_grad():
    for x_batch, y_batch in val_loader_mi:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        logits = mi_model(x_batch.transpose(1, 2))
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds)
        all_labels.append(y_batch)

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

f1 = binary_f1_score(all_preds, all_labels)
print(f"F1 Score: {f1.item():.4f}")

multi_f1_micro = multiclass_f1_score(all_preds, all_labels, num_classes=2, average='micro')
print(f"Multiclass F1 Score (micro): {multi_f1_micro.item():.4f}")

multi_f1_macro = multiclass_f1_score(all_preds, all_labels, num_classes=2, average='macro')
print(f"Multiclass F1 Score (macro): {multi_f1_macro.item():.4f}")

acc = binary_accuracy(all_preds, all_labels)
print(f"Accuracy: {acc.item():.4f}")

F1 Score: 0.6111
Multiclass F1 Score (micro): 0.4400
Multiclass F1 Score (macro): 0.3056
Accuracy: 0.4400


### Save MI model

In [13]:
mi_save_path = "mi_model_simple_cnn_cuda.pth"
torch.save({
    "epoch": epoch,                     
    "model_state_dict": mi_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    "f1_score": multi_f1_macro,
}, mi_save_path)

### Load Model

In [14]:
def load_model(model_path, input_size, num_classes, sequence_length):
    model = BCIModel(input_size, num_classes, sequence_length).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    checkpoint = torch.load(model_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    return model, optimizer

mi_model_loaded, optimizer_loaded = load_model("mi_model_simple_cnn.pth", 8, 2, 2250)

### Loaded Model Evaluation

In [15]:
all_preds = []
all_labels = []

mi_model_loaded.eval()
with torch.no_grad():
    for x_batch, y_batch in val_loader_mi:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        logits = mi_model_loaded(x_batch.transpose(1, 2))
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds)
        all_labels.append(y_batch)

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

multi_f1_macro = multiclass_f1_score(all_preds, all_labels, num_classes=2, average='macro')
print(f"Multiclass F1 Score (macro): {multi_f1_macro.item():.4f}")

acc = binary_accuracy(all_preds, all_labels)
print(f"Accuracy: {acc.item():.4f}")

Multiclass F1 Score (macro): 0.3056
Accuracy: 0.4400


## SSVEP

In [18]:
base_path = "/kaggle/input/mtcaic3"
# base_path = "/data/raw/mtcaic3"
label_mapping = {'Left': 0, 'Right': 1, 'Forward': 2, 'Backward': 3}
task_type = "SSVEP"  # MI or SSVEP
sequence_length = None

if task_type == "MI":
    num_classes = 2
    sequence_length = 2250
elif task_type == "SSVEP":
    num_classes = 4
    sequence_length = 1750

batch_size = 128
max_num_workers = rec_cpu_count()


In [19]:

# Loading the data
train_ssvep, val_ssvep, test_ssvep = load_data(
    base_path=base_path, task_type=task_type, label_mapping=label_mapping
)

train_loader_ssvep = DataLoader(
    train_ssvep, batch_size=batch_size, shuffle=True, num_workers=max_num_workers
)
val_loader_ssvep = DataLoader(
    val_ssvep, batch_size=batch_size, shuffle=False, num_workers=max_num_workers
)

In [20]:
# Defining the model, loss, and optimizer
ssvep_model = BCIModel(train_ssvep[0][0].shape[1], num_classes=num_classes, sequence_length=sequence_length).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(ssvep_model.parameters(), lr=0.001)

In [21]:
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    ssvep_model.train()
    for data, labels in train_loader_ssvep:
        data = data.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = ssvep_model(data.transpose(1, 2))

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

    # Validation
    ssvep_model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, labels in val_loader_ssvep:
            data = data.to(device)
            labels = labels.to(device)
            
            outputs = ssvep_model(data.transpose(1, 2))

            val_loss += criterion(outputs, labels).item()

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch + 1}, Val Loss: {val_loss / len(val_loader_ssvep)}")

Epoch 5, Val Loss: 552.5588989257812
Epoch 10, Val Loss: 550.42529296875
Epoch 15, Val Loss: 550.424560546875
Epoch 20, Val Loss: 550.423583984375
Epoch 25, Val Loss: 550.4229125976562
Epoch 30, Val Loss: 550.422607421875
Epoch 35, Val Loss: 550.421875
Epoch 40, Val Loss: 550.421630859375
Epoch 45, Val Loss: 550.4212646484375
Epoch 50, Val Loss: 550.4213256835938
Epoch 55, Val Loss: 550.421142578125
Epoch 60, Val Loss: 550.4210205078125
Epoch 65, Val Loss: 550.4210205078125
Epoch 70, Val Loss: 550.4209594726562
Epoch 75, Val Loss: 550.4208374023438
Epoch 80, Val Loss: 550.4209594726562
Epoch 85, Val Loss: 550.4205932617188
Epoch 90, Val Loss: 550.420654296875
Epoch 95, Val Loss: 550.4210815429688
Epoch 100, Val Loss: 550.4209594726562


#### SSVEP Model Summary

In [22]:
from torchinfo import summary

summary(ssvep_model, 
        input_size=(1, 8, 1750),
        col_names=['input_size',
                   'output_size',
                   'num_params'])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
BCIModel                                 [1, 8, 1750]              [1, 4]                    --
├─Conv1d: 1-1                            [1, 8, 1750]              [1, 16, 1748]             400
├─MaxPool1d: 1-2                         [1, 16, 1748]             [1, 16, 874]              --
├─Linear: 1-3                            [1, 13984]                [1, 4]                    55,940
Total params: 56,340
Trainable params: 56,340
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.76
Input size (MB): 0.06
Forward/backward pass size (MB): 0.22
Params size (MB): 0.23
Estimated Total Size (MB): 0.51

### SSVEP Validation Metrics

In [23]:
from torcheval.metrics.functional import multiclass_accuracy

In [25]:
all_preds = []
all_labels = []

ssvep_model.eval()
with torch.no_grad():
    for x_batch, y_batch in val_loader_ssvep:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        logits = ssvep_model(x_batch.transpose(1, 2))
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds)
        all_labels.append(y_batch)

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

multi_f1_micro = multiclass_f1_score(all_preds, all_labels, num_classes=num_classes, average='micro')
print(f"Multiclass F1 Score (micro): {multi_f1_micro.item():.4f}")

multi_f1_macro = multiclass_f1_score(all_preds, all_labels, num_classes=num_classes, average='macro')
print(f"Multiclass F1 Score (macro): {multi_f1_macro.item():.4f}")

acc = multiclass_accuracy(all_preds, all_labels, num_classes=num_classes)
print(f"Multiclass Accuracy: {acc.item():.4f}")

Multiclass F1 Score (micro): 0.3000
Multiclass F1 Score (macro): 0.1736
Multiclass Accuracy: 0.3000


In [26]:
all_labels

tensor([3, 0, 2, 2, 2, 0, 1, 3, 0, 2, 3, 1, 1, 0, 2, 0, 1, 2, 0, 3, 2, 0, 3, 1,
        0, 2, 2, 3, 0, 1, 1, 3, 3, 3, 0, 3, 1, 1, 2, 3, 0, 0, 2, 0, 3, 1, 0, 2,
        3, 3], device='cuda:0')

In [27]:
all_preds

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3], device='cuda:0')

### SSVEP Save Model

In [28]:
ssvep_save_path = "ssvep_model_simple_cnn.pth"
torch.save({
    "epoch": epoch,                     
    "model_state_dict": ssvep_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    "f1_score": multi_f1_macro,
}, ssvep_save_path)

In [None]:
def make_labels():
    
    _, _, mi_test = 
    _, _, ssvep_test = 