In [1]:
%load_ext autoreload
%autoreload 2
import torch
from modules.competition_dataset import EEGDataset
import random
import numpy as np
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split, cross_val_score
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from scipy.linalg import expm, logm
from sklearn.pipeline import make_pipeline
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.mean import mean_riemann
from pyriemann.utils.distance import distance_riemann
from pyriemann.classification import MDM
from sklearn.feature_selection import mutual_info_classif
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from scipy.signal import sosfiltfilt, butter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


  from .autonotebook import tqdm as notebook_tqdm


device(type='cpu')

In [2]:
data_path = './data/mtcaic3'
lda_model_path = './checkpoints/mi/models/lda_mi.pkl'

# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_random_seeds(42)

In [3]:
window_length = 256
stride = window_length // 3
batch_size = 64

In [4]:
eeg_channels = [
    "FZ",
    "C3",
    "CZ",
    "C4",
    # "PZ",
    "PO7",
    # "OZ",
    "PO8",
]

dataset_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    task="mi",
    split="train",
    data_fraction=0.5,
    hardcoded_mean=False,
    eeg_channels=eeg_channels,
)

dataset_val = EEGDataset(
    data_path=data_path,
    window_length=window_length,
    stride=stride,
    task='mi',
    split='validation',
    read_labels=True,
    hardcoded_mean=False,
    eeg_channels=eeg_channels,
)

dataset_test = EEGDataset(
    data_path=data_path,
    window_length=window_length,
    stride=stride,
    task='mi',
    split='test',
    read_labels=False,
    hardcoded_mean=False,
    eeg_channels=eeg_channels,
)

dataset_train[0][0].shape

task: mi, split: train, domain: time, data_fraction: 0.5
Using 50.0% of data: 1200/1200 samples
skipped: 1/1200
task: mi, split: validation, domain: time, data_fraction: 1.0
skipped: 0/50
task: mi, split: test, domain: time, data_fraction: 1.0
skipped: 0/50


torch.Size([6, 256])

In [5]:
all_data = torch.cat([torch.stack([x for x,_ in ds]) for ds in (dataset_train, dataset_val, dataset_test)])
X_val_train = torch.cat([torch.stack([x for x,_ in ds]) for ds in (dataset_train, dataset_val)])
y_val_train = torch.cat([torch.stack([y for _,y in ds]) for ds in (dataset_train, dataset_val)])

mean = all_data.mean((0, 2))
std = all_data.std((0, 2))

X_val_train = (X_val_train - mean[None, :, None]) / std[None, :, None]

mean, std

(tensor([0.0317, 0.0496, 0.0636, 0.0403, 0.0301, 0.0531]),
 tensor([0.9850, 0.9914, 1.0514, 0.9993, 0.9921, 0.9901]))

In [6]:
import numpy as np
from sklearn.feature_selection import f_classif

# Concatenate all splits (add dataset_val and dataset_test if needed)
X_all = np.concatenate([
    dataset_train.data.numpy(),
    dataset_val.data.numpy(),
    dataset_test.data.numpy(),
], axis=0)  # shape: [N_total, C, ...]
y_all = np.concatenate([
    dataset_train.labels.numpy(),
    dataset_val.labels.numpy(),
    dataset_test.labels.numpy(),
], axis=0)  # shape: [N_total]

# Detect shape and adapt
if X_all.ndim == 3:
    # [B, C, T]
    num_samples, num_channels, time_points = X_all.shape
    channel_f_scores = []
    for i in range(num_channels):
        channel_data = X_all[:, i, :]  # [N_total, T]
        f_scores_per_timepoint, _ = f_classif(channel_data, y_all)
        aggregated_f_score = np.sum(f_scores_per_timepoint)
        channel_f_scores.append(aggregated_f_score)
