In [1]:
%load_ext autoreload
%autoreload 2
import torch
import random
import numpy as np
import optuna
from optuna.pruners import MedianPruner
from scipy.signal import butter, sosfiltfilt
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.feature_selection import mutual_info_classif
from sklearn.base import BaseEstimator, ClassifierMixin
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from modules.competition_dataset import EEGDataset
from pyriemann.utils.mean import mean_riemann
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import accuracy_score, classification_report
from scipy.linalg import expm, logm, inv
from scipy.optimize import minimize

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path = './data/mtcaic3'
lda_model_path = './checkpoints/mi/models/lda_mi.pkl'

# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_random_seeds(42)

In [None]:
# -----------------------------------------------------------------------------
# 1) RSFTransformer that learns W_opt using BFGS
# -----------------------------------------------------------------------------
class RSFTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, d=8, eps=1e-6, maxiter=500):
        self.d = d
        self.eps = eps
        self.maxiter = maxiter

    def fit(self, X, y):
        covs = Covariances(estimator="lwf").transform(X)
        C1 = mean_riemann(covs[y == 0])
        C2 = mean_riemann(covs[y == 1])
        Nc = C1.shape[0]
        W0 = np.random.randn(Nc, self.d)

        def J(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            return np.linalg.norm(logm(inv(S1) @ S2), "fro") ** 2

        def grad(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            A = inv(S1) @ S2
            L = logm(A)
            t1 = 2 * C1 @ W @ inv(S1) @ L @ inv(S1)
            t2 = 2 * C2 @ W @ inv(S2) @ L @ inv(S2)
            return (t1 - t2).ravel()

        res = minimize(lambda w: -J(w), W0.ravel(), jac=lambda w: -grad(w), method="BFGS", options={"maxiter": self.maxiter})
        self.W_opt = res.x.reshape(Nc, self.d)
        return self

    def transform(self, X):
        return np.einsum("cd,ect->edt", self.W_opt, X)

class FilterBankRTSClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, bands=None, fs=250, order=4, n_estimators=100, max_depth=None, min_samples_split=2, min_samples_leaf=1, max_features="sqrt", class_weight="balanced", n_jobs=-1):
        self.bands = bands if bands else [(8, 12), (12, 16), (16, 20), (20, 24), (24, 30)]
        self.fs = fs
        self.order = order
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.max_features = max_features
        self.class_weight = class_weight
        self.n_jobs = n_jobs

    def compute_fb_covs(self, X):
        """X: (n_trials, C, T) → fb_covs: (n_trials, B, C, C)"""
        # Pre-compute SOS filters if not done
        if not hasattr(self, "sos_bands"):
            self.sos_bands = [butter(self.order, (l / (self.fs / 2), h / (self.fs / 2)), btype="bandpass", output="sos") for l, h in self.bands]

        n, C, _ = X.shape
        B = len(self.sos_bands)
        fb_covs = np.zeros((n, B, C, C))
        for i, sos in enumerate(self.sos_bands):
            Xf = sosfiltfilt(sos, X, axis=2)
            fb_covs[:, i] = Covariances(estimator="lwf").transform(Xf)
        return fb_covs

    def fit(self, X, y):
        self.classes_ = np.unique(y)
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        # Flatten for tangent space
        covs_flat = fb_covs.reshape(n * B, C, C)
        labels_rep = np.repeat(y, B)

        # Fit tangent space
        self.ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
        Z = self.ts.transform(covs_flat)
        Z = Z.reshape(n, B, -1)

        # Compute mutual information weights
        self.w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)
        self.w = self.w / self.w.sum()

        # Weight features
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        # Train classifier
        self.clf = make_pipeline(
            StandardScaler(),
            RandomForestClassifier(
                n_estimators=self.n_estimators,
                max_depth=self.max_depth,
                min_samples_split=self.min_samples_split,
                min_samples_leaf=self.min_samples_leaf,
                max_features=self.max_features,
                class_weight=self.class_weight,
                n_jobs=self.n_jobs,
                random_state=42,
            ),
        )
        self.clf.fit(Z_weighted, y)
        return self

    def predict(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict(Z_weighted)

    def predict_proba(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict_proba(Z_weighted)


def load_eeg_data(data_path, window_length, stride, tmin, eeg_channels):
    ds = EEGDataset(
        data_path,
        window_length=window_length,
        stride=stride,
        task="mi",
        split="train",
        data_fraction=0.4,
        tmin=tmin,
        eeg_channels=eeg_channels,
    )
    X = np.stack([x.numpy() for x, _ in ds])
    y = np.array([label[0] for _, label in ds])
    return X, y


# Optuna optimization
data_path = "./data/mtcaic3"
cv_folds = 3

window_lengths = [250, 500, 1000, 1500]
strides = [250, 500, 750]
tmins = [0, 250, 500]


def objective(trial):
    # --- Data and channel selection parameters (same as before) ---
    wl = trial.suggest_categorical("window_length", window_lengths)
    st = trial.suggest_categorical("stride", strides)
    t0 = trial.suggest_categorical("tmin", tmins)
    all_channels = ["FZ", "C3", "CZ", "C4", "PZ", "PO7", "OZ", "PO8"]
    active_mask = [trial.suggest_int(f"ch_{ch}", 0, 1) for ch in all_channels]
    chs = [ch for ch, enabled in zip(all_channels, active_mask) if enabled]
    if len(chs) < 2:
        return 0.0

    # --- NEW: Hyperparameter for the RSFTransformer ---
    # The dimension 'd' must be between 2 and the number of selected channels.
    d_rsf = trial.suggest_int("rsf_d", 2, len(chs))

    # --- Filter bank parameters (same as before) ---
    n_bands = trial.suggest_int("n_bands", 3, 8)
    min_freq = trial.suggest_int("min_freq", 4, 12)
    max_freq = trial.suggest_int("max_freq", 25, 40)
    freq_step = (max_freq - min_freq) / n_bands
    bands = [(min_freq + i * freq_step, min_freq + (i + 1) * freq_step) for i in range(n_bands)]
    filter_order = trial.suggest_int("filter_order", 3, 6)
    fs = trial.suggest_categorical("fs", [125, 250, 500])

    # --- Random Forest parameters (same as before) ---
    n_estimators = trial.suggest_int("n_estimators", 50, 500, step=50)
    max_depth = trial.suggest_categorical("max_depth", [None, 5, 10, 15, 20])
    min_samples_split = trial.suggest_int("min_samples_split", 2, 10)
    min_samples_leaf = trial.suggest_int("min_samples_leaf", 1, 5)
    max_features = trial.suggest_categorical("max_features", ["sqrt", "log2", None])

    print(f"\n→ Trying: wl={wl}, st={st}, tmin={t0}, chans={chs}")
    print(f"   RSF d={d_rsf}, Bands: {bands}, RF: n_est={n_estimators}, depth={max_depth}")

    try:
        X, y = load_eeg_data(data_path, wl, st, t0, chs)
    except Exception as e:
        print("Data loading failed:", e)
        return 0.0

    # --- CORE CHANGE: Build the pipeline ---
    # This single object now contains your entire workflow.
    pipeline = make_pipeline(
        RSFTransformer(d=d_rsf),
        FilterBankRTSClassifier(
            bands=bands,
            fs=fs,
            order=filter_order,
            n_estimators=n_estimators,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            max_features=max_features,
            class_weight="balanced",
            n_jobs=-1
        )
    )

    cv = StratifiedKFold(cv_folds, shuffle=True, random_state=42)
    
    # Run cross-validation on the entire pipeline
    scores = cross_validate(pipeline, X, y, cv=cv, scoring="accuracy", return_train_score=True)

    train_acc = scores["train_score"].mean()
    val_acc = scores["test_score"].mean()

    print(f"   → Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")
    trial.set_user_attr("train_acc", train_acc)

    return val_acc
    
    
study = optuna.create_study(direction="maximize", pruner=MedianPruner())
study.optimize(objective, n_trials=100, timeout=7200)

print("\n=== Best trial ===")
best = study.best_trial
print("Val Acc:", best.value)
print("Train Acc:", best.user_attrs["train_acc"])
print("Params:")
for k, v in best.params.items():
    print(f"  {k}: {v}")

[I 2025-06-28 16:45:27,429] A new study created in memory with name: no-name-a919510c-1120-434c-9610-964537f001ae



→ Trying: wl=1000, st=250, tmin=250, chans=['C3', 'CZ', 'C4', 'PZ', 'OZ']
   RSF d=3, Bands: [(8.0, 15.25), (15.25, 22.5), (22.5, 29.75), (29.75, 37.0)], RF: n_est=300, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 16/960


[I 2025-06-28 16:46:17,501] Trial 0 finished with value: 0.5192881036889468 and parameters: {'window_length': 1000, 'stride': 250, 'tmin': 250, 'ch_FZ': 0, 'ch_C3': 1, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 1, 'ch_PO7': 0, 'ch_OZ': 1, 'ch_PO8': 0, 'rsf_d': 3, 'n_bands': 4, 'min_freq': 8, 'max_freq': 37, 'filter_order': 6, 'fs': 250, 'n_estimators': 300, 'max_depth': 10, 'min_samples_split': 5, 'min_samples_leaf': 5, 'max_features': None}. Best is trial 0 with value: 0.5192881036889468.


   → Train acc: 0.957 | Val acc: 0.519

→ Trying: wl=1500, st=750, tmin=250, chans=['FZ', 'C3', 'C4', 'PO7', 'PO8']
   RSF d=2, Bands: [(5.0, 9.166666666666668), (9.166666666666668, 13.333333333333334), (13.333333333333334, 17.5), (17.5, 21.666666666666668), (21.666666666666668, 25.833333333333336), (25.833333333333336, 30.0)], RF: n_est=450, depth=15
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 183/960


[I 2025-06-28 16:46:42,349] Trial 1 finished with value: 0.48648648648648646 and parameters: {'window_length': 1500, 'stride': 750, 'tmin': 250, 'ch_FZ': 1, 'ch_C3': 1, 'ch_CZ': 0, 'ch_C4': 1, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 0, 'ch_PO8': 1, 'rsf_d': 2, 'n_bands': 6, 'min_freq': 5, 'max_freq': 30, 'filter_order': 4, 'fs': 500, 'n_estimators': 450, 'max_depth': 15, 'min_samples_split': 9, 'min_samples_leaf': 5, 'max_features': 'sqrt'}. Best is trial 0 with value: 0.5192881036889468.


   → Train acc: 0.972 | Val acc: 0.486

→ Trying: wl=1500, st=750, tmin=250, chans=['FZ', 'CZ', 'C4', 'PZ']
   RSF d=2, Bands: [(7.0, 12.333333333333332), (12.333333333333332, 17.666666666666664), (17.666666666666664, 23.0), (23.0, 28.333333333333332), (28.333333333333332, 33.666666666666664), (33.666666666666664, 39.0)], RF: n_est=400, depth=15
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 183/960


[I 2025-06-28 16:47:06,857] Trial 2 finished with value: 0.5122265122265123 and parameters: {'window_length': 1500, 'stride': 750, 'tmin': 250, 'ch_FZ': 1, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 1, 'ch_PO7': 0, 'ch_OZ': 0, 'ch_PO8': 0, 'rsf_d': 2, 'n_bands': 6, 'min_freq': 7, 'max_freq': 39, 'filter_order': 6, 'fs': 250, 'n_estimators': 400, 'max_depth': 15, 'min_samples_split': 7, 'min_samples_leaf': 2, 'max_features': 'log2'}. Best is trial 0 with value: 0.5192881036889468.


   → Train acc: 0.992 | Val acc: 0.512

→ Trying: wl=1500, st=250, tmin=0, chans=['C3', 'CZ', 'C4', 'PO8']
   RSF d=2, Bands: [(10.0, 14.285714285714285), (14.285714285714285, 18.57142857142857), (18.57142857142857, 22.857142857142858), (22.857142857142858, 27.142857142857142), (27.142857142857142, 31.428571428571427), (31.428571428571427, 35.714285714285715), (35.714285714285715, 40.0)], RF: n_est=500, depth=20
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 37/960


[I 2025-06-28 16:47:54,096] Trial 3 finished with value: 0.5252896342413421 and parameters: {'window_length': 1500, 'stride': 250, 'tmin': 0, 'ch_FZ': 0, 'ch_C3': 1, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 0, 'ch_PO7': 0, 'ch_OZ': 0, 'ch_PO8': 1, 'rsf_d': 2, 'n_bands': 7, 'min_freq': 10, 'max_freq': 40, 'filter_order': 4, 'fs': 250, 'n_estimators': 500, 'max_depth': 20, 'min_samples_split': 2, 'min_samples_leaf': 3, 'max_features': None}. Best is trial 3 with value: 0.5252896342413421.


   → Train acc: 0.998 | Val acc: 0.525

→ Trying: wl=500, st=500, tmin=0, chans=['C3', 'OZ', 'PO8']
   RSF d=3, Bands: [(7.0, 13.666666666666668), (13.666666666666668, 20.333333333333336), (20.333333333333336, 27.0)], RF: n_est=150, depth=None
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 3/960


[I 2025-06-28 16:48:27,715] Trial 4 finished with value: 0.513255567338282 and parameters: {'window_length': 500, 'stride': 500, 'tmin': 0, 'ch_FZ': 0, 'ch_C3': 1, 'ch_CZ': 0, 'ch_C4': 0, 'ch_PZ': 0, 'ch_PO7': 0, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 3, 'n_bands': 3, 'min_freq': 7, 'max_freq': 27, 'filter_order': 5, 'fs': 250, 'n_estimators': 150, 'max_depth': None, 'min_samples_split': 3, 'min_samples_leaf': 2, 'max_features': 'sqrt'}. Best is trial 3 with value: 0.5252896342413421.


   → Train acc: 0.999 | Val acc: 0.513

→ Trying: wl=500, st=500, tmin=0, chans=['FZ', 'C3', 'OZ', 'PO8']
   RSF d=3, Bands: [(12.0, 16.0), (16.0, 20.0), (20.0, 24.0), (24.0, 28.0), (28.0, 32.0), (32.0, 36.0), (36.0, 40.0)], RF: n_est=150, depth=20
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 3/960


[I 2025-06-28 16:49:33,198] Trial 5 finished with value: 0.5245669848002827 and parameters: {'window_length': 500, 'stride': 500, 'tmin': 0, 'ch_FZ': 1, 'ch_C3': 1, 'ch_CZ': 0, 'ch_C4': 0, 'ch_PZ': 0, 'ch_PO7': 0, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 3, 'n_bands': 7, 'min_freq': 12, 'max_freq': 40, 'filter_order': 4, 'fs': 250, 'n_estimators': 150, 'max_depth': 20, 'min_samples_split': 8, 'min_samples_leaf': 5, 'max_features': None}. Best is trial 3 with value: 0.5252896342413421.


   → Train acc: 0.992 | Val acc: 0.525

→ Trying: wl=250, st=250, tmin=0, chans=['FZ', 'C4', 'PO7', 'OZ', 'PO8']
   RSF d=3, Bands: [(7.0, 11.166666666666668), (11.166666666666668, 15.333333333333334), (15.333333333333334, 19.5), (19.5, 23.666666666666668), (23.666666666666668, 27.833333333333336), (27.833333333333336, 32.0)], RF: n_est=350, depth=20
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 0/960


[I 2025-06-28 16:51:29,622] Trial 6 finished with value: 0.5418525292608368 and parameters: {'window_length': 250, 'stride': 250, 'tmin': 0, 'ch_FZ': 1, 'ch_C3': 0, 'ch_CZ': 0, 'ch_C4': 1, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 3, 'n_bands': 6, 'min_freq': 7, 'max_freq': 32, 'filter_order': 5, 'fs': 500, 'n_estimators': 350, 'max_depth': 20, 'min_samples_split': 10, 'min_samples_leaf': 5, 'max_features': 'sqrt'}. Best is trial 6 with value: 0.5418525292608368.


   → Train acc: 0.983 | Val acc: 0.542

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'C4', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(4.0, 15.666666666666666), (15.666666666666666, 27.333333333333332), (27.333333333333332, 39.0)], RF: n_est=150, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:52:03,671] Trial 7 finished with value: 0.5871394563472291 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 3, 'min_freq': 4, 'max_freq': 39, 'filter_order': 3, 'fs': 125, 'n_estimators': 150, 'max_depth': 10, 'min_samples_split': 5, 'min_samples_leaf': 1, 'max_features': 'sqrt'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.972 | Val acc: 0.587

