In [2]:
%load_ext autoreload
%autoreload 2
import torch
import random
import numpy as np
import optuna
from optuna.pruners import MedianPruner
from scipy.signal import butter, sosfiltfilt
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.feature_selection import mutual_info_classif
from sklearn.base import BaseEstimator, ClassifierMixin
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from modules.competition_dataset import EEGDataset
from pyriemann.utils.mean import mean_riemann
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import accuracy_score, classification_report
from scipy.linalg import expm, logm, inv
from scipy.optimize import minimize
from Models import FilterBankRTSClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_path = './data/mtcaic3'
lda_model_path = './checkpoints/mi/models/lda_mi.pkl'

# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_random_seeds(42)

In [4]:
class FilterBankRTSClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, fs=250, order=4, C=1.0, kernel="rbf", gamma="scale", degree=3, coef0=0.0, probability=True, class_weight="balanced", random_state=42):
        self.bands = [(8, 12), (12, 16), (16, 20), (20, 24), (24, 30)]
        self.fs = fs
        self.order = order
        self.C = C
        self.kernel = kernel
        self.gamma = gamma
        self.degree = degree
        self.coef0 = coef0
        self.probability = probability
        self.class_weight = class_weight
        self.random_state = random_state

    def compute_fb_covs(self, X):
        """X: (n_trials, C, T) → fb_covs: (n_trials, B, C, C)"""
        if not hasattr(self, "sos_bands"):
            self.sos_bands = [butter(self.order, (l / (self.fs / 2), h / (self.fs / 2)), btype="bandpass", output="sos") for l, h in self.bands]
        n, C, _ = X.shape
        B = len(self.sos_bands)
        fb_covs = np.zeros((n, B, C, C))
        for i, sos in enumerate(self.sos_bands):
            Xf = sosfiltfilt(sos, X, axis=2)
            fb_covs[:, i] = Covariances(estimator="lwf").transform(Xf)
        return fb_covs

    def fit(self, X, y):
        self.classes_ = np.unique(y)
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        labels_rep = np.repeat(y, B)

        self.ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
        Z = self.ts.transform(covs_flat)
        Z = Z.reshape(n, B, -1)

        self.w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)
        self.w = self.w / self.w.sum()

        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        from sklearn.svm import SVC

        self.clf = make_pipeline(
            StandardScaler(),
            SVC(
                C=self.C,
                kernel=self.kernel,
                gamma=self.gamma,
                degree=self.degree,
                coef0=self.coef0,
                probability=self.probability,
                class_weight=self.class_weight,
                random_state=self.random_state,
            ),
        )
        self.clf.fit(Z_weighted, y)
        return self

    def predict(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict(Z_weighted)

    def predict_proba(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict_proba(Z_weighted)


# -----------------------------------------------------------------------------
# 1) RSFTransformer that learns W_opt using BFGS
# -----------------------------------------------------------------------------
class RSFTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, d=8, eps=1e-6, maxiter=500):
        self.d = d
        self.eps = eps
        self.maxiter = maxiter

    def fit(self, X, y):
        covs = Covariances(estimator="lwf").transform(X)
        C1 = mean_riemann(covs[y == 0])
        C2 = mean_riemann(covs[y == 1])
        Nc = C1.shape[0]
        W0 = np.random.randn(Nc, self.d)

        def J(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            return np.linalg.norm(logm(inv(S1) @ S2), "fro") ** 2

        def grad(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            A = inv(S1) @ S2
            L = logm(A)
            t1 = 2 * C1 @ W @ inv(S1) @ L @ inv(S1)
            t2 = 2 * C2 @ W @ inv(S2) @ L @ inv(S2)
            return (t1 - t2).ravel()

        res = minimize(lambda w: -J(w), W0.ravel(), jac=lambda w: -grad(w), method="BFGS", options={"maxiter": self.maxiter})
        self.W_opt = res.x.reshape(Nc, self.d)
        return self

    def transform(self, X):
        return np.einsum("cd,ect->edt", self.W_opt, X)


def load_eeg_data(data_path, window_length, stride, tmin, eeg_channels):
    ds = EEGDataset(
        data_path,
        window_length=window_length,
        stride=stride,
        task="mi",
        split="train",
        data_fraction=0.4,
        tmin=tmin,
        eeg_channels=eeg_channels,
    )
    X = np.stack([x.numpy() for x, _ in ds])
    y = np.array([label[0] for _, label in ds])
    return X, y


# Optuna optimization
data_path = "./data/mtcaic3"
cv_folds = 3

window_length = 1000
stride = 85
tmin = 60
eeg_channels = ["FZ", "CZ", "PZ", "C3", "OZ"]

X, y = load_eeg_data(data_path, window_length, stride, tmin, eeg_channels)


def objective(trial):
    # --- Data and channel selection parameters (same as before) ---
    d_rsf = trial.suggest_int("rsf_d", 2, len(eeg_channels))

    # --- Filter bank parameters (same as before) ---
    filter_order = trial.suggest_int("filter_order", 3, 6)
    fs = trial.suggest_categorical("fs", [125, 250, 500])

    # --- SVM parameters ---
    C = trial.suggest_float("C", 1e-3, 1e3, log=True)
    kernel = trial.suggest_categorical("kernel", ["linear", "rbf", "poly", "sigmoid"])
    gamma = trial.suggest_categorical("gamma", ["scale", "auto"])
    degree = trial.suggest_int("degree", 2, 5)  # only for 'poly'
    coef0 = trial.suggest_float("coef0", 0.0, 1.0)

    # --- CORE CHANGE: Build the pipeline ---
    pipeline = make_pipeline(
        RSFTransformer(d=d_rsf),
        FilterBankRTSClassifier(
            fs=fs,
            order=filter_order,
            C=C,
            kernel=kernel,
            gamma=gamma,
            degree=degree,
            coef0=coef0,
            probability=True,
        ),
    )

    cv = StratifiedKFold(cv_folds, shuffle=True, random_state=42)

    # Run cross-validation on the entire pipeline
    scores = cross_validate(pipeline, X, y, cv=cv, scoring="accuracy", return_train_score=True)

    train_acc = scores["train_score"].mean()
    val_acc = scores["test_score"].mean()

    print(f"   → Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")
    trial.set_user_attr("train_acc", train_acc)

    return val_acc


study = optuna.create_study(direction="maximize", pruner=MedianPruner())
study.optimize(objective, n_trials=100, timeout=7200)

print("\n=== Best trial ===")
best = study.best_trial
print("Val Acc:", best.value)
print("Train Acc:", best.user_attrs["train_acc"])
print("Params:")
for k, v in best.params.items():
    print(f"  {k}: {v}")

task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 11/960


[I 2025-06-28 22:25:28,127] A new study created in memory with name: no-name-50228663-b5ca-4e00-8b7a-1bedb61f580c
[I 2025-06-28 22:28:48,938] Trial 0 finished with value: 0.5598784371887556 and parameters: {'rsf_d': 4, 'filter_order': 6, 'fs': 125, 'C': 17.796913417063408, 'kernel': 'linear', 'gamma': 'scale', 'degree': 4, 'coef0': 0.6522116873223839}. Best is trial 0 with value: 0.5598784371887556.


   → Train acc: 0.569 | Val acc: 0.560


[I 2025-06-28 22:30:30,466] Trial 1 finished with value: 0.5234123596888054 and parameters: {'rsf_d': 2, 'filter_order': 3, 'fs': 250, 'C': 0.24949163879331557, 'kernel': 'linear', 'gamma': 'auto', 'degree': 3, 'coef0': 0.8115646169421447}. Best is trial 0 with value: 0.5598784371887556.


   → Train acc: 0.531 | Val acc: 0.523


[I 2025-06-28 22:32:03,531] Trial 2 finished with value: 0.5225600691932295 and parameters: {'rsf_d': 2, 'filter_order': 5, 'fs': 500, 'C': 1.8007429871212501, 'kernel': 'linear', 'gamma': 'scale', 'degree': 5, 'coef0': 0.0514839356423179}. Best is trial 0 with value: 0.5598784371887556.


   → Train acc: 0.534 | Val acc: 0.523


[I 2025-06-28 22:36:39,037] Trial 3 finished with value: 0.5635379553218661 and parameters: {'rsf_d': 5, 'filter_order': 6, 'fs': 500, 'C': 364.72059659157196, 'kernel': 'poly', 'gamma': 'scale', 'degree': 3, 'coef0': 0.5001319564232748}. Best is trial 3 with value: 0.5635379553218661.


   → Train acc: 0.979 | Val acc: 0.564


[I 2025-06-28 22:44:54,457] Trial 4 finished with value: 0.5386589243625531 and parameters: {'rsf_d': 5, 'filter_order': 3, 'fs': 500, 'C': 30.107461402825148, 'kernel': 'linear', 'gamma': 'scale', 'degree': 2, 'coef0': 0.17589855276013622}. Best is trial 3 with value: 0.5635379553218661.


   → Train acc: 0.568 | Val acc: 0.539


[I 2025-06-28 22:47:03,313] Trial 5 finished with value: 0.5178049529297425 and parameters: {'rsf_d': 3, 'filter_order': 4, 'fs': 500, 'C': 0.002133930524120991, 'kernel': 'sigmoid', 'gamma': 'auto', 'degree': 3, 'coef0': 0.6289868713099951}. Best is trial 3 with value: 0.5635379553218661.


   → Train acc: 0.519 | Val acc: 0.518


[I 2025-06-28 22:49:08,086] Trial 6 finished with value: 0.5243903635901143 and parameters: {'rsf_d': 4, 'filter_order': 5, 'fs': 500, 'C': 0.0759630945832062, 'kernel': 'linear', 'gamma': 'scale', 'degree': 4, 'coef0': 0.14198198294205988}. Best is trial 3 with value: 0.5635379553218661.


   → Train acc: 0.555 | Val acc: 0.524


[I 2025-06-28 22:51:03,647] Trial 7 finished with value: 0.5140233794814844 and parameters: {'rsf_d': 2, 'filter_order': 5, 'fs': 500, 'C': 0.015084259382027471, 'kernel': 'rbf', 'gamma': 'auto', 'degree': 5, 'coef0': 0.18499795340352065}. Best is trial 3 with value: 0.5635379553218661.


   → Train acc: 0.524 | Val acc: 0.514


[I 2025-06-28 22:52:34,786] Trial 8 finished with value: 0.5364640700112143 and parameters: {'rsf_d': 2, 'filter_order': 3, 'fs': 500, 'C': 3.7996061484882855, 'kernel': 'rbf', 'gamma': 'auto', 'degree': 4, 'coef0': 0.14095658777607734}. Best is trial 3 with value: 0.5635379553218661.


   → Train acc: 0.589 | Val acc: 0.536


[W 2025-06-28 22:53:04,247] Trial 9 failed with parameters: {'rsf_d': 2, 'filter_order': 6, 'fs': 500, 'C': 7.28625061913845, 'kernel': 'linear', 'gamma': 'scale', 'degree': 4, 'coef0': 0.09279765268273243} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/zeyadcode/.pyenv/versions/icmtc_venv/lib/python3.12/site-packages/optuna/study/_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_94761/3783384118.py", line 184, in objective
    scores = cross_validate(pipeline, X, y, cv=cv, scoring="accuracy", return_train_score=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeyadcode/.pyenv/versions/icmtc_venv/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 213, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeyadcode/.pyenv/versions/i

KeyboardInterrupt: 

In [None]:
# Trial 7 finished with value: 0.5871394563472291 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 3, 'min_freq': 4, 'max_freq': 39, 'filter_order': 3, 'fs': 125, 'n_estimators': 150, 'max_depth': 10, 'min_samples_split': 5, 'min_samples_leaf': 1, 'max_features': 'sqrt'}. Best is trial 7 with value: 0.5871394563472291.

class FilterBankRTSClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, bands=None, fs=250, order=4, n_estimators=100, max_depth=None, min_samples_split=2, min_samples_leaf=1, max_features="sqrt", class_weight="balanced", n_jobs=-1):
        self.bands = bands if bands else [(8, 12), (12, 16), (16, 20), (20, 24), (24, 30)]
        self.fs = fs
        self.order = order
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.max_features = max_features
        self.class_weight = class_weight
        self.n_jobs = n_jobs

    def compute_fb_covs(self, X):
        """X: (n_trials, C, T) → fb_covs: (n_trials, B, C, C)"""
        # Pre-compute SOS filters if not done
        if not hasattr(self, "sos_bands"):
            self.sos_bands = [butter(self.order, (l / (self.fs / 2), h / (self.fs / 2)), btype="bandpass", output="sos") for l, h in self.bands]

        n, C, _ = X.shape
        B = len(self.sos_bands)
        fb_covs = np.zeros((n, B, C, C))
        for i, sos in enumerate(self.sos_bands):
            Xf = sosfiltfilt(sos, X, axis=2)
            fb_covs[:, i] = Covariances(estimator="lwf").transform(Xf)
        return fb_covs

    def fit(self, X, y):
        self.classes_ = np.unique(y)
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        # Flatten for tangent space
        covs_flat = fb_covs.reshape(n * B, C, C)
        labels_rep = np.repeat(y, B)

        # Fit tangent space
        self.ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
        Z = self.ts.transform(covs_flat)
        Z = Z.reshape(n, B, -1)

        # Compute mutual information weights
        self.w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)
        self.w = self.w / self.w.sum()

        # Weight features
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        # Train classifier
        self.clf = make_pipeline(
            StandardScaler(),
            RandomForestClassifier(
                n_estimators=self.n_estimators,
                max_depth=self.max_depth,
                min_samples_split=self.min_samples_split,
                min_samples_leaf=self.min_samples_leaf,
                max_features=self.max_features,
                class_weight=self.class_weight,
                n_jobs=self.n_jobs,
                random_state=42,
            ),
        )
        self.clf.fit(Z_weighted, y)
        return self

    def predict(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict(Z_weighted)

    def predict_proba(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict_proba(Z_weighted)

class RSFTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, d=8, eps=1e-6, maxiter=500):
        self.d = d
        self.eps = eps
        self.maxiter = maxiter

    def fit(self, X, y):
        covs = Covariances(estimator="lwf").transform(X)
        C1 = mean_riemann(covs[y == 0])
        C2 = mean_riemann(covs[y == 1])
        Nc = C1.shape[0]
        W0 = np.random.randn(Nc, self.d)

        def J(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            return np.linalg.norm(logm(inv(S1) @ S2), "fro") ** 2

        def grad(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            A = inv(S1) @ S2
            L = logm(A)
            t1 = 2 * C1 @ W @ inv(S1) @ L @ inv(S1)
            t2 = 2 * C2 @ W @ inv(S2) @ L @ inv(S2)
            return (t1 - t2).ravel()

        res = minimize(lambda w: -J(w), W0.ravel(), jac=lambda w: -grad(w), method="BFGS", options={"maxiter": self.maxiter})
        self.W_opt = res.x.reshape(Nc, self.d)
        return self

    def transform(self, X):
        return np.einsum("cd,ect->edt", self.W_opt, X)


window_length = 1000
stride = 85
tmin = 500
eeg_channels = ['CZ', 'FZ', 'PZ', 'C3', 'OZ']
d_rsf = 5
n_bands = 3
min_freq = 4
max_freq = 39
filter_order = 3
fs = 125
n_estimators = 150
max_depth = 10
min_samples_split = 5
min_samples_leaf = 1
max_features = 'sqrt'
data_path = "./data/mtcaic3"


# Load data with besjt parameters
ds_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_train = np.stack([x.numpy() for x, _ in ds_train])
y_train = np.array([label[0] for _, label in ds_train])

# Load data with besjt parameters
freq_step = (max_freq - min_freq) / n_bands
bands = [(min_freq + i * freq_step, min_freq + (i + 1) * freq_step) for i in range(n_bands)]
ds_val = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_val = np.stack([x.numpy() for x, _ in ds_val])
y_val = np.array([label[0] for _, label in ds_val])

freq_step = (max_freq - min_freq) / n_bands
bands = [(min_freq + i * freq_step, min_freq + (i + 1) * freq_step) for i in range(n_bands)]

# Create frequency bands

# Create FilterBank classifier with best parameters
clf = make_pipeline(
    # This single object now contains your entire workflow.
        RSFTransformer(d=d_rsf),
        FilterBankRTSClassifier(
            bands=bands,
            fs=fs,
            order=filter_order,
            n_estimators=n_estimators,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            max_features=max_features,
            class_weight="balanced",
            n_jobs=-1
        )
    )

# Fit on training data
clf.fit(X_train, y_train)

# Calculate accuracy
y_pred = clf.predict(X_val)
val_acc = accuracy_score(y_val, y_pred)

print(f"Validation accuracy: {val_acc:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(y_val, y_pred))

task: mi, split: train, domain: time, data_fraction: 1
skipped: 21/2400
task: mi, split: train, domain: time, data_fraction: 1
skipped: 21/2400
Validation accuracy: 0.5602

Classification Report:
              precision    recall  f1-score   support

           0       0.55      0.61      0.58      3317
           1       0.57      0.51      0.54      3374

    accuracy                           0.56      6691
   macro avg       0.56      0.56      0.56      6691
weighted avg       0.56      0.56      0.56      6691