elif X_all.ndim == 4:
    # [B, C, F, T]
    num_samples, num_channels, freq_points, time_points = X_all.shape
    channel_f_scores = []
    for i in range(num_channels):
        # Average over freq and time for each channel
        channel_data = X_all[:, i, :, :].mean(axis=(1, 2))  # [N_total]
        f_score, _ = f_classif(channel_data.reshape(-1, 1), y_all)
        channel_f_scores.append(f_score[0])
else:
    raise ValueError(f"Unsupported data shape: {X_all.shape}")

# Optionally, map to channel names
original_channel_names = eeg_channels
channel_scores_dict = {original_channel_names[i]: channel_f_scores[i] for i in range(num_channels)}

print("\n--- F-scores for each channel (higher score indicates more informativeness) ---")
sorted_channels = sorted(channel_scores_dict.items(), key=lambda item: item[1], reverse=True)
for channel, score in sorted_channels:
    print(f"  {channel}: {score:.2f}")

top_3_channels = [channel for channel, score in sorted_channels[:3]]
print(f"\n--- Recommended Top 3 Channels based on F-score: {top_3_channels} ---")


--- F-scores for each channel (higher score indicates more informativeness) ---
  CZ: 14698.86
  C4: 11370.12
  PO8: 7775.01
  PO7: 6639.33
  C3: 5295.83
  FZ: 5109.64

--- Recommended Top 3 Channels based on F-score: ['CZ', 'C4', 'PO8'] ---


In [7]:
# Example for train/val/test
X_train = np.stack([x.numpy() for x, y in dataset_train])  # shape: [N, C, T]
y_train = np.array([y[0] for x, y in dataset_train])

X_val = np.stack([x.numpy() for x, y in dataset_val])
y_val = np.array([y[0] for x, y in dataset_val])

X_test = np.stack([x.numpy() for x, y in dataset_test])
y_test = np.array([y[0] for x, y in dataset_test])

In [None]:
# --- A) Define your filter-bank & pre-compute SOS filters ---
bands = [(8, 12), (12, 16), (16, 20), (20, 24), (24, 30)]
sos_bands = [butter(4, (l / 125, h / 125), btype="bandpass", output="sos") for l, h in bands]


def compute_fb_covs(X):
    """X: (n_trials, C, T) → fb_covs: (n_trials, B, C, C)"""
    n, C, _ = X.shape
    B = len(sos_bands)
    fb_covs = np.zeros((n, B, C, C))
    for i, sos in enumerate(sos_bands):
        Xf = sosfiltfilt(sos, X, axis=2)
        fb_covs[:, i] = Covariances(estimator="lwf").transform(Xf)
    return fb_covs


def fb_rts_fsvm(X, y):
    """
    Implements Filter-Bank + Riemannian Tangent Space +
    Feature-weighted SVM (FBRTS + FWSVM).
    """
    # 1) Covariances per band
    fb_covs = compute_fb_covs(X)  # (n_trials, B, C, C)
    n, B, C, _ = fb_covs.shape

    # 2) Flatten bands → SPD list & repeat labels
    covs_flat = fb_covs.reshape(n * B, C, C)
    labels_rep = np.repeat(y, B)

    # 3) Tangent-space mapping
    ts = TangentSpace(metric="riemann").fit(covs_flat, labels_rep)
    Z = ts.transform(covs_flat)  # (n*B, D), with D = C(C+1)/2
    Z = Z.reshape(n, B, -1)  # (n, B, D)

    # 4) Feature-level weights per band
    #    Compute MI between each band’s features and labels
    w = mutual_info_classif(Z.reshape(n, -1), y, discrete_features=False).reshape(B, -1).mean(axis=1)  # flatten all bands → (n, B·D)  # average MI per band → (B,)
    w = w / w.sum()  # normalize to sum=1

    # 5) Apply sqrt-weighted scaling: sqrt so weights scale variances
    Z_weighted = np.concatenate([np.sqrt(w[i]) * Z[:, i, :] for i in range(B)], axis=1)  # → (n, B·D)

    # 6) Train & return a pipelined SVM
    clf = make_pipeline(StandardScaler(), SVC(kernel="linear", probability=True))
    clf.fit(Z_weighted, y)
    return ts, w, clf