→ Trying: wl=1500, st=500, tmin=250, chans=['C3', 'CZ', 'PO7', 'OZ']
   RSF d=2, Bands: [(5.0, 10.8), (10.8, 16.6), (16.6, 22.4), (22.4, 28.2), (28.2, 34.0)], RF: n_est=450, depth=20
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 183/960


[I 2025-06-28 16:52:27,542] Trial 8 finished with value: 0.5173745173745173 and parameters: {'window_length': 1500, 'stride': 500, 'tmin': 250, 'ch_FZ': 0, 'ch_C3': 1, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 0, 'rsf_d': 2, 'n_bands': 5, 'min_freq': 5, 'max_freq': 34, 'filter_order': 6, 'fs': 250, 'n_estimators': 450, 'max_depth': 20, 'min_samples_split': 10, 'min_samples_leaf': 4, 'max_features': 'sqrt'}. Best is trial 7 with value: 0.5871394563472291.
[I 2025-06-28 16:52:27,546] Trial 9 finished with value: 0.0 and parameters: {'window_length': 1500, 'stride': 250, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 0, 'ch_C4': 1, 'ch_PZ': 0, 'ch_PO7': 0, 'ch_OZ': 0, 'ch_PO8': 0}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.963 | Val acc: 0.517

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(4.0, 14.333333333333334), (14.333333333333334, 24.666666666666668), (24.666666666666668, 35.0)], RF: n_est=50, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:53:02,279] Trial 10 finished with value: 0.5774142353614201 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 3, 'min_freq': 4, 'max_freq': 35, 'filter_order': 3, 'fs': 125, 'n_estimators': 50, 'max_depth': 10, 'min_samples_split': 5, 'min_samples_leaf': 1, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.950 | Val acc: 0.577

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(4.0, 14.666666666666666), (14.666666666666666, 25.333333333333332), (25.333333333333332, 36.0)], RF: n_est=50, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:53:35,162] Trial 11 finished with value: 0.5669416302201554 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 3, 'min_freq': 4, 'max_freq': 36, 'filter_order': 3, 'fs': 125, 'n_estimators': 50, 'max_depth': 10, 'min_samples_split': 5, 'min_samples_leaf': 1, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.938 | Val acc: 0.567

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(4.0, 12.0), (12.0, 20.0), (20.0, 28.0), (28.0, 36.0)], RF: n_est=50, depth=5
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:54:11,457] Trial 12 finished with value: 0.5456218890449084 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 4, 'min_freq': 4, 'max_freq': 36, 'filter_order': 3, 'fs': 125, 'n_estimators': 50, 'max_depth': 5, 'min_samples_split': 5, 'min_samples_leaf': 1, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.739 | Val acc: 0.546

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(5.0, 12.0), (12.0, 19.0), (19.0, 26.0), (26.0, 33.0)], RF: n_est=150, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:54:48,628] Trial 13 finished with value: 0.5680580929708982 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 4, 'min_freq': 5, 'max_freq': 33, 'filter_order': 3, 'fs': 125, 'n_estimators': 150, 'max_depth': 10, 'min_samples_split': 4, 'min_samples_leaf': 2, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.969 | Val acc: 0.568

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=4, Bands: [(6.0, 16.333333333333336), (16.333333333333336, 26.666666666666668), (26.666666666666668, 37.0)], RF: n_est=200, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:55:18,705] Trial 14 finished with value: 0.5673182581503458 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 4, 'n_bands': 3, 'min_freq': 6, 'max_freq': 37, 'filter_order': 3, 'fs': 125, 'n_estimators': 200, 'max_depth': 10, 'min_samples_split': 7, 'min_samples_leaf': 1, 'max_features': 'sqrt'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.956 | Val acc: 0.567

