In [None]:
%load_ext autoreload
%autoreload 2
import torch
from modules.competition_dataset import EEGDataset, LABELS
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.covariance import EmpiricalCovariance
import random
import numpy as np
from sklearn.manifold import TSNE
from sklearn.feature_selection import f_classif
from sklearn.manifold import trustworthiness
from sklearn.metrics import silhouette_score
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.metrics import accuracy_score

from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace, 
from pyriemann.classification import MDM
import mne

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


device(type='cpu')

In [2]:
data_path = './data/mtcaic3'
lda_model_path = './checkpoints/mi/models/lda_mi.pkl'

# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_random_seeds(42)

In [3]:
window_length = 256
stride = window_length // 3
batch_size = 64

In [4]:
eeg_channels = [
    "FZ",
    "C3",
    "CZ",
    "C4",
    "PZ",
    "PO7",
    "OZ",
    "PO8",
]

dataset_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=0.2,
    hardcoded_mean=False,
    eeg_channels=eeg_channels,
)

dataset_val = EEGDataset(
    data_path=data_path,
    window_length=window_length,
    stride=stride,
    task='mi',
    split='validation',
    read_labels=True,
    hardcoded_mean=False,
    eeg_channels=eeg_channels,
)

dataset_test = EEGDataset(
    data_path=data_path,
    window_length=window_length,
    stride=stride,
    task='mi',
    split='test',
    read_labels=False,
    hardcoded_mean=False,
    eeg_channels=eeg_channels,
)

task: mi, split: train, domain: time, data_fraction: 0.2
Using 20.0% of data: 480/480 samples
skipped: 1/480
task: mi, split: validation, domain: time, data_fraction: 1.0
skipped: 0/50
task: mi, split: test, domain: time, data_fraction: 1.0
skipped: 0/50


In [5]:
dataset_train[0][0].shape

torch.Size([8, 256])

In [6]:
all_data = torch.cat([torch.stack([x for x,_ in ds]) for ds in (dataset_train, dataset_val, dataset_test)])
X_val_train = torch.cat([torch.stack([x for x,_ in ds]) for ds in (dataset_train, dataset_val)])
y_val_train = torch.cat([torch.stack([y for _,y in ds]) for ds in (dataset_train, dataset_val)])

mean = all_data.mean((0, 2))
std = all_data.std((0, 2))

X_val_train = (X_val_train - mean[None, :, None]) / std[None, :, None]

mean, std

(tensor([-0.1030, -0.0458,  0.1164,  0.0457, -0.0204, -0.0656, -0.0017, -0.0215]),
 tensor([1.1271, 0.9829, 1.0851, 1.0056, 0.9693, 1.0174, 0.9729, 0.9669]))

In [7]:
import numpy as np
from sklearn.feature_selection import f_classif

# Concatenate all splits (add dataset_val and dataset_test if needed)
X_all = np.concatenate([
    dataset_train.data.numpy(),
    dataset_val.data.numpy(),
    dataset_test.data.numpy(),
], axis=0)  # shape: [N_total, C, ...]
y_all = np.concatenate([
    dataset_train.labels.numpy(),
    dataset_val.labels.numpy(),
    dataset_test.labels.numpy(),
], axis=0)  # shape: [N_total]

# Detect shape and adapt
if X_all.ndim == 3:
    # [B, C, T]
    num_samples, num_channels, time_points = X_all.shape
    channel_f_scores = []
    for i in range(num_channels):
        channel_data = X_all[:, i, :]  # [N_total, T]
        f_scores_per_timepoint, _ = f_classif(channel_data, y_all)
        aggregated_f_score = np.sum(f_scores_per_timepoint)
        channel_f_scores.append(aggregated_f_score)
elif X_all.ndim == 4:
    # [B, C, F, T]
    num_samples, num_channels, freq_points, time_points = X_all.shape
    channel_f_scores = []
    for i in range(num_channels):
        # Average over freq and time for each channel
        channel_data = X_all[:, i, :, :].mean(axis=(1, 2))  # [N_total]
        f_score, _ = f_classif(channel_data.reshape(-1, 1), y_all)
        channel_f_scores.append(f_score[0])
else:
    raise ValueError(f"Unsupported data shape: {X_all.shape}")

# Optionally, map to channel names
original_channel_names = eeg_channels
channel_scores_dict = {original_channel_names[i]: channel_f_scores[i] for i in range(num_channels)}

print("\n--- F-scores for each channel (higher score indicates more informativeness) ---")
sorted_channels = sorted(channel_scores_dict.items(), key=lambda item: item[1], reverse=True)
for channel, score in sorted_channels:
    print(f"  {channel}: {score:.2f}")

top_3_channels = [channel for channel, score in sorted_channels[:3]]
print(f"\n--- Recommended Top 3 Channels based on F-score: {top_3_channels} ---")


--- F-scores for each channel (higher score indicates more informativeness) ---
  CZ: 12471.80
  C4: 9797.40
  FZ: 9445.39
  OZ: 5740.98
  C3: 4108.73
  PO7: 3960.75
  PO8: 3181.95
  PZ: 1084.31

--- Recommended Top 3 Channels based on F-score: ['CZ', 'C4', 'FZ'] ---


In [8]:
# Example for train/val/test
X_train = np.stack([x.numpy() for x, y in dataset_train])  # shape: [N, C, T]
y_train = np.array([y[0] for x, y in dataset_train])

X_val = np.stack([x.numpy() for x, y in dataset_val])
y_val = np.array([y[0] for x, y in dataset_val])

X_test = np.stack([x.numpy() for x, y in dataset_test])
y_test = np.array([y[0] for x, y in dataset_test])