def predict_fb_rts(ts, w, clf, X):
    """Given fitted ts, weights w, and clf, predict on new X."""
    fb_covs = compute_fb_covs(X)
    n, B, C, _ = fb_covs.shape
    covs_flat = fb_covs.reshape(n * B, C, C)
    Z = ts.transform(covs_flat).reshape(n, B, -1)
    Z_weighted = np.concatenate([np.sqrt(w[i]) * Z[:, i, :] for i in range(B)], axis=1)
    return clf.predict(Z_weighted)


# --- Example usage (no re-initialization of X_train, etc.) ---
ts, w, clf = fb_rts_fsvm(X_train, y_train)
y_pred = predict_fb_rts(ts, w, clf, X_val)

from sklearn.metrics import accuracy_score, classification_report

print("Val Acc:", accuracy_score(y_val, y_pred))
print(classification_report(y_val, y_pred))

In [10]:
def riemannian_geometric_median(X, max_iter=100, tol=1e-5):
    """
    Compute the Riemannian Geometric Median of a set of SPD matrices.
    This is a robust estimator of centrality, aligning with the paper's
    goal of using a MAD-like approach.

    The algorithm is an adaptation of the Weiszfeld algorithm for
    Riemannian manifolds.
    """
    current_median = mean_riemann(X)
    for i in range(max_iter):
        prev_median = current_median.copy()
        distances = np.array([distance_riemann(current_median, x) for x in X])
        distances[distances < 1e-10] = 1e-10
        
        weights = 1.0 / distances
        weights /= np.sum(weights)

        tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
        weighted_tangent_avg = np.sum(weights[:, np.newaxis, np.newaxis] * tangent_vectors, axis=0)
        current_median = current_median @ expm(weighted_tangent_avg)

        if np.linalg.norm(current_median - prev_median) < tol:
            break
            
    return current_median
    
# 1. Generate Synthetic Motor Imagery Data
n_channels = len(eeg_channels)
sfreq = 250

# 2. Build and Execute the Riemannian Pipeline
cov_estimator = Covariances(estimator='lwf')
X_train_cov = cov_estimator.fit(X_train, y_train).transform(X_train)
X_test_cov = cov_estimator.transform(X_test)

print("\nCalculating the robust reference matrix (Geometric Median)...")
reference_matrix = riemannian_geometric_median(X_train_cov)

print("Mapping covariance matrices to tangent space with custom reference...")
ts_mapper = TangentSpace(metric='riemann')
ts_mapper.fit(X_train_cov) 
ts_mapper.reference_ = reference_matrix
X_train_tan = ts_mapper.transform(X_train_cov)
X_test_tan = ts_mapper.transform(X_test_cov)


print("\nBuilding and training the final classification pipeline...")
final_pipeline = Pipeline([
    ('pca', PCA(n_components=min(n_channels * (n_channels + 1) // 2, X_train_tan.shape[0] -1, 15))),
    ('clf', LDA(solver='lsqr', shrinkage='auto'))
])

# 3. Train the model
final_pipeline.fit(X_train_tan, y_train)

# 4. Evaluate the model on the test set
y_pred = final_pipeline.predict(X_test_tan)
accuracy = accuracy_score(y_test, y_pred)

print("\n--- Results ---")
print(f"Test Set Classification Accuracy: {accuracy * 100:.2f}%")



Calculating the robust reference matrix (Geometric Median)...


  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x in X])
  tangent_vectors = np.array([logm(np.linalg.solve(current_median, x)) for x

KeyboardInterrupt: 

In [121]:
pipeline_svm = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("tangent", TangentSpace(metric="riemann")),
        ("clf", SVC(random_state=42, tol=0.001, kernel='rbf', gamma=0.01, class_weight='balanced', C=10)),
    ]
)

pipeline_lr = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("tangent", TangentSpace(metric="riemann")),
        ("clf", LogisticRegression(random_state=42, tol=0.0001, solver='sag', penalty='l2', max_iter=1000, class_weight='balanced', C=0.001)),
    ]
)