→ Trying: wl=1000, st=750, tmin=500, chans=['CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=6, Bands: [(9.0, 11.0), (11.0, 13.0), (13.0, 15.0), (15.0, 17.0), (17.0, 19.0), (19.0, 21.0), (21.0, 23.0), (23.0, 25.0)], RF: n_est=50, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 37/960


[I 2025-06-28 16:55:52,548] Trial 15 finished with value: 0.4994853138175614 and parameters: {'window_length': 1000, 'stride': 750, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 6, 'n_bands': 8, 'min_freq': 9, 'max_freq': 25, 'filter_order': 4, 'fs': 125, 'n_estimators': 50, 'max_depth': 10, 'min_samples_split': 6, 'min_samples_leaf': 3, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.993 | Val acc: 0.499

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PO7', 'OZ', 'PO8']
   RSF d=4, Bands: [(4.0, 9.2), (9.2, 14.4), (14.4, 19.6), (19.6, 24.8), (24.8, 30.0)], RF: n_est=250, depth=None
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:56:34,637] Trial 16 finished with value: 0.5665712934023847 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 4, 'n_bands': 5, 'min_freq': 4, 'max_freq': 30, 'filter_order': 3, 'fs': 125, 'n_estimators': 250, 'max_depth': None, 'min_samples_split': 3, 'min_samples_leaf': 1, 'max_features': 'sqrt'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 1.000 | Val acc: 0.567

→ Trying: wl=250, st=500, tmin=500, chans=['FZ', 'CZ', 'C4', 'PZ', 'PO7']
   RSF d=5, Bands: [(6.0, 14.0), (14.0, 22.0), (22.0, 30.0), (30.0, 38.0)], RF: n_est=100, depth=5
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:57:10,916] Trial 17 finished with value: 0.5396390411338094 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 1, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 0, 'ch_PO8': 0, 'rsf_d': 5, 'n_bands': 4, 'min_freq': 6, 'max_freq': 38, 'filter_order': 3, 'fs': 125, 'n_estimators': 100, 'max_depth': 5, 'min_samples_split': 6, 'min_samples_leaf': 2, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.731 | Val acc: 0.540

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PO7', 'OZ', 'PO8']
   RSF d=4, Bands: [(11.0, 19.0), (19.0, 27.0), (27.0, 35.0)], RF: n_est=250, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:57:41,537] Trial 18 finished with value: 0.517952318400749 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 4, 'n_bands': 3, 'min_freq': 11, 'max_freq': 35, 'filter_order': 5, 'fs': 500, 'n_estimators': 250, 'max_depth': 10, 'min_samples_split': 4, 'min_samples_leaf': 1, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.980 | Val acc: 0.518

