# Training the MI and SSVEP models using a simple CNN 

In [1]:
# %pip install torcheval

In [1]:
import torch

import os
import pandas as pd
from typing import Tuple

from torch import nn, optim
import torch.nn.functional as F
from torcheval.metrics.functional import (multiclass_f1_score,
                                          binary_accuracy,
                                          binary_f1_score)

from pytorch_lightning import seed_everything

# Code necessary to create reproducible runs
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed_everything(42, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Seed set to 42


In [2]:
# Set the number of threads for computations
torch.set_num_threads(12)
print(f"Using {torch.get_num_threads()} threads for computations")

Using 12 threads for computations


### Dataset Abstraction

In [3]:
# data.py
from torch.utils.data import DataLoader, Dataset

class BCIDataset(Dataset):
    def __init__(self, csv_file, base_path, task_type="MI", label_mapping=None):
        # Filter the main dataframe for the specific task (MI or SSVEP)
        self.metadata = pd.read_csv(os.path.join(base_path, csv_file))
        self.metadata = self.metadata[self.metadata["task"] == task_type]
        self.base_path = base_path
        self.task_type = task_type
        self.label_mapping = label_mapping

        # 9 seconds * 250 Hz = 2250 for MI
        # 7 seconds * 250 Hz = 1750 for SSVEP
        self.sequence_length = 2250 if task_type == "MI" else 1750


        num_trials = len(self.metadata)

        self.tensor_data = torch.empty(
            num_trials, self.sequence_length, 8, dtype=torch.float32
        )
        self.labels = torch.empty(num_trials, dtype=torch.long)

        for i, (idx, row) in enumerate(self.metadata.iterrows()):
            # Determine dataset split (train/validation/test)
            id_num = row["id"]
            if id_num <= 4800:
                dataset_split = "train"
            elif id_num <= 4900:
                dataset_split = "validation"
            else:
                dataset_split = "test"

            # Path to the EEG data file
            eeg_path = os.path.join(
                self.base_path,
                row["task"],
                dataset_split,
                row["subject_id"],
                str(row["trial_session"]),
                "EEGdata.csv",
            )

            eeg_data = pd.read_csv(eeg_path)

            # Extract the correct trial segment
            trial_num = int(row["trial"])

            samples_per_trial = self.sequence_length
            start_idx = (trial_num - 1) * samples_per_trial
            end_idx = start_idx + samples_per_trial - 1

            # Select only the 8 EEG channels
            eeg_channels = ["FZ", "C3", "CZ", "C4", "PZ", "PO7", "OZ", "PO8"]
            trial_data = eeg_data.loc[start_idx : end_idx, eeg_channels].values

            # uncomment the line below and comment the one above to include all 18 columns
            # trial_data = eeg_data.loc[start_idx:end_idx-1].values

            # Preprocess the data (see next section)
            processed_data = self.preprocess(trial_data)

            # Convert to tensor
            tensor_data = torch.tensor(processed_data, dtype=torch.float32)
            self.tensor_data[i] = tensor_data

            # Get label if it exists
            if "label" in row and self.label_mapping:
                label_str = row["label"]
                label_int = self.label_mapping[label_str]
                self.labels[i] = label_int

    def __len__(self):
        return len(self.tensor_data)

    def __getitem__(self, idx):
        if self.labels[idx] is not None:
            return self.tensor_data[idx], self.labels[idx]
        else:
            return self.tensor_data[idx]

    def preprocess(self, eeg_data):
        # Apply preprocessing steps here (filtering, normalization, etc.)
        # This will be different for MI and SSVEP
        # ...
        return eeg_data


def load_data(base_path, task_type, label_mapping) -> Tuple[BCIDataset, BCIDataset, BCIDataset]:
    """
    Loads the train, val, test data for the given {task_type} inside the given {base_path}

    Returns:
        a tuple of BCIDataset in the order (train, val, test)
    """

    train = BCIDataset(csv_file="train.csv", base_path=base_path, task_type=task_type, label_mapping=label_mapping)
    val = BCIDataset(
        csv_file="validation.csv", base_path=base_path, task_type=task_type, label_mapping=label_mapping
    )
    test = BCIDataset(csv_file="test.csv", base_path=base_path, task_type=task_type, label_mapping=label_mapping)

    return train, val, test

### Simple CNN Model Architecture

In [4]:
# models/simple_cnn.py

# Simple CNN model
class BCIModel(nn.Module):
    def __init__(self, input_channels, num_classes, sequence_length):
        super(BCIModel, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=3)  # 8 channels
        self.pool = nn.MaxPool1d(2)
        
        # Helper to calculate the flattened size
        self._get_conv_output_size(input_channels, sequence_length)
        
        self.fc1 = nn.Linear(self._to_linear, num_classes)
        
    def _get_conv_output_size(self, input_channels, sequence_length):
        # Create a dummy input tensor
        dummy_input = torch.randn(1, input_channels, sequence_length)
        
        # Pass it through the convolutional and pooling layers
        output = self.pool(torch.relu(self.conv1(dummy_input)))
        
        # Calculate the flattened size
        self._to_linear = output.numel()

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

### Helper Functions

In [4]:
# util.py
def rec_cpu_count() -> int:
    """Returns recommended cpu count based on machine and a simple heuristic"""
    cpu_count = os.cpu_count()

    if cpu_count is None:
        return 4

    return min(cpu_count // 2, 8)

In [5]:
from bci_aic3.paths import RAW_DATA_DIR

# base_path = "/kaggle/input/mtcaic3"
base_path = RAW_DATA_DIR
label_mapping = {'Left': 0, 'Right': 1, 'Forward': 2, 'Backward': 3}
task_type = "MI"  # MI or SSVEP
sequence_length = None

if task_type == "MI":
    num_classes = 2
    sequence_length = 2250
elif task_type == "SSVEP":
    num_classes = 4
    sequence_length = 1750

batch_size = 128
max_num_workers = 8

# Loading the data
train_mi, val_mi, test_ssvep = load_data(
    base_path=base_path, task_type=task_type, label_mapping=label_mapping
)

train_loader_mi = DataLoader(
    train_mi, 
    batch_size=batch_size,
    shuffle=True,
    num_workers=max_num_workers,
)
val_loader_mi = DataLoader(
    val_mi, 
    batch_size=batch_size,
    shuffle=False,
    num_workers=max_num_workers, 
)

## MI

In [6]:
# Defining the mi_model, loss, and optimizer
mi_model = BCIModel(train_mi[0][0].shape[1], num_classes=num_classes, sequence_length=sequence_length).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mi_model.parameters(), lr=0.001)

In [None]:
from tqdm import tqdm

# Training loop
num_epochs = 10
with torch.profiler.profile() as prof:
    for epoch in tqdm(range(num_epochs)):
        mi_model.train()
        
        for data, labels in train_loader_mi:
            data = data.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = mi_model(data.transpose(1, 2))

            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

        # Validation
        mi_model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, labels in val_loader_mi:
                data = data.to(device)
                labels = labels.to(device)
                
                outputs = mi_model(data.transpose(1, 2))

                val_loss += criterion(outputs, labels).item()

        # if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch + 1}, Val Loss: {val_loss / len(val_loader_mi)}")

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

NameError: name 'prof' is not defined

#### MI Model Summary

In [11]:
from torchinfo import summary

summary(mi_model, 
        input_size=(1, 8, 2250),
        col_names=['input_size',
                   'output_size',
                   'num_params'])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
BCIModel                                 [1, 8, 2250]              [1, 2]                    --
├─Conv1d: 1-1                            [1, 8, 2250]              [1, 16, 2248]             400
├─MaxPool1d: 1-2                         [1, 16, 2248]             [1, 16, 1124]             --
├─Linear: 1-3                            [1, 17984]                [1, 2]                    35,970
Total params: 36,370
Trainable params: 36,370
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.94
Input size (MB): 0.07
Forward/backward pass size (MB): 0.29
Params size (MB): 0.15
Estimated Total Size (MB): 0.51

### MI Validation Metrics

In [12]:
all_preds = []
all_labels = []

mi_model.eval()
with torch.no_grad():
    for x_batch, y_batch in val_loader_mi:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        logits = mi_model(x_batch.transpose(1, 2))
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds)
        all_labels.append(y_batch)

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

