## Install and import required libraries

In [None]:
import math
import os
import random
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.fft import fft, ifft, fftfreq
from scipy.signal import firwin, freqz, lfilter, welch
from sklearn import metrics
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

!pip install mne
!pip install moabb
!pip install braindecode

import mne
import moabb
from mne.decoding import CSP
from moabb.datasets import BNCI2014_001
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import LeftRightImagery




In [None]:
from braindecode.datasets import MOABBDataset
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids = [i for i in range(1, 10)])

import numpy as np

from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

preprocessors = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(
        lambda data, factor: np.multiply(data, factor),  # Convert from V to uV
        factor=1e6,
    ),
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Preprocess the data
preprocess(dataset, preprocessors, n_jobs=-1)


from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)
splitted = windows_dataset.split("session")
train_set = splitted['0train']  # Session train
test_set = splitted['1test']  # Session evaluation
batch_size = 64
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
for batch_idx, (X, y, _) in progress_bar:
  print(X.shape, y.shape)
  print(y)
  break


  set_config(key, get_config("MNE_DATA"))
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01T.mat'.


MNE_DATA is not already configured. It will be set to default location in the home directory - /root/mne_data
All datasets will be downloaded to this location, if anything is already downloaded, please move manually to this location


100%|█████████████████████████████████████| 42.8M/42.8M [00:00<00:00, 25.7GB/s]
SHA256 hash of downloaded file: 054f02e70cf9c4ada1517e9b9864f45407939c1062c6793516585c6f511d0325
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01E.mat'.
100%|█████████████████████████████████████| 43.8M/43.8M [00:00<00:00, 45.8GB/s]
SHA256 hash of downloaded file: 53d415f39c3d7b0c88b894d7b08d99bcdfe855ede63831d3691af1a45607fb62
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A02T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A02T.mat'.
100%|█████████████████████

Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

  0%|          | 0/41 [00:00<?, ?it/s]

torch.Size([64, 22, 1125]) torch.Size([64])
tensor([2, 1, 3, 2, 2, 3, 1, 1, 3, 2, 3, 0, 3, 2, 1, 1, 0, 0, 0, 2, 3, 0, 1, 2,
        2, 0, 1, 0, 0, 3, 0, 3, 2, 0, 0, 0, 2, 3, 1, 3, 2, 2, 3, 2, 3, 3, 0, 2,
        0, 2, 2, 1, 3, 3, 2, 2, 1, 2, 0, 3, 3, 0, 3, 2])


In [None]:
import plotly.graph_objects as go

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=X[0,0]))
fig.show()

In [None]:
import torch
from torch import nn, optim
from tqdm import tqdm

# Braindecode imports
from braindecode.models import EEGConformer
from braindecode.util import set_random_seeds   # nice helper
from braindecode.classifier import EEGClassifier

# ------------------------------------------------------------------
# 1.  Work out signal-related numbers from the WindowsDataset
# ------------------------------------------------------------------
n_chans  = train_set[0][0].shape[0]          # 26 EEG/EOG channels
n_times  = train_set[0][0].shape[1]          # 1125 time samples (≈4.5 s @250 Hz)
sfreq    = train_set.datasets[0].raw.info['sfreq']   # 250 Hz for BNCI 2014-001
n_outputs = 4 # 4 motor-imagery classes

# ------------------------------------------------------------------
# 2.  Define the network
# ------------------------------------------------------------------
model = EEGConformer(
    n_outputs   = n_outputs,
    n_chans     = n_chans,
    n_times     = n_times,      # mandatory, otherwise ValueError
    sfreq       = sfreq,        # lets Braindecode infer input_window_seconds
    final_fc_length="auto",     # will infer correct size
    # leave the rest at their paper defaults (filter_time_length=25, pool_time_length=75, ...)
)
print(model)

# ------------------------------------------------------------------
# 4.  …or use a plain PyTorch training loop instead
# ------------------------------------------------------------------
device   = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
epochs    = 50

set_random_seeds(seed=42, cuda=device.type=='cuda')

accs = []
for epoch in range(epochs):
    # ---------- training ----------
    model.train()
    running_loss = 0
    for X, y, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        X, y = X.to(device, dtype=torch.float32), y.to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss   = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * X.size(0)

    # ---------- evaluation ----------
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X, y, _ in test_loader:
            X, y = X.to(device, dtype=torch.float32), y.to(device)
            preds = model(X).argmax(dim=1)
            correct += (preds == y).sum().item()
            total   += y.size(0)
    acc = correct / total
    accs.append(acc)
    print(f"epoch {epoch+1:>3d}: train-loss={running_loss/len(train_set):.4f} "
          f"test-acc={acc*100:.2f}%")



LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!



Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #                   Kernel Shape
EEGConformer (EEGConformer)                                  [1, 22, 1125]             [1, 4]                    --                        --
├─_PatchEmbedding (patch_embedding): 1-1                     [1, 1, 22, 1125]          [1, 69, 40]               --                        --
│    └─Sequential (shallownet): 2-1                          [1, 1, 22, 1125]          [1, 40, 1, 69]            --                        --
│    │    └─Conv2d (0): 3-1                                  [1, 1, 22, 1125]          [1, 40, 22, 1101]         1,040                     [1, 25]
│    │    └─Conv2d (1): 3-2                                  [1, 40, 22, 1101]         [1, 40, 1, 1101]          35,240                    [22, 1]
│    │    └─BatchNorm2d (2): 3-3                             [1, 40, 1, 1101]          [1, 40, 1, 1101]          80             

Epoch 1/50: 100%|██████████| 41/41 [00:03<00:00, 12.61it/s]


epoch   1: train-loss=1.4033 test-acc=27.39%


Epoch 2/50: 100%|██████████| 41/41 [00:03<00:00, 12.63it/s]


epoch   2: train-loss=1.4026 test-acc=30.90%


Epoch 3/50: 100%|██████████| 41/41 [00:03<00:00, 12.75it/s]


epoch   3: train-loss=1.3908 test-acc=32.14%


Epoch 4/50: 100%|██████████| 41/41 [00:03<00:00, 12.56it/s]


epoch   4: train-loss=1.3526 test-acc=33.95%


Epoch 5/50: 100%|██████████| 41/41 [00:03<00:00, 12.33it/s]


epoch   5: train-loss=1.3398 test-acc=34.07%


Epoch 6/50: 100%|██████████| 41/41 [00:03<00:00, 12.51it/s]


epoch   6: train-loss=1.3320 test-acc=32.99%


Epoch 7/50: 100%|██████████| 41/41 [00:03<00:00, 12.44it/s]


epoch   7: train-loss=1.3102 test-acc=35.92%


Epoch 8/50: 100%|██████████| 41/41 [00:03<00:00, 12.25it/s]


epoch   8: train-loss=1.3133 test-acc=33.80%


Epoch 9/50: 100%|██████████| 41/41 [00:03<00:00, 12.51it/s]


epoch   9: train-loss=1.2968 test-acc=37.35%


Epoch 10/50: 100%|██████████| 41/41 [00:03<00:00, 12.39it/s]


epoch  10: train-loss=1.2979 test-acc=36.42%


Epoch 11/50: 100%|██████████| 41/41 [00:03<00:00, 12.27it/s]


epoch  11: train-loss=1.2735 test-acc=38.89%


Epoch 12/50: 100%|██████████| 41/41 [00:03<00:00, 12.51it/s]


epoch  12: train-loss=1.2541 test-acc=41.82%


Epoch 13/50: 100%|██████████| 41/41 [00:03<00:00, 12.41it/s]


epoch  13: train-loss=1.2226 test-acc=41.36%


Epoch 14/50: 100%|██████████| 41/41 [00:03<00:00, 12.35it/s]


epoch  14: train-loss=1.2133 test-acc=40.05%


Epoch 15/50: 100%|██████████| 41/41 [00:03<00:00, 12.70it/s]


epoch  15: train-loss=1.2128 test-acc=41.13%


Epoch 16/50: 100%|██████████| 41/41 [00:03<00:00, 12.72it/s]


epoch  16: train-loss=1.1883 test-acc=43.36%


Epoch 17/50: 100%|██████████| 41/41 [00:03<00:00, 12.67it/s]


epoch  17: train-loss=1.1899 test-acc=45.33%


Epoch 18/50: 100%|██████████| 41/41 [00:03<00:00, 12.93it/s]


epoch  18: train-loss=1.1912 test-acc=43.90%


Epoch 19/50: 100%|██████████| 41/41 [00:03<00:00, 12.76it/s]


epoch  19: train-loss=1.1773 test-acc=44.21%


Epoch 20/50: 100%|██████████| 41/41 [00:03<00:00, 12.62it/s]


epoch  20: train-loss=1.1687 test-acc=42.94%


Epoch 21/50: 100%|██████████| 41/41 [00:03<00:00, 12.94it/s]


epoch  21: train-loss=1.1647 test-acc=45.99%


Epoch 22/50: 100%|██████████| 41/41 [00:03<00:00, 12.71it/s]


epoch  22: train-loss=1.1657 test-acc=46.30%


Epoch 23/50: 100%|██████████| 41/41 [00:03<00:00, 12.58it/s]


epoch  23: train-loss=1.1480 test-acc=46.57%


Epoch 24/50: 100%|██████████| 41/41 [00:03<00:00, 12.86it/s]