→ Trying: wl=500, st=750, tmin=500, chans=['CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=6, Bands: [(6.0, 11.0), (11.0, 16.0), (16.0, 21.0), (21.0, 26.0), (26.0, 31.0)], RF: n_est=100, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 11/960


[I 2025-06-28 16:58:16,415] Trial 19 finished with value: 0.5457659017713365 and parameters: {'window_length': 500, 'stride': 750, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 6, 'n_bands': 5, 'min_freq': 6, 'max_freq': 31, 'filter_order': 4, 'fs': 125, 'n_estimators': 100, 'max_depth': 10, 'min_samples_split': 7, 'min_samples_leaf': 4, 'max_features': 'sqrt'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.968 | Val acc: 0.546

→ Trying: wl=1000, st=500, tmin=500, chans=['FZ', 'PO7']
   RSF d=2, Bands: [(9.0, 17.333333333333336), (17.333333333333336, 25.666666666666668), (25.666666666666668, 34.0)], RF: n_est=200, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 37/960


[I 2025-06-28 16:58:32,656] Trial 20 finished with value: 0.5037684899248417 and parameters: {'window_length': 1000, 'stride': 500, 'tmin': 500, 'ch_FZ': 1, 'ch_C3': 0, 'ch_CZ': 0, 'ch_C4': 0, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 0, 'ch_PO8': 0, 'rsf_d': 2, 'n_bands': 3, 'min_freq': 9, 'max_freq': 34, 'filter_order': 3, 'fs': 125, 'n_estimators': 200, 'max_depth': 10, 'min_samples_split': 4, 'min_samples_leaf': 2, 'max_features': None}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.987 | Val acc: 0.504

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(5.0, 12.0), (12.0, 19.0), (19.0, 26.0), (26.0, 33.0)], RF: n_est=150, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:59:09,377] Trial 21 finished with value: 0.5770359298012512 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 4, 'min_freq': 5, 'max_freq': 33, 'filter_order': 3, 'fs': 125, 'n_estimators': 150, 'max_depth': 10, 'min_samples_split': 4, 'min_samples_leaf': 2, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.969 | Val acc: 0.577

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(4.0, 10.25), (10.25, 16.5), (16.5, 22.75), (22.75, 29.0)], RF: n_est=100, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 16:59:46,086] Trial 22 finished with value: 0.5639499965608585 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 4, 'min_freq': 4, 'max_freq': 29, 'filter_order': 3, 'fs': 125, 'n_estimators': 100, 'max_depth': 10, 'min_samples_split': 3, 'min_samples_leaf': 1, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.978 | Val acc: 0.564

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=4, Bands: [(5.0, 14.0), (14.0, 23.0), (23.0, 32.0)], RF: n_est=200, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 17:00:16,484] Trial 23 finished with value: 0.5624531312124735 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 4, 'n_bands': 3, 'min_freq': 5, 'max_freq': 32, 'filter_order': 4, 'fs': 125, 'n_estimators': 200, 'max_depth': 10, 'min_samples_split': 5, 'min_samples_leaf': 2, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.946 | Val acc: 0.562

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=5, Bands: [(4.0, 11.5), (11.5, 19.0), (19.0, 26.5), (26.5, 34.0)], RF: n_est=100, depth=10
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 7/960


