<a href="https://colab.research.google.com/github/PershinIlya/eeg-fm-eeg2025/blob/main/notebooks/labram_challenge1_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train a simple FM model based on LabraM

In [1]:
# Identify whether a CUDA-enabled GPU is available
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    msg ='CUDA-enabled GPU found. Training should be faster.'
else:
    msg = (
        "No GPU found. Training will be carried out on CPU, which might be "
        "slower.\n\nIf running on Google Colab, you can request a GPU runtime by"
        " clicking\n`Runtime/Change runtime type` in the top bar menu, then "
        "selecting \'T4 GPU\'\nunder \'Hardware accelerator\'."
    )
print(msg)

No GPU found. Training will be carried out on CPU, which might be slower.

If running on Google Colab, you can request a GPU runtime by clicking
`Runtime/Change runtime type` in the top bar menu, then selecting 'T4 GPU'
under 'Hardware accelerator'.


In [2]:
!pip install braindecode
!pip install eegdash



## Prepare dataset

In [3]:
from pathlib import Path

DATA_DIR = Path("data")
DATA_DIR.mkdir(parents=True, exist_ok=True)

from eegdash.dataset import EEGChallengeDataset

dataset_ccd = EEGChallengeDataset(task="contrastChangeDetection",
                                  release="R5", cache_dir=DATA_DIR,
                                  mini=False)

In [4]:
from joblib import Parallel, delayed

raws = Parallel(n_jobs=-1)(
    delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets
)

Downloading dataset_description.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 10.4B/s]
Downloading dataset_description.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 13.2B/s]

Downloading dataset_description.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 23.2B/s]
Downloading participants.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.78B/s]
Downloading participants.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.9B/s]
Downloading participants.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.1B/s]
Downloading participants.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 13.0B/s]
Downloading sub-NDARAC857HDB_task-DiaryOfAWimpyKid_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 20.5B/s] ?B/s]
Downloading sub-NDARAC350XUM_task-contrastChangeDetection_run-2_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 15.5B/s]
Downloading sub-NDARAC350XUM_task-contrastChangeDetection_run-2_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 17.7B/s]
Downloading sub-NDARAC857HDB_task-Despicable

## Prepare epochs

In [5]:
from braindecode.datasets import BaseConcatDataset

In [6]:
from braindecode.preprocessing import preprocess, Preprocessor, create_windows_from_events
from eegdash.hbn.windows import (
    annotate_trials_with_target,
    add_aux_anchors,
    add_extras_columns,
    keep_only_recordings_with,
)

EPOCH_LEN_S = 2.0
SFREQ = 100 # by definition here

transformation_offline = [
    Preprocessor(
        annotate_trials_with_target,
        target_field="rt_from_stimulus", epoch_length=EPOCH_LEN_S,
        require_stimulus=True, require_response=True,
        apply_on_array=False,
    ),
    Preprocessor(add_aux_anchors, apply_on_array=False),
]
preprocess(dataset_ccd, transformation_offline, n_jobs=1)

ANCHOR = "stimulus_anchor"

SHIFT_AFTER_STIM = 0.5
WINDOW_LEN       = 2.0

# Keep only recordings that actually contain stimulus anchors
dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)

# Create single-interval windows (stim-locked, long enough to include the response)
single_windows = create_windows_from_events(
    dataset,
    mapping={ANCHOR: 0},
    trial_start_offset_samples=int(SHIFT_AFTER_STIM * SFREQ),                 # +0.5 s
    trial_stop_offset_samples=int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),   # +2.5 s
    window_size_samples=int(EPOCH_LEN_S * SFREQ),
    window_stride_samples=SFREQ,
    preload=True,
)