f1 = binary_f1_score(all_preds, all_labels)
print(f"F1 Score: {f1.item():.4f}")

multi_f1_micro = multiclass_f1_score(all_preds, all_labels, num_classes=2, average='micro')
print(f"Multiclass F1 Score (micro): {multi_f1_micro.item():.4f}")

multi_f1_macro = multiclass_f1_score(all_preds, all_labels, num_classes=2, average='macro')
print(f"Multiclass F1 Score (macro): {multi_f1_macro.item():.4f}")

acc = binary_accuracy(all_preds, all_labels)
print(f"Accuracy: {acc.item():.4f}")

F1 Score: 0.6111
Multiclass F1 Score (micro): 0.4400
Multiclass F1 Score (macro): 0.3056
Accuracy: 0.4400


### Save MI model

In [13]:
mi_save_path = "mi_model_simple_cnn_cuda.pth"
torch.save({
    "epoch": epoch,                     
    "model_state_dict": mi_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    "f1_score": multi_f1_macro,
}, mi_save_path)

### Load Model

In [14]:
def load_model(model_path, input_size, num_classes, sequence_length):
    model = BCIModel(input_size, num_classes, sequence_length).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    checkpoint = torch.load(model_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    return model, optimizer

mi_model_loaded, optimizer_loaded = load_model("mi_model_simple_cnn.pth", 8, 2, 2250)

### Loaded Model Evaluation

In [15]:
all_preds = []
all_labels = []

mi_model_loaded.eval()
with torch.no_grad():
    for x_batch, y_batch in val_loader_mi:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        logits = mi_model_loaded(x_batch.transpose(1, 2))
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds)
        all_labels.append(y_batch)

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