[I 2025-06-28 17:00:58,223] Trial 24 finished with value: 0.5587141133977208 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 4, 'min_freq': 4, 'max_freq': 34, 'filter_order': 3, 'fs': 125, 'n_estimators': 100, 'max_depth': 10, 'min_samples_split': 6, 'min_samples_leaf': 1, 'max_features': 'log2'}. Best is trial 7 with value: 0.5871394563472291.


   → Train acc: 0.963 | Val acc: 0.559

→ Trying: wl=250, st=500, tmin=500, chans=['CZ', 'PZ', 'PO7', 'OZ', 'PO8']
   RSF d=4, Bands: [(5.0, 16.0), (16.0, 27.0), (27.0, 38.0)], RF: n_est=150, depth=15
task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples


In [2]:
# Trial 7 finished with value: 0.5871394563472291 and parameters: {'window_length': 250, 'stride': 500, 'tmin': 500, 'ch_FZ': 0, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 1, 'ch_PZ': 0, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 1, 'rsf_d': 5, 'n_bands': 3, 'min_freq': 4, 'max_freq': 39, 'filter_order': 3, 'fs': 125, 'n_estimators': 150, 'max_depth': 10, 'min_samples_split': 5, 'min_samples_leaf': 1, 'max_features': 'sqrt'}. Best is trial 7 with value: 0.5871394563472291.

class FilterBankRTSClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, bands=None, fs=250, order=4, n_estimators=100, max_depth=None, min_samples_split=2, min_samples_leaf=1, max_features="sqrt", class_weight="balanced", n_jobs=-1):
        self.bands = bands if bands else [(8, 12), (12, 16), (16, 20), (20, 24), (24, 30)]
        self.fs = fs
        self.order = order
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.max_features = max_features
        self.class_weight = class_weight
        self.n_jobs = n_jobs

    def compute_fb_covs(self, X):
        """X: (n_trials, C, T) → fb_covs: (n_trials, B, C, C)"""
        # Pre-compute SOS filters if not done
        if not hasattr(self, "sos_bands"):
            self.sos_bands = [butter(self.order, (l / (self.fs / 2), h / (self.fs / 2)), btype="bandpass", output="sos") for l, h in self.bands]

        n, C, _ = X.shape
        B = len(self.sos_bands)
        fb_covs = np.zeros((n, B, C, C))
        for i, sos in enumerate(self.sos_bands):
            Xf = sosfiltfilt(sos, X, axis=2)
            fb_covs[:, i] = Covariances(estimator="lwf").transform(Xf)
        return fb_covs

    def fit(self, X, y):
        self.classes_ = np.unique(y)
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        # Flatten for tangent space
        covs_flat = fb_covs.reshape(n * B, C, C)
        labels_rep = np.repeat(y, B)

        # Fit tangent space
        self.ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
        Z = self.ts.transform(covs_flat)
        Z = Z.reshape(n, B, -1)

        # Compute mutual information weights
        self.w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)
        self.w = self.w / self.w.sum()

        # Weight features
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        # Train classifier
        self.clf = make_pipeline(
            StandardScaler(),
            RandomForestClassifier(
                n_estimators=self.n_estimators,
                max_depth=self.max_depth,
                min_samples_split=self.min_samples_split,
                min_samples_leaf=self.min_samples_leaf,
                max_features=self.max_features,
                class_weight=self.class_weight,
                n_jobs=self.n_jobs,
                random_state=42,
            ),
        )
        self.clf.fit(Z_weighted, y)
        return self

    def predict(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict(Z_weighted)

    def predict_proba(self, X):
        fb_covs = self.compute_fb_covs(X)
        n, B, C, _ = fb_covs.shape

        covs_flat = fb_covs.reshape(n * B, C, C)
        Z = self.ts.transform(covs_flat).reshape(n, B, -1)
        Z_weighted = np.concatenate([np.sqrt(self.w[i]) * Z[:, i, :] for i in range(B)], axis=1)

        return self.clf.predict_proba(Z_weighted)

class RSFTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, d=8, eps=1e-6, maxiter=500):
        self.d = d
        self.eps = eps
        self.maxiter = maxiter

    def fit(self, X, y):
        covs = Covariances(estimator="lwf").transform(X)
        C1 = mean_riemann(covs[y == 0])
        C2 = mean_riemann(covs[y == 1])
        Nc = C1.shape[0]
        W0 = np.random.randn(Nc, self.d)

        def J(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            return np.linalg.norm(logm(inv(S1) @ S2), "fro") ** 2

        def grad(w_flat):
            W = w_flat.reshape(Nc, self.d)
            S1 = W.T @ C1 @ W + np.eye(self.d) * self.eps
            S2 = W.T @ C2 @ W + np.eye(self.d) * self.eps
            A = inv(S1) @ S2
            L = logm(A)
            t1 = 2 * C1 @ W @ inv(S1) @ L @ inv(S1)
            t2 = 2 * C2 @ W @ inv(S2) @ L @ inv(S2)
            return (t1 - t2).ravel()

        res = minimize(lambda w: -J(w), W0.ravel(), jac=lambda w: -grad(w), method="BFGS", options={"maxiter": self.maxiter})
        self.W_opt = res.x.reshape(Nc, self.d)
        return self

    def transform(self, X):
        return np.einsum("cd,ect->edt", self.W_opt, X)


window_length = 250
stride = 500
tmin = 500
eeg_channels = ['CZ', 'C4', 'PO7', 'OZ', 'PO8']
d_rsf = 5
n_bands = 3
min_freq = 4
max_freq = 39
filter_order = 3
fs = 125
n_estimators = 150
max_depth = 10
min_samples_split = 5
min_samples_leaf = 1
max_features = 'sqrt'
data_path = "./data/mtcaic3"


# Load data with besjt parameters
ds_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_train = np.stack([x.numpy() for x, _ in ds_train])
y_train = np.array([label[0] for _, label in ds_train])

# Load data with besjt parameters
freq_step = (max_freq - min_freq) / n_bands
bands = [(min_freq + i * freq_step, min_freq + (i + 1) * freq_step) for i in range(n_bands)]
ds_val = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=1,
    tmin=tmin,
    eeg_channels=eeg_channels,
)
X_val = np.stack([x.numpy() for x, _ in ds_val])
y_val = np.array([label[0] for _, label in ds_val])