pipeline_mdm = Pipeline([
    ('cov', Covariances(estimator="lwf")),
    ("clf", MDM()),
])

pipelines = {
    "SVM": pipeline_svm,
    "Logistic Regression": pipeline_lr,
    "MDM": pipeline_mdm,
}

results = {}

for name, pipe in pipelines.items():
    pipe.fit(X_train, y_train)
    train_acc = accuracy_score(y_train, pipe.predict(X_train))
    val_acc = accuracy_score(y_val, pipe.predict(X_val))
    results[name] = (train_acc, val_acc)

print("=== Model Performance Report ===")
for name, (train_acc, val_acc) in results.items():
    print(f"{name:>20}: Train Acc = {train_acc:.3f} | Val Acc = {val_acc:.3f}")

=== Model Performance Report ===
                 SVM: Train Acc = 0.728 | Val Acc = 0.502
 Logistic Regression: Train Acc = 0.608 | Val Acc = 0.591
                 MDM: Train Acc = 0.539 | Val Acc = 0.520


In [None]:
# SVM
# Best CV score:    0.5679161628375655
# Best parameters:  {'clf__tol': 0.001, 'clf__kernel': 'rbf', 'clf__gamma': 0.01, 'clf__class_weight': 'balanced', 'clf__C': 10}
# Validation accuracy: 0.5145502645502645

param_grid_svm = {
    "clf__kernel": ["rbf", "linear"],  # Drop poly - research shows overfitting
    "clf__C": [0.01, 0.1, 1, 10, 50],  # Add lower values
    "clf__gamma": [0.001, 0.01, 0.1, "scale"],  # Finer granularity
    "clf__class_weight": ["balanced"],
    "clf__tol": [1e-3, 1e-4],
    # Remove degree/coef0 (irrelevant for non-poly kernels)
}

pipeline = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("tangent", TangentSpace(metric="riemann")),
        ("clf", SVC(random_state=42)),
    ]
)

grid = RandomizedSearchCV(
    estimator=pipeline,
    n_iter=50,
    param_distributions=param_grid_svm,
    # param_grid=param_grid,
    cv=3,
    scoring="accuracy",
    n_jobs=-1,
    verbose=2,
)

grid.fit(X_train, y_train)

# 3) Inspect best params & CV score
print("Best CV score:   ", grid.best_score_)
print("Best parameters: ", grid.best_params_)

# # 4) Evaluate on validation set
best_model = grid.best_estimator_
y_val_pred = best_model.predict(X_val)
val_acc = accuracy_score(y_val, y_val_pred)
print("Validation accuracy:", val_acc)

Fitting 3 folds for each of 50 candidates, totalling 150 fits
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=scale, clf__kernel=linear, clf__tol=0.001; total time=   7.0s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=scale, clf__kernel=linear, clf__tol=0.001; total time=   7.5s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=scale, clf__kernel=linear, clf__tol=0.001; total time=   7.9s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=linear, clf__tol=0.0001; total time=   7.9s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=linear, clf__tol=0.0001; total time=   8.4s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=rbf, clf__tol=0.001; total time=   8.6s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=rbf, clf__tol=0.001; total time=   8.8s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__gamma=0.001, clf__kernel=r

In [41]:
# Logistic Regression 
# Best CV score:    0.5507187961843343
# Best parameters:  {'clf__tol': 0.0001, 'clf__solver': 'sag', 'clf__penalty': 'l2', 'clf__max_iter': 1000, 'clf__class_weight': 'balanced', 'clf__C': 0.001}
# Validation accuracy: 0.5066137566137566

