In [7]:
%load_ext autoreload
%autoreload 2
import torch
from modules.competition_dataset import EEGDataset
import random
import numpy as np
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split, cross_val_score
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from scipy.linalg import expm, logm
from sklearn.pipeline import make_pipeline
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.mean import mean_riemann
from pyriemann.utils.distance import distance_riemann
from pyriemann.classification import MDM
from sklearn.feature_selection import mutual_info_classif
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from scipy.signal import sosfiltfilt, butter
from sklearn.ensemble import RandomForestClassifier

from sklearn.metrics import accuracy_score, classification_report

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


device(type='cpu')

In [2]:
data_path = './data/mtcaic3'
lda_model_path = './checkpoints/mi/models/lda_mi.pkl'

# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_random_seeds(42)

In [3]:
window_length = 256 * 3
stride = window_length // 3
batch_size = 64

In [4]:
eeg_channels = [
    "FZ",
    "C3",
    "CZ",
    "C4",
    # "PZ",
    # "PO7",
    # "OZ",
    "PO8",
]

dataset_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=0.4,
    tmin=250,
    eeg_channels=eeg_channels,
)

dataset_val = EEGDataset(
    data_path=data_path,
    window_length=window_length,
    stride=stride,
    task='mi',
    split='validation',
    read_labels=True,
    tmin=250,
    eeg_channels=eeg_channels,
)

dataset_train[0][0].shape

task: mi, split: train, domain: time, data_fraction: 0.4
Using 40.0% of data: 960/960 samples
skipped: 11/960
task: mi, split: validation, domain: time, data_fraction: 1.0
skipped: 0/50


torch.Size([5, 768])

In [5]:
# Example for train/val/test
X_train = np.stack([x.numpy() for x, y in dataset_train])  # shape: [N, C, T]
y_train = np.array([y[0] for x, y in dataset_train])

X_val = np.stack([x.numpy() for x, y in dataset_val])
y_val = np.array([y[0] for x, y in dataset_val])

In [8]:
# --- A) Define your filter-bank & pre-compute SOS filters ---
bands = [(8, 12), (12, 16), (16, 20), (20, 24), (24, 30)]
sos_bands = [butter(4, (l / 125, h / 125), btype="bandpass", output="sos") for l, h in bands]


def compute_fb_covs(X):
    """X: (n_trials, C, T) → fb_covs: (n_trials, B, C, C)"""
    n, C, _ = X.shape
    B = len(sos_bands)
    fb_covs = np.zeros((n, B, C, C))
    for i, sos in enumerate(sos_bands):
        Xf = sosfiltfilt(sos, X, axis=2)
        fb_covs[:, i] = Covariances(estimator="lwf").transform(Xf)
    return fb_covs


def fb_rts_fsvm(X, y):
    """
    Implements Filter-Bank + Riemannian Tangent Space +
    Feature-weighted SVM (FBRTS + FWSVM).
    """
    fb_covs = compute_fb_covs(X)  # (n_trials, B, C, C)
    n, B, C, _ = fb_covs.shape

    covs_flat = fb_covs.reshape(n * B, C, C)
    labels_rep = np.repeat(y, B)

    ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
    Z = ts.transform(covs_flat)  # (n*B, D), with D = C(C+1)/2
    Z = Z.reshape(n, B, -1)  # (n, B, D)

    w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)  # flatten all bands → (n, B·D)  # average MI per band → (B,)
    w = w / w.sum()

    Z_weighted = np.concatenate([np.sqrt(w[i]) * Z[:, i, :] for i in range(B)], axis=1)  # → (n, B·D)

    clf = make_pipeline(
        StandardScaler(),
        # SVC(kernel="linear", probability=True),
        RandomForestClassifier(
            n_estimators=100,
            max_depth=None,
            class_weight='balanced',
            n_jobs=-1,
        )
    )
    clf.fit(Z_weighted, y)
    return ts, w, clf


def predict_fb_rts(ts, w, clf, X):
    """Given fitted ts, weights w, and clf, predict on new X."""
    fb_covs = compute_fb_covs(X)
    n, B, C, _ = fb_covs.shape
    covs_flat = fb_covs.reshape(n * B, C, C)
    Z = ts.transform(covs_flat).reshape(n, B, -1)
    Z_weighted = np.concatenate([np.sqrt(w[i]) * Z[:, i, :] for i in range(B)], axis=1)
    return clf.predict(Z_weighted)


# --- Example usage (no re-initialization of X_train, etc.) ---
ts, w, clf = fb_rts_fsvm(X_train, y_train)
y_pred = predict_fb_rts(ts, w, clf, X_val)


print("Val Acc:", accuracy_score(y_val, y_pred))
print(classification_report(y_val, y_pred))

Val Acc: 0.5142857142857142
              precision    recall  f1-score   support

           0       0.61      0.38      0.47        79
           1       0.46      0.69      0.55        61

    accuracy                           0.51       140
   macro avg       0.54      0.53      0.51       140
weighted avg       0.55      0.51      0.51       140