# Injecting metadata into the extra mne annotation.
single_windows = add_extras_columns(
    single_windows,
    dataset,
    desc=ANCHOR,
    keys=("target", "rt_from_stimulus", "rt_from_trialstart",
          "stimulus_onset", "response_onset", "correct", "response_type")
          )

Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(

In [7]:
meta_information = single_windows.get_metadata()

In [8]:
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state

valid_frac = 0.1
test_frac = 0.1
seed = 2025

subjects = meta_information["subject"].unique()
sub_rm = ["NDARWV769JM7", "NDARME789TD2", "NDARUA442ZVF", "NDARJP304NK1",
          "NDARTY128YLU", "NDARDW550GU6", "NDARLD243KRE", "NDARUJ292JXV", "NDARBA381JGH"]
subjects = [s for s in subjects if s not in sub_rm]

train_subj, valid_test_subject = train_test_split(
    subjects, test_size=(valid_frac + test_frac), random_state=check_random_state(seed), shuffle=True
)

valid_subj, test_subj = train_test_split(
    valid_test_subject, test_size=test_frac, random_state=check_random_state(seed + 1), shuffle=True
)
# sanity check
assert (set(valid_subj) | set(test_subj) | set(train_subj)) == set(subjects)

In [9]:
# and finally using braindecode split function, we can do:
subject_split = single_windows.split("subject")

train_set = []
valid_set = []
test_set = []

for s in subject_split:
    if s in train_subj:
        train_set.append(subject_split[s])
    elif s in valid_subj:
        valid_set.append(subject_split[s])
    elif s in test_subj:
        test_set.append(subject_split[s])

train_set = BaseConcatDataset(train_set)
valid_set = BaseConcatDataset(valid_set)
test_set = BaseConcatDataset(test_set)

print("Number of examples in each split in the minirelease")
print(f"Train:\t{len(train_set)}")
print(f"Valid:\t{len(valid_set)}")
print(f"Test:\t{len(test_set)}")

Number of examples in each split in the minirelease
Train:	12100
Valid:	2714
Test:	330


In [10]:
# Create datasets and dataloaders
from torch.utils.data import DataLoader

batch_size = 128
num_workers = 1 # We are using a single worker, but you can increase this for faster data loading

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

## Build a simple model

from braindecode.models.util import models_dict

names = sorted(models_dict)
w = max(len(n) for n in names)

for i in range(0, len(names), 3):
    row = names[i:i+3]
    print("  ".join(f"{n:<{w}}" for n in row))

In [11]:
# for any braindecode model, you can initialize only inputing the signal related parameters
from braindecode.models import EEGNeX

model = EEGNeX(n_chans=129, # 129 channels
                n_outputs=1, # 1 output for regression
                n_times=200, #2 seconds
                sfreq=100,      # sample frequency 100 Hz
                )

In [12]:
# Defining training parameters

lr = 1E-3
weight_decay = 1E-5
n_epochs = 100
early_stopping_patience = 50

## Train

In [13]:
from typing import Optional
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler

# Define a method for training one epoch
def train_one_epoch(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    optimizer,
    scheduler: Optional[LRScheduler],
    epoch: int,
    device,
    print_batch_stats: bool = True,
):
    model.train()

    total_loss = 0.0
    sum_sq_err = 0.0
    n_samples = 0

    progress_bar = tqdm(
        enumerate(dataloader), total=len(dataloader), disable=not print_batch_stats
    )

    for batch_idx, batch in progress_bar:
        # Support datasets that may return (X, y) or (X, y, ...)
        X, y = batch[0], batch[1]
        X, y = X.to(device).float(), y.to(device).float()

        optimizer.zero_grad(set_to_none=True)
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Flatten to 1D for regression metrics and accumulate squared error
        preds_flat = preds.detach().view(-1)
        y_flat = y.detach().view(-1)
        sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
        n_samples += y_flat.numel()

        if print_batch_stats:
            running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
            progress_bar.set_description(
                f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)}, "
                f"Loss: {loss.item():.6f}, RMSE: {running_rmse:.6f}"
            )

    if scheduler is not None:
        scheduler.step()

    avg_loss = total_loss / len(dataloader)
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
    return avg_loss, rmse

In [14]:
import torch
from torch.utils.data import DataLoader
from torch.nn import Module
from tqdm import tqdm