freq_step = (max_freq - min_freq) / n_bands
bands = [(min_freq + i * freq_step, min_freq + (i + 1) * freq_step) for i in range(n_bands)]

# Create frequency bands

# Create FilterBank classifier with best parameters
clf = make_pipeline(
    # This single object now contains your entire workflow.
        RSFTransformer(d=d_rsf),
        FilterBankRTSClassifier(
            bands=bands,
            fs=fs,
            order=filter_order,
            n_estimators=n_estimators,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            max_features=max_features,
            class_weight="balanced",
            n_jobs=-1
        )
    )

# Fit on training data
clf.fit(X_train, y_train)

# Calculate accuracy
y_pred = clf.predict(X_val)
val_acc = accuracy_score(y_val, y_pred)

print(f"Validation accuracy: {val_acc:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(y_val, y_pred))

task: mi, split: train, domain: time, data_fraction: 1
skipped: 21/2400
task: mi, split: train, domain: time, data_fraction: 1
skipped: 21/2400
Validation accuracy: 0.5602

Classification Report:
              precision    recall  f1-score   support

           0       0.55      0.61      0.58      3317
           1       0.57      0.51      0.54      3374

    accuracy                           0.56      6691
   macro avg       0.56      0.56      0.56      6691
weighted avg       0.56      0.56      0.56      6691

