In [1]:
import sys
import os
import warnings
import mne
import torch

warnings.filterwarnings("ignore")
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))
mne.set_config("MNE_DATA", os.path.join(os.getcwd(), "data"))
mne.set_config("MNE_DATASETS_BNCI_PATH", os.path.join(os.getcwd(), "data"))
mne.set_config("MNE_DATASETS_EEGBCI_PATH", os.path.join(os.getcwd(), "data"))
mne.set_config("MNE_DATASETS_SHIN_PATH", os.path.join(os.getcwd(), "data"))
mne.set_config("MOABB_RESULTS", os.path.join(os.getcwd(), "results"))
os.makedirs(os.environ["MNE_DATA"], exist_ok=True)
os.makedirs(os.environ["MOABB_RESULTS"], exist_ok=True)

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from importlib import reload
from mne.decoding import CSP, PSDEstimator
from sklearn.model_selection import KFold
from torchmetrics.classification import Accuracy

import scripts.transformer.transformer_models as trans
from scripts.dataset.eeg_dataset import EEGDataset
from scripts.features_extract.welch import extract_welch_features
from eeg_logger import logger

import moabb
from moabb.datasets import PhysionetMI
from moabb.paradigms import LeftRightImagery

moabb.set_log_level("info")

# Instantiate dataset and load data

MOABB makes working with datasets super easy. With Physionet we can specify which runs we want to analyze using `imagined` and `executed` flags. Setting `imagined` to **True** makes `get_data` method return runs 4, 8, 12, 6, 10, 14 because they contain imaginary tasks. For this work, we only need runs 4, 8, 12 because they contain left and right hand movement. The `get_data` method returns dict structured like this:  
`data[subject_index]["session_index"]["run_index"]`


In [3]:
dataset = PhysionetMI(imagined=True, executed=False)
subject_list = np.delete(np.arange(1, 110), [87, 91, 99, 103])
dataset.subject_list = subject_list
epochs_all = []
labels_all = []

data = dataset.get_data(subjects=subject_list.tolist())
data[1]

{'0': {'0': <RawEDF | S001R04.edf, 65 x 20000 (125.0 s), ~10.0 MiB, data loaded>,
  '1': <RawEDF | S001R08.edf, 65 x 20000 (125.0 s), ~10.0 MiB, data loaded>,
  '2': <RawEDF | S001R12.edf, 65 x 20000 (125.0 s), ~10.0 MiB, data loaded>,
  '3': <RawEDF | S001R06.edf, 65 x 20000 (125.0 s), ~10.0 MiB, data loaded>,
  '4': <RawEDF | S001R10.edf, 65 x 20000 (125.0 s), ~10.0 MiB, data loaded>,
  '5': <RawEDF | S001R14.edf, 65 x 20000 (125.0 s), ~10.0 MiB, data loaded>}}

# Extract epochs related to motor imagery


In [4]:
def normalize(epochs: mne.Epochs) -> mne.epochs:
    """
    Applies z-score normalization according to this formula:
    X* = (X - mean) / std + aN
    """

    data: np.ndarray = epochs.get_data()  # shape: (n_epochs, n_channels, n_times)
    mean = data.mean(axis=2, keepdims=True)
    std = data.std(axis=2, keepdims=True)
    std[std == 0] = 1.0
    N = np.random.randn(*data.shape)
    a = 0.01

    zscored_data = (data - mean) / std + a * N
    epochs._data = zscored_data

    return epochs


selected_event_id = {"left_hand": 1, "right_hand": 3}  # BASED ON EVENT_IDS
tmin_3s, tmax_3s = 2.0, 5.0
channels = [
    "FC5",
    "FC3",
    "FC1",
    "FC2",
    "FC4",
    "FC6",
    "C5",
    "C3",
    "C1",
    "Cz",
    "C2",
    "C4",
    "C6",
    "CP5",
    "CP3",
    "CP1",
    "CP2",
    "CP4",
    "CP6",
]
epochs_all_subjects = []

