In [2]:
%load_ext autoreload
%autoreload 2
import torch
import random
import numpy as np
import optuna
from optuna.pruners import MedianPruner
from scipy.signal import butter, sosfiltfilt
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.feature_selection import mutual_info_classif
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.pipeline import Pipeline
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from modules.competition_dataset import EEGDataset
from Models import FilterBankRTSClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_path = './data/mtcaic3'
lda_model_path = './checkpoints/mi/models/lda_mi.pkl'

# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_random_seeds(42)

In [7]:
def load_eeg_data(data_path, window_length, stride, tmin, eeg_channels):
    ds = EEGDataset(
        data_path,
        window_length=window_length,
        stride=stride,
        task="mi",
        split="train",
        data_fraction=0.4,
        tmin=tmin,
        eeg_channels=eeg_channels,
    )
    X = np.stack([x.numpy() for x, _ in ds])
    y = np.array([label[0] for _, label in ds])
    return X, y


# Optuna optimization
data_path = "./data/mtcaic3"
cv_folds = 3

window_length = 1000
stride = 85
eeg_channels = ['FZ', 'CZ', 'PZ', 'C3', 'OZ']
tmin = 60

X, y = load_eeg_data(data_path, window_length=window_length, stride=stride, tmin=tmin, eeg_channels=eeg_channels)

def objective(trial):
    # Data parameters
    tmin = trial.suggest_int("tmin", 0, 250, step=10)

    # Filter bank parameters
    filter_order = trial.suggest_int("filter_order", 3, 6)
    fs = trial.suggest_int("fs", 70, 300, step=10)

    # Random Forest parameters
    n_estimators = trial.suggest_int("n_estimators", 50, 500, step=50)
    min_samples_split = trial.suggest_int("min_samples_split", 2, 10)
    min_samples_leaf = trial.suggest_int("min_samples_leaf", 1, 5)

    clf = FilterBankRTSClassifier(fs=fs, order=filter_order, n_estimators=n_estimators, max_depth=None, class_weight="balanced", n_jobs=-1)

    # Add RF-specific parameters
    clf.min_samples_split = min_samples_split
    clf.min_samples_leaf = min_samples_leaf

    # Override the classifier creation in fit method
    original_fit = clf.fit

    def custom_fit(X, y):
        # Store the classes - this is required by scikit-learn
        clf.classes_ = np.unique(y)

        fb_covs = clf.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        labels_rep = np.repeat(y, B)

        clf.ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
        Z = clf.ts.transform(covs_flat)
        Z = Z.reshape(n, B, -1)

        clf.w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)
        clf.w = clf.w / clf.w.sum()

        Z_weighted = np.concatenate([np.sqrt(clf.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        clf.clf = make_pipeline(
            StandardScaler(),
            RandomForestClassifier(
                n_estimators=n_estimators,
                max_depth=None,
                min_samples_split=min_samples_split,
                min_samples_leaf=min_samples_leaf,
                max_features='sqrt',
                class_weight="balanced",
                n_jobs=-1,
                random_state=42,
            ),
        )
        clf.clf.fit(Z_weighted, y)
        return clf

    clf.fit = custom_fit

    cv = StratifiedKFold(cv_folds, shuffle=True, random_state=42)
    scores = cross_validate(clf, X, y, cv=cv, scoring="accuracy", return_train_score=True)

    train_acc = scores["train_score"].mean()
    val_acc = scores["test_score"].mean()

    print(f"   → Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")
    trial.set_user_attr("train_acc", train_acc)

    return val_acc


study = optuna.create_study(direction="maximize", pruner=MedianPruner())
study.optimize(objective, n_trials=50 , timeout=7200)

print("\n=== Best trial ===")
best = study.best_trial
print("Val Acc:", best.value)
print("Train Acc:", best.user_attrs["train_acc"])
print("Params:")
for k, v in best.params.items():
    print(f"  {k}: {v}")

task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 11/960


[I 2025-06-28 22:13:14,772] A new study created in memory with name: no-name-473c3efe-591a-4c69-aa42-94c50a4c2329
[W 2025-06-28 22:13:42,769] Trial 0 failed with parameters: {'tmin': 180, 'filter_order': 4, 'fs': 170, 'n_estimators': 400, 'min_samples_split': 2, 'min_samples_leaf': 4} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/zeyadcode/.pyenv/versions/icmtc_venv/lib/python3.12/site-packages/optuna/study/_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_41439/3691641310.py", line 88, in objective
    scores = cross_validate(clf, X, y, cv=cv, scoring="accuracy", return_train_score=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeyadcode/.pyenv/versions/icmtc_venv/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 213, in wrapper
    return func(*args, **kwargs)
      

KeyboardInterrupt: 

In [5]:
window_length = 1000
stride = 85
tmin = 60
eeg_channels = ['FZ', 'CZ', 'PZ', 'PO7', 'OZ']
n_bands = 4
filter_order = 3
fs = 100
n_estimators = 300
max_depth = None
min_samples_split = 7
min_samples_leaf = 3
max_features = 'sqrt'
data_path = "./data/mtcaic3"



# Load data with besjt parameters
ds_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_train = np.stack([x.numpy() for x, _ in ds_train])
y_train = np.array([label[0] for _, label in ds_train])

# Load data with besjt parameters
ds_val = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="validation",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_val = np.stack([x.numpy() for x, _ in ds_val])
y_val = np.array([label[0] for _, label in ds_val])


# Create frequency bands

# Create FilterBank classifier with best parameters
clf = FilterBankRTSClassifier(
    fs=fs,
    order=filter_order,
    n_estimators=n_estimators,
    max_depth=max_depth,
    min_samples_split=min_samples_split,
    min_samples_leaf=min_samples_leaf# Trial 67 finished with value: 0.6263345734944465 and parameters: {'window_length': 250, 'stride': 250, 'tmin': 0, 'ch_FZ': 1, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 0, 'n_bands': 4, 'min_freq': 11, 'max_freq': 40, 'filter_order': 3, 'fs': 125, 'n_estimators': 200, 'max_depth': None, 'min_samples_split': 6, 'min_samples_leaf': 2, 'max_features': 'sqrt'}. Best is trial 67 with value: 0.6263345734944465.
)

# Fit on training data
clf.fit(X_train, y_train)

# Calculate accuracy
y_pred = clf.predict(X_val)
val_acc = accuracy_score(y_val, y_pred)

print(f"Validation accuracy: {val_acc:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(y_val, y_pred)) 

task: mi, split: train, domain: time, data_fraction: 1
skipped: 33/2400
task: mi, split: validation, domain: time, data_fraction: 1
skipped: 0/50
Validation accuracy: 0.4015

Classification Report:
              precision    recall  f1-score   support

           0       0.48      0.34      0.40       232
           1       0.35      0.49      0.41       169

    accuracy                           0.40       401
   macro avg       0.41      0.41      0.40       401
weighted avg       0.42      0.40      0.40       401

