# EEGo Models

In [12]:
from pathlib import Path

import numpy as np
import pandas as pd

from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    classification_report,
)

from ML.model_training import (
    train_lstm,
    build_eego_lstm_sequences
)
from ML import utils
import sys
import random
from itertools import product

# Data prep functions
def load_eego_df(
    filename: str | Path = "EEGo_labeled.csv",
) -> pd.DataFrame:
    """
    Load EEGo_labeled.csv and do minimal cleaning.
    """
    df = pd.read_csv(filename)

    # Sort by (user, session, time)
    sort_cols = ["user_id", "session_id"]
    if "time_elapsed" in df.columns:
        sort_cols.append("time_elapsed")
    elif "timestamp" in df.columns:
        sort_cols.append("timestamp")

    df = df.sort_values(sort_cols).reset_index(drop=True)
    return df

def select_eego_features(df: pd.DataFrame) -> list[str]:
    eeg_prefixes = [
        "AF3_", "F7_", "F3_", "FC5_", "T7_", "P7_",
        "O1_", "O2_", "P8_", "T8_", "FC6_", "F4_", "F8_", "AF4_",
    ]
    
    eeg_features = [
        c
        for c in df.columns
        if any(c.startswith(p) for p in eeg_prefixes)
    ]

    # feature_cols = eeg_features + ["affect_minute", "user_id"]
    feature_cols = eeg_features
    return feature_cols

def balance_binary_sequences(
    X: np.ndarray, y: np.ndarray, seed: int = 5
) -> tuple[np.ndarray, np.ndarray]:
    """
    Downsample the majority class at the *sequence* level
    so that classes 0 and 1 are balanced.
    """
    rng = np.random.default_rng(seed)

    idx_pos = np.where(y == 1.0)[0]
    idx_neg = np.where(y == 0.0)[0]

    n_pos = len(idx_pos)
    n_neg = len(idx_neg)

    if n_pos == 0 or n_neg == 0 or n_pos == n_neg:
        return X, y

    if n_pos > n_neg:
        keep_pos = rng.choice(idx_pos, size=n_neg, replace=False)
        keep_idx = np.concatenate([keep_pos, idx_neg])
    else:
        keep_neg = rng.choice(idx_neg, size=n_pos, replace=False)
        keep_idx = np.concatenate([keep_neg, idx_pos])

    keep_idx = np.sort(keep_idx)
    return X[keep_idx], y[keep_idx]

Setup the EEGo dataset with features and identify and organize sessions.
Create a feature table using eegproc.

In [25]:
# Load EEGo data
df_eego = load_eego_df("datasets/EEGo_labeled.csv")
print("EEGo shape:", df_eego.shape)

# Choose columns to use as features
feature_cols = select_eego_features(df_eego)
print("n_features:", len(feature_cols))

AROUSAL = "affect_arousal"
VALENCE = "affect_valence"
THRESH = 2.5

session_ids = df_eego["session_id"].unique().tolist()
print("Number of sessions:", len(session_ids))

counts = df_eego["affect_emotion"].value_counts()
print(counts)
print(
    "Baseline:",
    (counts["E"] + counts["A"])
    / (counts["E"] + counts["A"] + counts["B"] + counts["R"]),
)

EEGo shape: (58628, 84)
n_features: 70
Number of sessions: 20
affect_emotion
E    22092
A    15277
R    11950
B     9309
Name: count, dtype: int64
Baseline: 0.6373916899774852


In [14]:
import eegproc as eeg

features_table: pd.DataFrame = pd.DataFrame()
freqs = {
    "delta": (0.5, 4.0),
    "theta": (4.0, 8.0),
    "alpha": (8.0, 13.0),
    "betaL": (13.0, 20.0),
    "betaH": (20.0, 30.0),
    "gamma": (30.0, 45.0),
}
ch_names = [
    "AF3",
    "F7",
    "F3",
    "FC5",
    "T7",
    "P7",
    "O1",
    "O2",
    "P8",
    "T8",
    "FC6",
    "F4",
    "F8",
    "AF4",
]
meta_cols = [
    "user_id",
    "session_id",
    "object_count",
    "time_elapsed",
    "arousal",
    "valence",
    "fall_speed",
    "difficulty_type",
    "sensor_contact_quality",
    "timestamp",
    "affect_minute",
    "affect_emotion",
    "affect_valence",
    "affect_arousal"
]


FS = 128

# for session in df_eego["session_id"].unique():
#     mask = df_eego["session_id"] == session
#     eeg_df = df_eego.loc[mask, :].copy()