param_grid_lr = {
    "clf__penalty": ["l2", None],
    "clf__C": [0.001, 0.01, 0.1, 1, 10],
    "clf__solver": ["lbfgs", "sag"],
    "clf__max_iter": [1000],
    "clf__class_weight": ["balanced"],
    "clf__tol": [1e-4]
}

pipeline = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("tangent", TangentSpace(metric="riemann")),
        ("clf", LogisticRegression(random_state=42)),
    ]
)

grid_lr = RandomizedSearchCV(
    estimator=pipeline,
    n_iter=70,
    param_distributions=param_grid_lr,
    # param_grid=param_grid,
    cv=3,
    scoring="accuracy",
    n_jobs=-1,
    verbose=2,
)

grid_lr.fit(X_train, y_train)

# 3) Inspect best params & CV score
print("Best CV score:   ", grid_lr.best_score_)
print("Best parameters: ", grid_lr.best_params_)

# # 4) Evaluate on validation set
best_model = grid_lr.best_estimator_
y_val_pred = best_model.predict(X_val)
val_acc = accuracy_score(y_val, y_val_pred)
print("Validation accuracy:", val_acc)



Fitting 3 folds for each of 20 candidates, totalling 60 fits




[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.0s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.0s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.2s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.3s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.7s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.9s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.3s
[CV



[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   5.9s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.2s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.1s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=0.001, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.6s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.5s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.8s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.1s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.4s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.0s




[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.4s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   6.5s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=0.01, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   8.0s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.3s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.2s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.8s




[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.4s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.8s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.0s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.6s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.9s




[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   6.5s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.0s




[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   6.8s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.9s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.8s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.3s
[CV] END clf__C=0.1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   7.8s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.1s
[CV] END clf__C=1, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.5s
[CV] END clf__C=1, clf__cla



[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   6.7s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.0s




[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   7.2s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=l2, clf__solver=sag, clf__tol=0.0001; total time=   7.5s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   6.4s




[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=lbfgs, clf__tol=0.0001; total time=   5.5s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   5.4s




[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   5.3s
[CV] END clf__C=10, clf__class_weight=balanced, clf__max_iter=1000, clf__penalty=None, clf__solver=sag, clf__tol=0.0001; total time=   4.1s
Best CV score:    0.5507187961843343
Best parameters:  {'clf__tol': 0.0001, 'clf__solver': 'sag', 'clf__penalty': 'l2', 'clf__max_iter': 1000, 'clf__class_weight': 'balanced', 'clf__C': 0.001}
Validation accuracy: 0.5066137566137566


In [45]:
# # MDM

# param_grid_mdm = {
#     "clf__metric": ["riemann"],
#     "clf__n_means": [3, 5, 7],  # Number of power means
#     "clf__h_values": [
#         [-1, 0, 1], 
#         [-0.5, 0, 0.5],
#         [-1, -0.2, 0.2, 1]
#     ],  # Power parameters
#     "clf__mean_type": ["power"]
# }

pipeline = Pipeline(
    [
        ("cov", Covariances(estimator="lwf")),
        ("clf", MDM(metric="riemann")),
    ]
)

# grid_mdm = RandomizedSearchCV(
#     estimator=pipeline,
#     n_iter=70,
#     param_distributions=param_grid_mdm,
#     # param_grid=param_grid,
#     cv=3,
#     scoring="accuracy",
#     n_jobs=-1,
#     verbose=2,
# )

pipeline.fit(X_train, y_train)

# 3) Inspect best params & CV score
# print("Best CV score:   ", pipeline.best_score_)
# print("Best parameters: ", pipeline.best_params_)

# # 4) Evaluate on validation set
# best_model = pipeline.best_estimator_
y_val_pred = pipeline.predict(X_val)
val_acc = accuracy_score(y_val, y_val_pred)
print("Validation accuracy:", val_acc)

Validation accuracy: 0.5185185185185185