epoch  24: train-loss=1.1466 test-acc=47.18%


Epoch 25/50: 100%|██████████| 41/41 [00:03<00:00, 12.38it/s]


epoch  25: train-loss=1.1394 test-acc=47.45%


Epoch 26/50: 100%|██████████| 41/41 [00:03<00:00, 12.59it/s]


epoch  26: train-loss=1.1258 test-acc=46.53%


Epoch 27/50: 100%|██████████| 41/41 [00:03<00:00, 12.70it/s]


epoch  27: train-loss=1.1283 test-acc=46.88%


Epoch 28/50: 100%|██████████| 41/41 [00:03<00:00, 12.32it/s]


epoch  28: train-loss=1.1248 test-acc=46.99%


Epoch 29/50: 100%|██████████| 41/41 [00:03<00:00, 12.48it/s]


epoch  29: train-loss=1.1040 test-acc=46.41%


Epoch 30/50: 100%|██████████| 41/41 [00:03<00:00, 12.45it/s]


epoch  30: train-loss=1.1161 test-acc=48.84%


Epoch 31/50: 100%|██████████| 41/41 [00:03<00:00, 12.32it/s]


epoch  31: train-loss=1.0931 test-acc=46.10%


Epoch 32/50: 100%|██████████| 41/41 [00:03<00:00, 12.48it/s]


epoch  32: train-loss=1.0887 test-acc=48.77%


Epoch 33/50: 100%|██████████| 41/41 [00:03<00:00, 12.49it/s]


epoch  33: train-loss=1.0847 test-acc=50.19%


Epoch 34/50: 100%|██████████| 41/41 [00:03<00:00, 12.22it/s]


epoch  34: train-loss=1.0884 test-acc=48.69%


Epoch 35/50: 100%|██████████| 41/41 [00:03<00:00, 12.53it/s]


epoch  35: train-loss=1.0900 test-acc=50.00%


Epoch 36/50: 100%|██████████| 41/41 [00:03<00:00, 12.52it/s]


epoch  36: train-loss=1.0612 test-acc=48.77%


Epoch 37/50: 100%|██████████| 41/41 [00:03<00:00, 12.18it/s]


epoch  37: train-loss=1.0616 test-acc=49.15%


Epoch 38/50: 100%|██████████| 41/41 [00:03<00:00, 12.62it/s]


epoch  38: train-loss=1.0669 test-acc=51.43%


Epoch 39/50: 100%|██████████| 41/41 [00:03<00:00, 12.75it/s]


epoch  39: train-loss=1.0554 test-acc=50.31%


Epoch 40/50: 100%|██████████| 41/41 [00:03<00:00, 12.50it/s]


epoch  40: train-loss=1.0461 test-acc=51.23%


Epoch 41/50: 100%|██████████| 41/41 [00:03<00:00, 12.70it/s]


epoch  41: train-loss=1.0310 test-acc=50.62%


Epoch 42/50: 100%|██████████| 41/41 [00:03<00:00, 12.62it/s]


epoch  42: train-loss=1.0170 test-acc=50.00%


Epoch 43/50: 100%|██████████| 41/41 [00:03<00:00, 12.50it/s]


epoch  43: train-loss=1.0288 test-acc=52.89%


Epoch 44/50: 100%|██████████| 41/41 [00:03<00:00, 12.66it/s]


epoch  44: train-loss=1.0166 test-acc=51.54%


Epoch 45/50: 100%|██████████| 41/41 [00:03<00:00, 12.69it/s]


epoch  45: train-loss=1.0071 test-acc=52.12%


Epoch 46/50: 100%|██████████| 41/41 [00:03<00:00, 12.45it/s]


epoch  46: train-loss=1.0151 test-acc=53.51%


Epoch 47/50: 100%|██████████| 41/41 [00:03<00:00, 12.63it/s]


epoch  47: train-loss=1.0015 test-acc=51.97%


Epoch 48/50: 100%|██████████| 41/41 [00:03<00:00, 12.50it/s]


epoch  48: train-loss=0.9868 test-acc=52.66%


Epoch 49/50: 100%|██████████| 41/41 [00:03<00:00, 12.30it/s]


epoch  49: train-loss=0.9858 test-acc=52.31%


Epoch 50/50: 100%|██████████| 41/41 [00:03<00:00, 12.71it/s]


epoch  50: train-loss=0.9774 test-acc=52.74%


In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=accs, mode='lines+markers', name='Accuracy'))
fig.update_layout(title="Model Accuracy over Epochs", template="simple_white", width=500, height=500)
fig.update_xaxes(title="Epoch", showgrid=True, dtick=5)
fig.update_yaxes(title="Accuracy", showgrid=True, range=[0, 1], dtick=0.1)
fig.show()