#     if eeg_df.empty:
#         continue
#     user_id = eeg_df["user_id"]
#     session_id = eeg_df["session_id"]
#     arousal = eeg_df[AROUSAL]
#     valence = eeg_df[VALENCE]
#     minute = eeg_df["affect_minute"]


#     eeg_df = eeg_df.drop(
#         columns=meta_cols,
#         errors="ignore",
#     )

    # shannons = eeg.shannons_entropy(eeg_df, bands=freqs)
    # da = utils.compute_asymmetry_from_psd(eeg_df)
    # n = len(eeg_df)

    # batch = pd.concat([user_id, session_id, arousal, valence, minute, eeg_df, shannons, da], axis=1)

    # features_table = pd.concat([features_table, batch], ignore_index=True)

# features_table = features_table.sort_values(
#     by=["user_id", "session_id", "affect_minute"],
#     ascending=True
# )
# features_table.to_csv("datasets/eego_features.csv")
features_table = pd.read_csv("datasets/eego_features.csv")

## LSTM LOO Optimizer on EEGo sessions

Performs Leave-One-Out (omit one session) cross-validation, where each **session_id**
is held-out as the test fold once. We search over a small hyperparameter grid.

In [None]:
# Hyper-parameter grid for global EEGo LSTM
param_grid = {
    "lr": [0.0001],
    "epochs": [50],
    "units": [256],
    "batch_size": [64],
    "patience": [10],
}

sessions = []
for i in range(5):
    while True:
        r = random.randint(0, 19)
        if session_ids[r] not in sessions:
            sessions.append(session_ids[r])
            break

best_params = None
best_mean_acc = -np.inf

print("Starting EEGo LOO hyperparameter search over sessions...\n")


for lr, epochs, units, batch_size, patience in product(
    param_grid["lr"],
    param_grid["epochs"],
    param_grid["units"],
    param_grid["batch_size"],
    param_grid["patience"],
):
    combo_accs = []

    for sid in sessions[1:]:
        # Leave session out
        print("LEFT OUT:", sid)
        use_features = features_table.drop(
            columns=[VALENCE]
        )
        train_df = use_features[use_features["session_id"] != sid].reset_index(
            drop=True
        )
        test_df = use_features[use_features["session_id"] == sid].reset_index(
            drop=True
        )

        print(train_df[AROUSAL].value_counts())
        print(test_df[AROUSAL].value_counts())
        # Build sequence for lstm
        X_train_seq, y_train_seq = build_eego_lstm_sequences(
            train_df,
            feature_cols=feature_cols,
            target_col=AROUSAL,
            thresh=THRESH,
            fixed_T=1000,
        )
        X_test_seq, y_test_seq = build_eego_lstm_sequences(
            test_df,
            feature_cols=feature_cols,
            target_col=AROUSAL,
            thresh=THRESH,
            fixed_T=1000,
        )


        # X_train_seq, y_train_seq = balance_binary_sequences(X_train_seq, y_train_seq)
        # X_test_seq, y_test_seq = balance_binary_sequences(X_test_seq, y_test_seq)
        lstm_model, X_test_eval, y_test_eval = train_lstm(
            X_train_seq,
            X_test_seq,
            y_train_seq,
            y_test_seq,
            lr=lr,
            epochs=epochs,
            units=units,
            batch_size=batch_size,
            patience=patience,
            bidirectional=True,
        )

        y_pred_prob = lstm_model.predict(X_test_eval).ravel()
        y_pred = (y_pred_prob >= 0.5).astype("int32")
        y_true = y_test_eval.astype("int32")

        acc = accuracy_score(y_true, y_pred)
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
        print(acc)
        print(cm)
        combo_accs.append(acc)

    mean_acc = float(np.mean(combo_accs))
    std_acc = float(np.std(combo_accs))

    print(
        f"lr={lr}, epochs={epochs}, units={units}, batch={batch_size}, "
        f" -> mean acc={mean_acc:.4f} (std={std_acc:.4f})"
    )

    if mean_acc > best_mean_acc:
        best_mean_acc = mean_acc
        best_params = {
            "lr": lr,
            "epochs": epochs,
            "units": units,
            "batch_size": batch_size,
            "patience": patience,
        }

print("\nBest EEGo LOO mean test accuracy:", f"{best_mean_acc:.4f}")
print("Best params:", best_params)

Starting EEGo LOO hyperparameter search over sessions...

LEFT OUT: b8175a99-c266-447b-bffa-b0490a5c337f
affect_arousal
4.0    34488
2.0    20588
Name: count, dtype: int64
affect_arousal
4.0    2881
2.0     671
Name: count, dtype: int64