@torch.no_grad()
def valid_model(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    device,
    print_batch_stats: bool = True,
):
    model.eval()

    total_loss = 0.0
    sum_sq_err = 0.0
    n_batches = len(dataloader)
    n_samples = 0

    iterator = tqdm(
        enumerate(dataloader),
        total=n_batches,
        disable=not print_batch_stats
    )

    for batch_idx, batch in iterator:
        # Supports (X, y) or (X, y, ...)
        X, y = batch[0], batch[1]
        X, y = X.to(device).float(), y.to(device).float()
        # casting X to float32

        preds = model(X)
        batch_loss = loss_fn(preds, y).item()
        total_loss += batch_loss

        preds_flat = preds.detach().view(-1)
        y_flat = y.detach().view(-1)
        sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
        n_samples += y_flat.numel()

        if print_batch_stats:
            running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
            iterator.set_description(
                f"Val Batch {batch_idx + 1}/{n_batches}, "
                f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}"
            )

    avg_loss = total_loss / n_batches if n_batches else float("nan")
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5

    print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}\n")
    return avg_loss, rmse


In [None]:
import copy

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs - 1)
loss_fn = torch.nn.MSELoss()

patience = 5
min_delta = 1e-4
best_rmse = float("inf")
epochs_no_improve = 0
best_state, best_epoch = None, None

for epoch in range(1, n_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}: ", end="")

    train_loss, train_rmse = train_one_epoch(
        train_loader, model, loss_fn, optimizer, scheduler, epoch, device
    )
    val_loss, val_rmse = valid_model(test_loader, model, loss_fn, device)

    print(
        f"Train RMSE: {train_rmse:.6f}, "
        f"Average Train Loss: {train_loss:.6f}, "
        f"Val RMSE: {val_rmse:.6f}, "
        f"Average Val Loss: {val_loss:.6f}"
    )

    if val_rmse < best_rmse - min_delta:
        best_rmse = val_rmse
        best_state = copy.deepcopy(model.state_dict())
        best_epoch = epoch
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch}. Best Val RMSE: {best_rmse:.6f} (epoch {best_epoch})")
            break

if best_state is not None:
    model.load_state_dict(best_state)