multi_f1_macro = multiclass_f1_score(all_preds, all_labels, num_classes=2, average='macro')
print(f"Multiclass F1 Score (macro): {multi_f1_macro.item():.4f}")

acc = binary_accuracy(all_preds, all_labels)
print(f"Accuracy: {acc.item():.4f}")

Multiclass F1 Score (macro): 0.3056
Accuracy: 0.4400


## SSVEP

In [45]:
base_path = "/kaggle/input/mtcaic3"
# base_path = "/data/raw/mtcaic3"
label_mapping = {'Left': 0, 'Right': 1, 'Forward': 2, 'Backward': 3}
task_type = "SSVEP"  # MI or SSVEP
sequence_length = None

if task_type == "MI":
    num_classes = 2
    sequence_length = 2250
elif task_type == "SSVEP":
    num_classes = 4
    sequence_length = 1750

batch_size = 128
max_num_workers = rec_cpu_count()


In [46]:

# Loading the data
train_ssvep, val_ssvep, test_ssvep = load_data(
    base_path=base_path, task_type=task_type, label_mapping=label_mapping
)

train_loader_ssvep = DataLoader(
    train_ssvep, batch_size=batch_size, shuffle=True, num_workers=max_num_workers
)
val_loader_ssvep = DataLoader(
    val_ssvep, batch_size=batch_size, shuffle=False, num_workers=max_num_workers
)

In [47]:
### EEGNet Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm


class EEGNet_SSVEP(nn.Module):
    """SSVEP Variant of EEGNet, as used in [1].
    
    Inputs:
        nb_classes      : int, number of classes to classify
        Chans, Samples  : number of channels and time points in the EEG data
        dropoutRate     : dropout fraction
        kernLength      : length of temporal convolution in first layer
        F1, F2          : number of temporal filters (F1) and number of pointwise
                          filters (F2) to learn.
        D               : number of spatial filters to learn within each temporal
                          convolution.
        dropoutType     : Either 'SpatialDropout2D' or 'Dropout', passed as a string.
    
    [1]. Waytowich, N. et. al. (2018). Compact Convolutional Neural Networks
    for Classification of Asynchronous Steady-State Visual Evoked Potentials.
    Journal of Neural Engineering vol. 15(6).
    http://iopscience.iop.org/article/10.1088/1741-2552/aae5d8
    """
    
    def __init__(self, nb_classes=4, Chans=8, Samples=1750, 
                 dropoutRate=0.25, kernLength=250, F1=96, 
                 D=1, F2=96, dropoutType='Dropout'):
        super(EEGNet_SSVEP, self).__init__()
        
        self.nb_classes = nb_classes
        self.Chans = Chans
        self.Samples = Samples
        self.dropoutRate = dropoutRate
        self.kernLength = kernLength
        self.F1 = F1
        self.D = D
        self.F2 = F2
        
        # Validate dropout type
        if dropoutType not in ['SpatialDropout2D', 'Dropout']:
            raise ValueError('dropoutType must be one of SpatialDropout2D '
                           'or Dropout, passed as a string.')
        self.dropoutType = dropoutType
        
        # Block 1
        self.conv1 = nn.Conv2d(1, F1, (1, kernLength), padding='same', bias=False)
        self.batchnorm1 = nn.BatchNorm2d(F1)
        
        # Depthwise convolution - equivalent to DepthwiseConv2D
        self.depthwise_conv = nn.Conv2d(F1, F1 * D, (Chans, 1), 
                                       groups=F1, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(F1 * D)
        
        # Apply max_norm constraint to depthwise conv weights
        self.depthwise_conv = weight_norm(self.depthwise_conv, name='weight')
        
        self.avgpool1 = nn.AvgPool2d((1, 4))
        
        # Block 2
        # SeparableConv2D is equivalent to depthwise + pointwise convolution
        self.separable_conv_depthwise = nn.Conv2d(F1 * D, F1 * D, (1, 16), 
                                                 groups=F1 * D, padding='same', bias=False)
        self.separable_conv_pointwise = nn.Conv2d(F1 * D, F2, 1, bias=False)
        
        self.batchnorm3 = nn.BatchNorm2d(F2)
        self.avgpool2 = nn.AvgPool2d((1, 8))
        
        # Calculate the size after convolutions for the linear layer
        # We need to be more careful about the feature size calculation
        # This will be computed dynamically in the first forward pass
        self.feature_size = None
        
        # Dense layer - will be initialized on first forward pass
        self.classifier = None
        
        # Dropout layers
        if dropoutType == 'SpatialDropout2D':
            self.dropout1 = SpatialDropout2d(dropoutRate)
            self.dropout2 = SpatialDropout2d(dropoutRate)
        else:  # 'Dropout'
            self.dropout1 = nn.Dropout2d(dropoutRate)
            self.dropout2 = nn.Dropout2d(dropoutRate)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize weights similar to Keras defaults"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Handle different input formats
        if x.dim() == 3:
            # Input shape: (batch_size, samples, channels)
            # Reshape to (batch_size, channels, samples, 1) then to (batch_size, 1, channels, samples)
            x = x.permute(0, 2, 1)  # (batch_size, channels, samples)
            x = x.unsqueeze(1)      # (batch_size, 1, channels, samples)
        elif x.dim() == 4:
            # Input shape: (batch_size, Chans, Samples, 1)
            # Reshape to (batch_size, 1, Chans, Samples) for PyTorch conv2d
            x = x.permute(0, 3, 1, 2)
        else:
            raise ValueError(f"Expected input to be 3D or 4D, got {x.dim()}D tensor")
        
        # Block 1
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.depthwise_conv(x)
        x = self.batchnorm2(x)
        x = F.elu(x)
        x = self.avgpool1(x)
        x = self.dropout1(x)
        
        # Block 2
        x = self.separable_conv_depthwise(x)
        x = self.separable_conv_pointwise(x)
        x = self.batchnorm3(x)
        x = F.elu(x)
        
        # Check if we can apply the second pooling
        if x.size(-1) >= 8:  # Check if temporal dimension is >= 8
            x = self.avgpool2(x)
        else:
            # Use adaptive pooling or smaller kernel
            x = F.avg_pool2d(x, (1, min(x.size(-1), 2)))
        
        x = self.dropout2(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Initialize classifier on first forward pass
        if self.classifier is None:
            self.feature_size = x.size(1)
            self.classifier = nn.Linear(self.feature_size, self.nb_classes).to(x.device)
            # Initialize the classifier weights
            nn.init.kaiming_normal_(self.classifier.weight, mode='fan_out', nonlinearity='relu')
            nn.init.constant_(self.classifier.bias, 0)
        
        # Dense layer
        x = self.classifier(x)
        
        # Softmax (often applied in loss function, but included here for completeness)
        x = F.softmax(x, dim=1)
        
        return x


class SpatialDropout2d(nn.Module):
    """Spatial Dropout implementation for PyTorch
    
    Drops entire feature maps instead of individual elements.
    This is equivalent to Keras' SpatialDropout2D.
    """
    
    def __init__(self, p=0.5):
        super(SpatialDropout2d, self).__init__()
        self.p = p
    
    def forward(self, x):
        if not self.training:
            return x
        
        # x shape: (N, C, H, W)
        N, C, H, W = x.size()
        
        # Create mask for entire feature maps
        mask = torch.bernoulli(torch.full((N, C, 1, 1), 1 - self.p, device=x.device))
        mask = mask.expand_as(x)
        
        # Apply mask and scale
        return x * mask / (1 - self.p)


In [90]:
# Defining the model, loss, and optimizer
# ssvep_model = BCIModel(train_ssvep[0][0].shape[1], num_classes=num_classes, sequence_length=sequence_length).to(device)
ssvep_model = EEGNet_SSVEP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(ssvep_model.parameters(), lr=0.001)

In [92]:
from tqdm import tqdm

# Training loop
num_epochs = 100
for epoch in tqdm(range(num_epochs)):
    ssvep_model.train()
    for data, labels in train_loader_ssvep:

        data = data.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = ssvep_model(data.transpose(1, 2))

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

    # Validation
    ssvep_model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, labels in val_loader_ssvep:
            data = data.to(device)
            labels = labels.to(device)
            
            outputs = ssvep_model(data.transpose(1, 2))

            val_loss += criterion(outputs, labels).item()

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch + 1}, Val Loss: {val_loss / len(val_loader_ssvep)}")

  5%|▌         | 5/100 [00:34<10:49,  6.83s/it]

Epoch 5, Val Loss: 1.4636753797531128


 10%|█         | 10/100 [01:08<10:09,  6.77s/it]

Epoch 10, Val Loss: 1.463667869567871


 15%|█▌        | 15/100 [01:42<09:38,  6.80s/it]

Epoch 15, Val Loss: 1.4636681079864502


 20%|██        | 20/100 [02:16<09:03,  6.79s/it]

Epoch 20, Val Loss: 1.48002290725708


 25%|██▌       | 25/100 [02:49<08:28,  6.78s/it]

Epoch 25, Val Loss: 1.429288625717163


 30%|███       | 30/100 [03:23<07:54,  6.78s/it]

Epoch 30, Val Loss: 1.4636681079864502


 35%|███▌      | 35/100 [03:57<07:21,  6.79s/it]

Epoch 35, Val Loss: 1.463456392288208


 40%|████      | 40/100 [04:31<06:47,  6.80s/it]

Epoch 40, Val Loss: 1.5036680698394775


 45%|████▌     | 45/100 [05:05<06:14,  6.80s/it]

Epoch 45, Val Loss: 1.560676097869873


 50%|█████     | 50/100 [05:39<05:39,  6.80s/it]

Epoch 50, Val Loss: 1.5436680316925049


 55%|█████▌    | 55/100 [06:13<05:05,  6.80s/it]

Epoch 55, Val Loss: 1.5436680316925049


 60%|██████    | 60/100 [06:47<04:31,  6.79s/it]

Epoch 60, Val Loss: 1.5436680316925049


 65%|██████▌   | 65/100 [07:21<03:57,  6.79s/it]

Epoch 65, Val Loss: 1.5436680316925049


 70%|███████   | 70/100 [07:55<03:23,  6.79s/it]

Epoch 70, Val Loss: 1.5436680316925049


 75%|███████▌  | 75/100 [08:29<02:49,  6.79s/it]

Epoch 75, Val Loss: 1.4836680889129639


 80%|████████  | 80/100 [09:03<02:16,  6.80s/it]

Epoch 80, Val Loss: 1.463484764099121


 85%|████████▌ | 85/100 [09:37<01:41,  6.80s/it]

Epoch 85, Val Loss: 1.5436680316925049


 90%|█████████ | 90/100 [10:11<01:07,  6.79s/it]

Epoch 90, Val Loss: 1.5436680316925049


 95%|█████████▌| 95/100 [10:45<00:33,  6.79s/it]

Epoch 95, Val Loss: 1.4636681079864502


100%|██████████| 100/100 [11:19<00:00,  6.80s/it]

Epoch 100, Val Loss: 1.5436680316925049





#### SSVEP Model Summary

In [97]:
from torchinfo import summary

summary(ssvep_model, 
        input_size=(128, 8, 1750),
        col_names=['input_size',
                   'output_size',
                   'num_params'])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
EEGNet_SSVEP                             [128, 8, 1750]            [128, 4]                  --
├─Conv2d: 1-1                            [128, 1, 1750, 8]         [128, 96, 1750, 8]        24,000
├─BatchNorm2d: 1-2                       [128, 96, 1750, 8]        [128, 96, 1750, 8]        192
├─ParametrizedConv2d: 1-3                [128, 96, 1750, 8]        [128, 96, 1743, 8]        --
│    └─ModuleDict: 2-1                   --                        --                        --
│    │    └─ParametrizationList: 3-1     --                        [96, 1, 8, 1]             769
├─BatchNorm2d: 1-4                       [128, 96, 1743, 8]        [128, 96, 1743, 8]        192
├─AvgPool2d: 1-5                         [128, 96, 1743, 8]        [128, 96, 1743, 2]        --
├─Dropout2d: 1-6                         [128, 96, 1743, 2]        [128, 96, 1743, 2]        --
├─Conv2d: 1-7               

### SSVEP Validation Metrics

In [98]:
from torcheval.metrics.functional import multiclass_accuracy

In [99]:
all_preds = []
all_labels = []

ssvep_model.eval()
with torch.no_grad():
    for x_batch, y_batch in val_loader_ssvep:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        logits = ssvep_model(x_batch.transpose(1, 2))
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds)
        all_labels.append(y_batch)

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

