In [8]:
%load_ext autoreload
%autoreload 2
import torch
import random
import numpy as np
import optuna
from optuna.pruners import MedianPruner
from scipy.signal import butter, sosfiltfilt
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.feature_selection import mutual_info_classif
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.pipeline import Pipeline
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from modules.competition_dataset import EEGDataset
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, classification_report

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
data_path = './data/mtcaic3'
lda_model_path = './checkpoints/mi/models/lda_mi.pkl'

# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_random_seeds(42)

In [10]:
class FilterBankRTSClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, fs=250, order=4, n_estimators=100, max_depth=None, bands=None, min_samples_split=2, min_samples_leaf=1, max_features="sqrt", class_weight="balanced", n_jobs=-1):
        self.bands = [(8, 12), (12, 16), (16, 20), (20, 24), (24, 30)] if bands is None else bands
        self.fs = fs
        self.order = order
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.max_features = max_features
        self.class_weight = class_weight
        self.n_jobs = n_jobs

    def compute_fb_covs(self, X):
        """X: (n_trials, C, T) → fb_covs: (n_trials, B, C, C)"""
        if not hasattr(self, "sos_bands"):
            self.sos_bands = [
                butter(self.order, (l / (self.fs / 2), h / (self.fs / 2)), btype="bandpass", output="sos")
                for l, h in self.bands
            ]

        n, C, _ = X.shape
        B = len(self.sos_bands)
        fb_covs = np.zeros((n, B, C, C))
        for i, sos in enumerate(self.sos_bands):
            Xf = sosfiltfilt(sos, X, axis=2)
            covs = Covariances(estimator="lwf").transform(Xf)
            # Add regularization to the diagonal
            covs += np.eye(C)[None, :, :] * 1e-6
            fb_covs[:, i] = covs
        return fb_covs

    def fit(self, X, y):
        self.classes_ = np.unique(y)
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        # Flatten for tangent space
        covs_flat = fb_covs.reshape(n * B, C, C)
        labels_rep = np.repeat(y, B)

        # Fit tangent space
        self.ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
        Z = self.ts.transform(covs_flat)
        Z = Z.reshape(n, B, -1)

        # Compute mutual information weights
        self.w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)
        self.w = self.w / self.w.sum()

        # Weight features
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        # Train classifier
        self.clf = make_pipeline(
            StandardScaler(),
            RandomForestClassifier(
                n_estimators=self.n_estimators,
                max_depth=self.max_depth,
                min_samples_split=self.min_samples_split,
                min_samples_leaf=self.min_samples_leaf,
                max_features=self.max_features,
                class_weight=self.class_weight,
                n_jobs=self.n_jobs,
                random_state=42,
            ),
        )
        self.clf.fit(Z_weighted, y)
        return self

    def predict(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict(Z_weighted)

    def predict_proba(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict_proba(Z_weighted)



In [None]:
def load_eeg_data(data_path, window_length, stride, tmin, eeg_channels):
    ds = EEGDataset(
        data_path,
        window_length=window_length,
        stride=stride,
        task="ssvep",
        split="train",
        data_fraction=0.4,
        tmin=tmin,
        eeg_channels=eeg_channels,
    )
    X = np.stack([x.numpy() for x, _ in ds])
    y = np.array([label[0] for _, label in ds])
    return X, y


# Optuna optimization
data_path = "./data/mtcaic3"
cv_folds = 3

window_length = 1000
stride = 85
eeg_channels = ['FZ', 'CZ', 'PZ', 'C3', 'OZ']
tmin = 60

X, y = load_eeg_data(data_path, window_length=window_length, stride=stride, tmin=tmin, eeg_channels=eeg_channels)

def objective(trial):
    # Data parameters
    tmin = trial.suggest_int("tmin", 0, 250, step=10)

    # Filter bank parameters
    filter_order = trial.suggest_int("filter_order", 3, 6)
    fs = trial.suggest_int("fs", 70, 300, step=10)

    # Random Forest parameters
    n_estimators = trial.suggest_int("n_estimators", 50, 500, step=50)
    min_samples_split = trial.suggest_int("min_samples_split", 2, 10)
    min_samples_leaf = trial.suggest_int("min_samples_leaf", 1, 5)

    clf = FilterBankRTSClassifier(fs=fs, order=filter_order, n_estimators=n_estimators, max_depth=None, class_weight="balanced", n_jobs=-1)

    # Add RF-specific parameters
    clf.min_samples_split = min_samples_split
    clf.min_samples_leaf = min_samples_leaf

    # Override the classifier creation in fit method
    original_fit = clf.fit

    def custom_fit(X, y):
        # Store the classes - this is required by scikit-learn
        clf.classes_ = np.unique(y)

        fb_covs = clf.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        labels_rep = np.repeat(y, B)

        clf.ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
        Z = clf.ts.transform(covs_flat)
        Z = Z.reshape(n, B, -1)

        clf.w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)
        clf.w = clf.w / clf.w.sum()

        Z_weighted = np.concatenate([np.sqrt(clf.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        clf.clf = make_pipeline(
            StandardScaler(),
            RandomForestClassifier(
                n_estimators=n_estimators,
                max_depth=None,
                min_samples_split=min_samples_split,
                min_samples_leaf=min_samples_leaf,
                max_features='sqrt',
                class_weight="balanced",
                n_jobs=-1,
                random_state=42,
            ),
        )
        clf.clf.fit(Z_weighted, y)
        return clf

    clf.fit = custom_fit

    cv = StratifiedKFold(cv_folds, shuffle=True, random_state=42)
    scores = cross_validate(clf, X, y, cv=cv, scoring="accuracy", return_train_score=True)

    train_acc = scores["train_score"].mean()
    val_acc = scores["test_score"].mean()

    print(f"   → Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")
    trial.set_user_attr("train_acc", train_acc)

    return val_acc


study = optuna.create_study(direction="maximize", pruner=MedianPruner())
study.optimize(objective, n_trials=50 , timeout=7200)

print("\n=== Best trial ===")
best = study.best_trial
print("Val Acc:", best.value)
print("Train Acc:", best.user_attrs["train_acc"])
print("Params:")
for k, v in best.params.items():
    print(f"  {k}: {v}")

task: ssvep, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 51/960


[I 2025-06-28 22:26:22,449] A new study created in memory with name: no-name-58a9c27b-ac71-48c7-b6d7-7846b4aab08c
[W 2025-06-28 22:26:35,640] Trial 0 failed with parameters: {'tmin': 120, 'filter_order': 3, 'fs': 70, 'n_estimators': 450, 'min_samples_split': 9, 'min_samples_leaf': 2} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/zeyadcode/.pyenv/versions/icmtc_venv/lib/python3.12/site-packages/optuna/study/_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_93986/2837652162.py", line 88, in objective
    scores = cross_validate(clf, X, y, cv=cv, scoring="accuracy", return_train_score=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeyadcode/.pyenv/versions/icmtc_venv/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 213, in wrapper
    return func(*args, **kwargs)
       

KeyboardInterrupt: 

In [12]:
# [I 2025-06-28 19:47:51,512] Trial 30 finished with value: 0.6598790349742903 and parameters: {'tmin': 60, 'filter_order': 3, 'fs': 100, 'n_estimators': 300, 'min_samples_split': 7, 'min_samples_leaf': 3}. Best is trial 30 with value: 0.6598790349742903.

window_length = 1000
stride = 20
tmin = 250
eeg_channels = ['C4', 'CZ']
n_bands = 4
filter_order = 3
fs = 100
n_estimators = 300
max_depth = None
min_samples_split = 7
min_samples_leaf = 3
max_features = 'sqrt'
data_path = "./data/mtcaic3"
bands = [
    (6, 8),   # around 7 Hz
    (7, 9),   # around 8 Hz
    (9, 11),  # around 10 Hz
    (12, 14), # around 13 Hz
]



# Load data with besjt parameters
ds_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="ssvep",
    split="train",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_train = np.stack([x.numpy() for x, _ in ds_train])
y_train = np.array([label[0] for _, label in ds_train])

# Load data with besjt parameters
ds_val = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="ssvep",
    split="validation",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_val = np.stack([x.numpy() for x, _ in ds_val])
y_val = np.array([label[0] for _, label in ds_val])


# Create frequency bands

# Create FilterBank classifier with best parameters
clf = FilterBankRTSClassifier(
    fs=fs,
    order=filter_order,
    n_estimators=n_estimators,
    max_depth=max_depth,
    min_samples_split=min_samples_split,
    min_samples_leaf=min_samples_leaf,
    bands=bands,
)

# Fit on training data
clf.fit(X_train, y_train)

# Calculate accuracy
y_pred = clf.predict(X_val)
val_acc = accuracy_score(y_val, y_pred)

print(f"Validation accuracy: {val_acc:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(y_val, y_pred)) 

task: ssvep, split: train, domain: time, data_fraction: 1
skipped: 117/2400
task: ssvep, split: validation, domain: time, data_fraction: 1
skipped: 3/50
Validation accuracy: 0.2540

Classification Report:
              precision    recall  f1-score   support

           0       0.19      0.30      0.23       295
           1       0.61      0.33      0.43       212
           2       0.19      0.20      0.19       302
           3       0.35      0.20      0.25       266

    accuracy                           0.25      1075
   macro avg       0.33      0.26      0.28      1075
weighted avg       0.31      0.25      0.26      1075

