In [1]:
import torch
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import preprocess, Preprocessor, create_windows_from_events, exponential_moving_standardize
from braindecode.models import EEGNetv4
from braindecode.classifier import EEGClassifier
from braindecode.util import set_random_seeds
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Subset
from skorch.callbacks import LRScheduler, Checkpoint, EarlyStopping
from skorch.helper import predefined_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
set_random_seeds(seed=42, cuda=torch.cuda.is_available())
dataset = MOABBDataset(dataset_name='BNCI2014001', subject_ids=[1])

In [None]:
#预处理
preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),
    Preprocessor('filter', l_freq=4., h_freq=38.),
    Preprocessor(exponential_moving_standardize, factor_new=0.001, init_block_size=1000)
]
preprocess(dataset, preprocessors)

In [None]:
# 创建窗口
windows_dataset = create_windows_from_events(
    dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0, preload=True
)

print(f'Windows: {len(windows_dataset)}')
print(f'Classes: {np.unique(windows_dataset.get_metadata().target)}')

In [None]:
# 划分
N = len(windows_dataset)
inds = np.arange(N)

train_inds, valid_inds = train_test_split(
    inds,
    test_size=0.2,
    stratify=windows_dataset.get_metadata().target,
    random_state=42
)

train_set = Subset(windows_dataset, train_inds)
valid_set = Subset(windows_dataset, valid_inds)

In [None]:
# EEGNet 模型
X, y, metadata = windows_dataset[0]
in_chans = X.shape[0]
input_window_samples = X.shape[1]
n_classes = len(np.unique(windows_dataset.get_metadata().target))

model = EEGNetv4(
    in_chans=in_chans,
    n_classes=n_classes,
    input_window_samples=input_window_samples
).to(device)

In [None]:
# 封装
# 自动保存 & 自动加载最优模型
checkpoint = Checkpoint(
    monitor='valid_acc_best',  # 自定义指标名
    f_params='best_model.pt',
    load_best=True  # 训练后自动加载
)

# 学习率调度器（逐步下降）
lr_scheduler = LRScheduler(
    policy=CosineAnnealingLR,
    T_max=50  # T_max=总 epochs
)

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    optimizer__lr=0.01,
    batch_size=32,  # 初始 batch size，可后面 grid search
    train_split=predefined_split(valid_set),  # 用手动验证集
    callbacks=[
        'accuracy',
        checkpoint,  # 保存最佳模型
        lr_scheduler
    ],
    device=device,
    classes=np.unique(windows_dataset.get_metadata().target)
)

In [None]:
# 训练&测试
clf.fit(train_set, y=None, epochs=50)
# 验证集
y_valid_true = [y for _, y, _ in valid_set]
y_valid_pred = clf.predict(valid_set)
print(f'✅ Final valid accuracy: {accuracy_score(y_valid_true, y_valid_pred):.4f}')

In [None]:
# 可视化
# 自动记录了 valid_acc_best
history = pd.DataFrame(clf.history)
plt.plot(history['epoch'], history['train_accuracy'], label='Train')
plt.plot(history['epoch'], history['valid_accuracy'], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.title('Train/Valid Accuracy with LR Scheduler & Checkpoint')
plt.show()
plt.plot(history['epoch'], history['train_loss'], label='Train')
plt.plot(history['epoch'], history['valid_loss'], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.title('Train/Valid Loss')
plt.show()
print(f"Best valid acc (from history): {max(history['valid_accuracy']):.4f}")

In [None]:
from skorch.helper import SliceDataset
from sklearn.model_selection import GridSearchCV

# 把 EEGClassifier 用到 GridSearchCV，需要传 X, y
y_train = [y for _, y, _ in train_set]

param_grid = {
    'batch_size': [16, 32, 64]
}

gs = GridSearchCV(
    clf,
    param_grid,
    refit=False,  # 不重新 fit，防止重复
    scoring='accuracy',
    cv=2,
    verbose=1
)

# 用 skorch 的 SliceDataset 兼容
X_train = SliceDataset(train_set, idx=0)
y_train = np.array(y_train)

gs.fit(X_train, y_train)

print(f"Grid search results:")
print(gs.cv_results_)
print(f"Best batch size: {gs.best_params_['batch_size']}")
results = pd.DataFrame(gs.cv_results_)
plt.figure(figsize=(8,5))
plt.errorbar(
    results['param_batch_size'].astype(int),
    results['mean_test_score'],
    yerr=results['std_test_score'],
    fmt='-o', capsize=5
)
plt.title('Grid Search Result: Batch Size vs CV Accuracy')
plt.xlabel('Batch Size')
plt.ylabel('Mean CV Accuracy')
plt.grid(True)
plt.show()