multi_f1_micro = multiclass_f1_score(all_preds, all_labels, num_classes=num_classes, average='micro')
print(f"Multiclass F1 Score (micro): {multi_f1_micro.item():.4f}")

multi_f1_macro = multiclass_f1_score(all_preds, all_labels, num_classes=num_classes, average='macro')
print(f"Multiclass F1 Score (macro): {multi_f1_macro.item():.4f}")

acc = multiclass_accuracy(all_preds, all_labels, num_classes=num_classes)
print(f"Multiclass Accuracy: {acc.item():.4f}")

Multiclass F1 Score (micro): 0.2000
Multiclass F1 Score (macro): 0.1264
Multiclass Accuracy: 0.2000


In [100]:
all_labels

tensor([3, 0, 2, 2, 2, 0, 1, 3, 0, 2, 3, 1, 1, 0, 2, 0, 1, 2, 0, 3, 2, 0, 3, 1,
        0, 2, 2, 3, 0, 1, 1, 3, 3, 3, 0, 3, 1, 1, 2, 3, 0, 0, 2, 0, 3, 1, 0, 2,
        3, 3], device='cuda:0')

In [101]:
all_preds

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3], device='cuda:0')

### SSVEP Save Model

In [28]:
ssvep_save_path = "ssvep_model_simple_cnn.pth"
torch.save({
    "epoch": epoch,                     
    "model_state_dict": ssvep_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    "f1_score": multi_f1_macro,
}, ssvep_save_path)