Epoch 1/100: 

  return F.conv2d(
Epoch 1, Batch 95/95, Loss: 0.273379, RMSE: 0.922488: 100%|██████████| 95/95 [13:17<00:00,  8.39s/it]
Val Batch 3/3, Loss: 0.170270, RMSE: 0.475534: 100%|██████████| 3/3 [00:06<00:00,  2.07s/it]

Val RMSE: 0.475534, Val Loss: 0.218277

Train RMSE: 0.922488, Average Train Loss: 0.848134, Val RMSE: 0.475534, Average Val Loss: 0.218277
Epoch 2/100: 


Epoch 2, Batch 95/95, Loss: 0.242640, RMSE: 0.514397: 100%|██████████| 95/95 [12:54<00:00,  8.15s/it]
Val Batch 3/3, Loss: 0.212414, RMSE: 0.544624: 100%|██████████| 3/3 [00:06<00:00,  2.15s/it]

Val RMSE: 0.544624, Val Loss: 0.284774

Train RMSE: 0.514397, Average Train Loss: 0.264496, Val RMSE: 0.544624, Average Val Loss: 0.284774
Epoch 3/100: 


Epoch 3, Batch 95/95, Loss: 0.224454, RMSE: 0.489990: 100%|██████████| 95/95 [13:02<00:00,  8.24s/it]
Val Batch 3/3, Loss: 0.147555, RMSE: 0.419167: 100%|██████████| 3/3 [00:06<00:00,  2.17s/it]

Val RMSE: 0.419167, Val Loss: 0.171743

Train RMSE: 0.489990, Average Train Loss: 0.240013, Val RMSE: 0.419167, Average Val Loss: 0.171743
Epoch 4/100: 


Epoch 4, Batch 95/95, Loss: 0.259263, RMSE: 0.475527: 100%|██████████| 95/95 [13:14<00:00,  8.37s/it]
Val Batch 3/3, Loss: 0.140835, RMSE: 0.390676: 100%|██████████| 3/3 [00:06<00:00,  2.26s/it]

Val RMSE: 0.390676, Val Loss: 0.150969

Train RMSE: 0.475527, Average Train Loss: 0.226289, Val RMSE: 0.390676, Average Val Loss: 0.150969
Epoch 5/100: 


Epoch 5, Batch 95/95, Loss: 0.211513, RMSE: 0.467395: 100%|██████████| 95/95 [12:56<00:00,  8.17s/it]
Val Batch 3/3, Loss: 0.140249, RMSE: 0.384047: 100%|██████████| 3/3 [00:06<00:00,  2.21s/it]

Val RMSE: 0.384047, Val Loss: 0.146474

Train RMSE: 0.467395, Average Train Loss: 0.218424, Val RMSE: 0.384047, Average Val Loss: 0.146474
Epoch 6/100: 


Epoch 6, Batch 95/95, Loss: 0.237862, RMSE: 0.459975: 100%|██████████| 95/95 [13:05<00:00,  8.27s/it]
Val Batch 3/3, Loss: 0.140447, RMSE: 0.386190: 100%|██████████| 3/3 [00:06<00:00,  2.30s/it]

Val RMSE: 0.386190, Val Loss: 0.147920

Train RMSE: 0.459975, Average Train Loss: 0.211707, Val RMSE: 0.386190, Average Val Loss: 0.147920
Epoch 7/100: 


Epoch 7, Batch 95/95, Loss: 0.188278, RMSE: 0.456022: 100%|██████████| 95/95 [11:35<00:00,  7.32s/it]
Val Batch 3/3, Loss: 0.140061, RMSE: 0.381173: 100%|██████████| 3/3 [00:05<00:00,  1.89s/it]

Val RMSE: 0.381173, Val Loss: 0.144557

Train RMSE: 0.456022, Average Train Loss: 0.207859, Val RMSE: 0.381173, Average Val Loss: 0.144557
Epoch 8/100: 


Epoch 8, Batch 95/95, Loss: 0.179335, RMSE: 0.453204: 100%|██████████| 95/95 [11:35<00:00,  7.32s/it]
Val Batch 3/3, Loss: 0.140180, RMSE: 0.380214: 100%|██████████| 3/3 [00:05<00:00,  1.73s/it]

Val RMSE: 0.380214, Val Loss: 0.143946

Train RMSE: 0.453204, Average Train Loss: 0.205266, Val RMSE: 0.380214, Average Val Loss: 0.143946
Epoch 9/100: 


Epoch 9, Batch 95/95, Loss: 0.190616, RMSE: 0.453018: 100%|██████████| 95/95 [11:02<00:00,  6.97s/it]
Val Batch 3/3, Loss: 0.140090, RMSE: 0.371019: 100%|██████████| 3/3 [00:05<00:00,  1.71s/it]

Val RMSE: 0.371019, Val Loss: 0.137998

Train RMSE: 0.453018, Average Train Loss: 0.205153, Val RMSE: 0.371019, Average Val Loss: 0.137998
Epoch 10/100: 


Epoch 10, Batch 95/95, Loss: 0.286772, RMSE: 0.451077: 100%|██████████| 95/95 [11:02<00:00,  6.97s/it]
Val Batch 3/3, Loss: 0.140758, RMSE: 0.366721: 100%|██████████| 3/3 [00:05<00:00,  1.68s/it]

Val RMSE: 0.366721, Val Loss: 0.135366

Train RMSE: 0.451077, Average Train Loss: 0.203881, Val RMSE: 0.366721, Average Val Loss: 0.135366
Epoch 11/100: 


Epoch 11, Batch 95/95, Loss: 0.120321, RMSE: 0.449219: 100%|██████████| 95/95 [11:01<00:00,  6.96s/it]
Val Batch 3/3, Loss: 0.140072, RMSE: 0.378154: 100%|██████████| 3/3 [00:06<00:00,  2.03s/it]

Val RMSE: 0.378154, Val Loss: 0.142589

Train RMSE: 0.449219, Average Train Loss: 0.201395, Val RMSE: 0.378154, Average Val Loss: 0.142589
Epoch 12/100: 


Epoch 12, Batch 8/95, Loss: 0.178507, RMSE: 0.451870:   8%|▊         | 8/95 [00:57<10:23,  7.17s/it]