for subject in subject_list:

    session_data = data[subject]["0"]
    run_4 = session_data["0"]
    run_8 = session_data["1"]
    run_12 = session_data["2"]

    all_runs = mne.concatenate_raws([run_4, run_8, run_12])
    events, event_ids = mne.events_from_annotations(all_runs)
    if subject == 1:
        logger.info(f"Event ids: {event_ids}")

    epochs_3s = mne.Epochs(
        all_runs,
        events,
        event_id=selected_event_id,
        tmin=tmin_3s,
        tmax=tmax_3s,
        picks=channels,
        baseline=None,
        preload=True,
    )
    epochs_3s = normalize(epochs_3s)
    epochs_all_subjects.append(epochs_3s)

[34m2025-08-08 15:11:36,881 - INFO - Event ids: {'left_hand': 1, 'rest': 2, 'right_hand': 3}[0m


# Get data and labels from extracted epochs


In [5]:
def extract_data_from_epochs(epochs: mne.Epochs, label_mapping: dict | None = None) -> tuple[np.ndarray, np.ndarray]:
    X = epochs.get_data()
    y = epochs.events[:, -1]
    if label_mapping:
        y = np.array([label_mapping[label] for label in y])
    return X, y


X_all = []
y_all = []
for subject_epochs in epochs_all_subjects:
    X, y = extract_data_from_epochs(subject_epochs, label_mapping={1: 0, 3: 1})
    X_all.append(X)
    y_all.append(y)
X_all = np.concatenate(X_all)
y_all = np.concatenate(y_all)

# Train and evaluate model


In [6]:
def train_model(
    model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, device: torch.device, verbose: bool
) -> None:
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0007, weight_decay=0.0001)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(50):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if verbose:
            logger.info(f"Epoch {epoch+1}/{50}, Loss: {total_loss:.4f}")


def evaluate_model(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    device: torch.device,
) -> float:
    """
    Computes accuracy of provided model.

    :param model: model to evaluate
    :param test_loader: loader for testing data
    :param device: device to evaluate model on
    """
    acc = Accuracy(task="binary").to(device)
    model.eval()

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            output = model(X_batch)
            preds = torch.argmax(output, dim=1)
            acc.update(preds, y_batch)

    return acc.compute().item()


def train(X, y):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device == "cpu":
        logger.warning("Warning - training model on cpu")
    else:
        logger.info("Training model on gpu")

    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    accuracies = []

    for fold, (train_idx, test_idx) in enumerate(kf.split(X, y)):

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        train_dataset = EEGDataset(X_train, y_train, cnn_mode=False)
        test_dataset = EEGDataset(X_test, y_test, cnn_mode=False)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

        model = trans.SpatialTransformer(input_size=X_train.shape[2], d_model=64, num_heads=8, num_classes=2)

        logger.info(f"Training {model._get_name()} in fold {fold + 1}...")
        train_model(model, train_loader, device, verbose=False)

        accuracy = evaluate_model(model, test_loader, device)
        logger.info(f"Accuracy for {model._get_name()} in fold {fold + 1}: {accuracy * 100:.2f}%")
        accuracies.append(accuracy)

    logger.info(f"Accuracy across 5 folds: {np.mean(accuracies) * 100:.2f}%")


train(X_all, y_all)

[34m2025-08-08 15:12:10,873 - INFO - Training model on gpu[0m
[34m2025-08-08 15:12:11,135 - INFO - Training SpatialTransformer in fold 1...[0m
[34m2025-08-08 15:12:46,199 - INFO - Accuracy for SpatialTransformer in fold 1: 67.35%[0m
[34m2025-08-08 15:12:46,351 - INFO - Training SpatialTransformer in fold 2...[0m
[34m2025-08-08 15:13:22,221 - INFO - Accuracy for SpatialTransformer in fold 2: 65.87%[0m
[34m2025-08-08 15:13:22,371 - INFO - Training SpatialTransformer in fold 3...[0m
[34m2025-08-08 15:13:55,872 - INFO - Accuracy for SpatialTransformer in fold 3: 69.27%[0m
[34m2025-08-08 15:13:56,005 - INFO - Training SpatialTransformer in fold 4...[0m
[34m2025-08-08 15:14:31,738 - INFO - Accuracy for SpatialTransformer in fold 4: 64.06%[0m
[34m2025-08-08 15:14:31,865 - INFO - Training SpatialTransformer in fold 5...[0m
[34m2025-08-08 15:15:01,782 - INFO - Accuracy for SpatialTransformer in fold 5: 65.87%[0m
[34m2025-08-08 15:15:01,783 - INFO - Accuracy across 5 folds