In [100]:
import numpy as np
from numpy import multiply
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import preprocess, Preprocessor, create_windows_from_events
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, cohen_kappa_score

In [101]:
# ==================== CONFIG ====================
SUBJECTS = list(range(1,10))  # 1 to 9
BATCH_SIZE = 32
EPOCHS = 50
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [102]:
class EEGDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)   # [trials, channels, time]
        self.labels   = torch.tensor(labels,   dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

In [103]:
class BiGRUModel(nn.Module):
    def __init__(self, num_channels=22, num_classes=4, hidden_dim=128, dropout=0.5):
        super().__init__()
        # Spatial convolution to reduce channel dimension (common in EEG DL)
        self.spatial = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(num_channels, 1), stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(dropout)
        )

        self.bigru = nn.GRU(input_size=32,
                            hidden_size=hidden_dim,
                            num_layers=2,
                            batch_first=True,
                            bidirectional=True,
                            dropout=dropout)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        # x: [batch, channels, time]
        batch = x.size(0)
        x = x.unsqueeze(1)                     # [batch, 1, channels, time]
        x = self.spatial(x)                    # [batch, 32, 1, time]
        x = x.squeeze(2)                       # [batch, 32, time]
        x = x.permute(0, 2, 1)                 # [batch, time, 32] for GRU
        _, hn = self.bigru(x)                  # hn: [4, batch, hidden] (2 layers × bidirectional)
        hn = hn.view(2, 2, batch, -1)          # separate layers & directions
        hn = hn[-1]                            # take top layer
        hn = hn.transpose(0, 1).contiguous().view(batch, -1)  # [batch, hidden*2]
        out = self.classifier(hn)
        return out

In [None]:
results = {}

for subj in SUBJECTS:  # SUBJECTS = list(range(1,10))
    print(f"\n=== Processing subject A0{subj} ===")

    # Load data
    dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subj])

    # Preprocessing (scaling first if needed; BCI IV-2a is already in µV, so optional)
    preprocessors = [
        Preprocessor('pick_types', eeg=True, stim=False),
        Preprocessor(lambda x: multiply(x, 1e6)),  # Optional V → µV
        Preprocessor('filter', l_freq=8, h_freq=30),
    ]
    preprocess(dataset, preprocessors)

    # Create windows (4s trials)
    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=0,
        trial_stop_offset_samples=0,
        preload=True
    )

    # Debug print keys (should show ['0train', '1test'])
    print("Available session keys:", list(windows_dataset.split('session').keys()))

    # Split sessions
    splits = windows_dataset.split('session')
    train_val_set = splits['0train']  # Training session (T)
    test_set      = splits['1test']   # Evaluation session (E)

    # Split training session into train/val by run (usually 2 runs)
    run_splits = train_val_set.split('run')
    print("Run keys:", list(run_splits.keys()))  # Usually ['0', '1']

    train_set = run_splits['0']
    val_set   = run_splits['1']

    # DataLoaders
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE)
    test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE)

    # Model setup
    model = BiGRUModel().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)  # Use your LR=0.001

    best_val_acc = 0.0
    best_model_state = None

    for epoch in range(1, EPOCHS + 1):
        # Training
        model.train()
        for X, y, _ in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(X)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        val_preds, val_true = [], []
        with torch.no_grad():
            for X, y, _ in val_loader:
                X = X.to(DEVICE)
                outputs = model(X)
                preds = outputs.argmax(dim=1)
                val_preds.extend(preds.cpu().numpy())
                val_true.extend(y.numpy())

        val_acc = accuracy_score(val_true, val_preds)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()

        if epoch % 10 == 0 or epoch == EPOCHS:
            print(f"Epoch {epoch:2d} | Val Acc: {val_acc:.4f} (Best: {best_val_acc:.4f})")

    # Load best model and test on E session
    model.load_state_dict(best_model_state)
    model.eval()
    test_preds, test_true = [], []
    with torch.no_grad():
        for X, y,_ in test_loader:
            X = X.to(DEVICE)
            outputs = model(X)
            preds = outputs.argmax(dim=1)
            test_preds.extend(preds.cpu().numpy())
            test_true.extend(y.numpy())

    acc = accuracy_score(test_true, test_preds)
    kappa = cohen_kappa_score(test_true, test_preds)

    results[f'A0{subj}'] = {'accuracy': acc, 'kappa': kappa}
    print(f"{f'A0{subj}'} Test Accuracy: {acc:.4f} | Kappa: {kappa:.4f}")

# Final summary
accs = [v['accuracy'] for v in results.values()]
kappas = [v['kappa'] for v in results.values()]

print("\n=== FINAL RESULTS (Subject-Dependent) ===")
for subj, res in results.items():
    print(f"{subj}: Acc = {res['accuracy']:.4f}, Kappa = {res['kappa']:.4f}")

print(f"\nAverage Accuracy: {np.mean(accs):.4f} (±{np.std(accs):.3f})")
print(f"Average Kappa:    {np.mean(kappas):.4f} (±{np.std(kappas):.3f})")


=== Processing subject A01 ===
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


  warn(
  warn("Preprocessing choices with lambda functions cannot be saved.")


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition b

ValueError: too many values to unpack (expected 2)