In [None]:
# SVM
# Best CV score:    0.5679161628375655
# Best parameters:  {'clf__tol': 0.001, 'clf__kernel': 'rbf', 'clf__gamma': 0.01, 'clf__class_weight': 'balanced', 'clf__C': 10}
# Validation accuracy: 0.5145502645502645

param_grid_svm = {
    "clf__kernel": ["rbf", "linear"],  # Drop poly - research shows overfitting
    "clf__C": [0.01, 0.1, 1, 10, 50],  # Add lower values
    "clf__gamma": [0.001, 0.01, 0.1, "scale"],  # Finer granularity
    "clf__class_weight": ["balanced"],
    "clf__tol": [1e-3, 1e-4],
    # Remove degree/coef0 (irrelevant for non-poly kernels)
}

pipeline = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("tangent", TangentSpace(metric="riemann")),
        ("clf", SVC(random_state=42)),
    ]
)

grid = RandomizedSearchCV(
    estimator=pipeline,
    n_iter=50,
    param_distributions=param_grid_svm,
    # param_grid=param_grid,
    cv=3,
    scoring="accuracy",
    n_jobs=-1,
    verbose=2,
)

grid.fit(X_train, y_train)

# 3) Inspect best params & CV score
print("Best CV score:   ", grid.best_score_)
print("Best parameters: ", grid.best_params_)

# # 4) Evaluate on validation set
best_model = grid.best_estimator_
y_val_pred = best_model.predict(X_val)
val_acc = accuracy_score(y_val, y_val_pred)
print("Validation accuracy:", val_acc)

Fitting 3 folds for each of 50 candidates, totalling 150 fits
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=scale, clf__kernel=linear, clf__tol=0.001; total time=   7.0s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=scale, clf__kernel=linear, clf__tol=0.001; total time=   7.5s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=scale, clf__kernel=linear, clf__tol=0.001; total time=   7.9s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=linear, clf__tol=0.0001; total time=   7.9s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=linear, clf__tol=0.0001; total time=   8.4s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=rbf, clf__tol=0.001; total time=   8.6s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=rbf, clf__tol=0.001; total time=   8.8s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=r

In [41]:
# Logistic Regression 
# Best CV score:    0.5507187961843343
# Best parameters:  {'clf__tol': 0.0001, 'clf__solver': 'sag', 'clf__penalty': 'l2', 'clf__max_iter': 1000, 'clf__class_weight': 'balanced', 'clf__C': 0.001}
# Validation accuracy: 0.5066137566137566

param_grid_lr = {
    "clf__penalty": ["l2", None],
    "clf__C": [0.001, 0.01, 0.1, 1, 10],
    "clf__solver": ["lbfgs", "sag"],
    "clf__max_iter": [1000],
    "clf__class_weight": ["balanced"],
    "clf__tol": [1e-4]
}

pipeline = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("tangent", TangentSpace(metric="riemann")),
        ("clf", LogisticRegression(random_state=42)),
    ]
)

grid_lr = RandomizedSearchCV(
    estimator=pipeline,
    n_iter=70,
    param_distributions=param_grid_lr,
    # param_grid=param_grid,
    cv=3,
    scoring="accuracy",
    n_jobs=-1,
    verbose=2,
)

grid_lr.fit(X_train, y_train)

# 3) Inspect best params & CV score
print("Best CV score:   ", grid_lr.best_score_)
print("Best parameters: ", grid_lr.best_params_)

# # 4) Evaluate on validation set
best_model = grid_lr.best_estimator_
y_val_pred = best_model.predict(X_val)
val_acc = accuracy_score(y_val, y_val_pred)
print("Validation accuracy:", val_acc)



Fitting 3 folds for each of 20 candidates, totalling 60 fits




[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.0s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.0s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.2s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.3s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.7s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.9s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.3s
[CV



[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   5.9s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.2s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.1s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.6s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.5s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.8s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.1s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.4s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.0s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.4s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   6.5s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   8.0s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.3s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.2s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.8s




[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.4s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.8s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.0s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.6s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.9s




[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   6.5s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.0s




[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   6.8s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.9s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.8s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.3s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.8s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.1s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.5s
[CV] END clf__C=1, clf__cla



[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.7s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.0s




[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.5s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.4s




[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   5.5s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   5.4s




[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   5.3s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   4.1s
Best CV score:    0.5507187961843343
Best parameters:  {'clf__tol': 0.0001, 'clf__solver': 'sag', 'clf__penalty': 'l2', 'clf__max_iter': 1000, 'clf__class_weight': 'balanced', 'clf__C': 0.001}
Validation accuracy: 0.5066137566137566


In [45]:
# # MDM

# param_grid_mdm = {
#     "clf__metric": ["riemann"],
#     "clf__n_means": [3, 5, 7],  # Number of power means
#     "clf__h_values": [
#         [-1, 0, 1], 
#         [-0.5, 0, 0.5],
#         [-1, -0.2, 0.2, 1]
#     ],  # Power parameters
#     "clf__mean_type": ["power"]
# }

pipeline = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("clf", MDM(metric="riemann")),
    ]
)

# grid_mdm = RandomizedSearchCV(
#     estimator=pipeline,
#     n_iter=70,
#     param_distributions=param_grid_mdm,
#     # param_grid=param_grid,
#     cv=3,
#     scoring="accuracy",
#     n_jobs=-1,
#     verbose=2,
# )

pipeline.fit(X_train, y_train)

# 3) Inspect best params & CV score
# print("Best CV score:   ", pipeline.best_score_)
# print("Best parameters: ", pipeline.best_params_)

# # 4) Evaluate on validation set
# best_model = pipeline.best_estimator_
y_val_pred = pipeline.predict(X_val)
val_acc = accuracy_score(y_val, y_val_pred)
print("Validation accuracy:", val_acc)

Validation accuracy: 0.5185185185185185
