# Train and Apply Models

In [1]:
from ML.model_training import (
    random_train_test_split,
    train_random_forest,
    train_random_forest_regressor,
    omit_patient_video,
    train_knn_regressor,
    train_lstm,
    STSNet
)
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    classification_report,
)
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import numpy as np
import pandas as pd
import math, re, itertools
from ML import utils
import sys
from IPython.display import clear_output
from scipy.stats import pearsonr
import random

remove_list = [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]

Generate all subsets of columns for parameters.

In [None]:
from itertools import product
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    classification_report,
)

subjects = [3, 7, 12, 13, 15, 20, 21]

param_grid = {
    "lr": [0.001],
    "epochs": [10],
    "units": [32, 64, 128, 256, 512],
    "batch_size": [32, 64, 128, 256, 512],
}

best_params = None
best_mean_acc = -float("inf")

print("Starting global hyperparameter search...\n")

for lr, epochs, units, batch_size in product(
    param_grid["lr"],
    param_grid["epochs"],
    param_grid["units"],
    param_grid["batch_size"],
):
    combo_accs = []  # accuracy for each subject under this hyperparam combo

    for i in subjects:
        while True:
            X_train, X_test, arousal_train, arousal_test = omit_patient_video(
                target="arousal",
                selected_user=i,
                trials=3,
                holdout_videos=[2, 10, 15],
                exclude_users=remove_list,
            )

            # binarize labels
            arousal_train = pd.Series(
                np.where(arousal_train > 3.8, "high", "low"),
                index=arousal_train.index,
                dtype="string",
            )
            arousal_test = pd.Series(
                np.where(arousal_test > 3.8, "high", "low"),
                index=arousal_test.index,
                dtype="string",
            )

            # optional class sanity check if you want it back:
            # c = arousal_test.value_counts()
            # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
            #     continue

            X_train_bal, arousal_train_bal = balance(X_train, arousal_train)
            X_test_bal, arousal_test_bal = balance(X_test, arousal_test)
            break

        X_train_sub = X_train_bal.loc[:, features]
        X_test_sub = X_test_bal.loc[:, features]

        lstm, X_test_eval, y_test_eval = train_lstm(
            X_train_sub,
            X_test_sub,
            arousal_train_bal,
            arousal_test_bal,
            lr=lr,
            epochs=epochs,
            units=units,
            batch_size=batch_size,
            bidirectional=False,
        )

        y_prob = lstm.predict(X_test_eval).ravel()
        arousal_pred = (y_prob >= 0.5).astype(int)

        acc = accuracy_score(y_test_eval, arousal_pred)
        combo_accs.append(float(acc))

    mean_acc = float(np.mean(combo_accs)) if combo_accs else float("nan")
    print(
        f"Params lr={lr}, epochs={epochs}, units={units}, batch_size={batch_size} "
        f"-> mean acc across subjects = {mean_acc:.4f}"
    )

    if mean_acc > best_mean_acc:
        best_mean_acc = mean_acc
        best_params = {
            "lr": lr,
            "epochs": epochs,
            "units": units,
            "batch_size": batch_size,
        }

print("\nBest universal hyperparameters:")
print(best_params)
print(f"Mean accuracy across subjects (tuning): {best_mean_acc:.4f}\n")


acc_list = []
f1_list = []
precw_list = []
recw_list = []

y_test_full = []
y_pred_full = []

for i in subjects:
    while True:
        X_train, X_test, arousal_train, arousal_test = omit_patient_video(
            target="arousal",
            selected_user=i,
            trials=3,
            holdout_videos=[2, 10, 15],
            exclude_users=remove_list,
        )

        arousal_train = pd.Series(
            np.where(arousal_train > 3.8, "high", "low"),
            index=arousal_train.index,
            dtype="string",
        )
        arousal_test = pd.Series(
            np.where(arousal_test > 3.8, "high", "low"),
            index=arousal_test.index,
            dtype="string",
        )

        # c = arousal_test.value_counts()
        # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
        #     continue

        X_train, arousal_train = balance(X_train, arousal_train)
        X_test, arousal_test = balance(X_test, arousal_test)
        break

    print(f"\n=== Subject {i} ===")
    print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
    print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))

    X_train_sub = X_train.loc[:, features]
    X_test_sub = X_test.loc[:, features]

    lstm, X_test_eval, y_test_eval = train_lstm(
        X_train_sub,
        X_test_sub,
        arousal_train,
        arousal_test,
        lr=best_params["lr"],
        epochs=best_params["epochs"],
        units=best_params["units"],
        batch_size=best_params["batch_size"],
        bidirectional=False,
    )

    y_prob = lstm.predict(X_test_eval).ravel()
    arousal_pred = (y_prob >= 0.5).astype(int)

    acc = accuracy_score(y_test_eval, arousal_pred)
    f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
    prec = precision_score(y_test_eval, arousal_pred, average="weighted")
    rec = recall_score(y_test_eval, arousal_pred, average="weighted")

    print("\nConfusion Matrix (this subject):")
    print(confusion_matrix(y_test_eval, arousal_pred))
    print(f"Acc={acc:.4f}, F1w={f1:.4f}, Prec={prec:.4f}, Rec={rec:.4f}")

    acc_list.append(float(acc))
    f1_list.append(float(f1))
    precw_list.append(float(prec))
    recw_list.append(float(rec))

    y_test_full.extend(y_test_eval)
    y_pred_full.extend(
        arousal_pred.tolist() if hasattr(arousal_pred, "tolist") else list(arousal_pred)
    )

avg_acc = float(np.mean(acc_list)) if acc_list else float("nan")
avg_f1w = float(np.mean(f1_list)) if f1_list else float("nan")
avg_prec = float(np.mean(precw_list)) if precw_list else float("nan")
avg_rec = float(np.mean(recw_list)) if recw_list else float("nan")

print(
    "\nLSTM Classification Performance (cross-subject folds, best universal hyperparams)"
)
print("---------------------------------------------------")
print(f"Accuracy: {avg_acc:.4f}")
print(f"F1 (weighted): {avg_f1w:.4f}")
print(f"Precision (weighted): {avg_prec:.4f}")
print(f"Recall (weighted): {avg_rec:.4f}")

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_full, y_pred_full))

print("\nClassification Report (pooled):")
print(classification_report(y_test_full, y_pred_full, zero_division=0))

Starting global hyperparameter search...

Held-out patient: 3 | Held-out (patient, video) trials: [(3, 2), (3, 10), (3, 15)] | Excluded users: [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
Held-out patient: 7 | Held-out (patient, video) trials: [(7, 2), (7, 10), (7, 15)] | Excluded users: [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
Held-out patient: 12 | Held-out (patient, video) trials: [(12, 2), (12, 10), (12, 15)] | Excluded users: [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
Held-out patient: 13 | Held-out (patient, video) trials: [(13, 2), (13, 10), (13, 15)] | Excluded users: [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
Held-out patient: 15 | Held-ou

In [14]:
X_train, X_test, arousal_train, arousal_test = omit_patient_video(
    target="arousal",
    trials=3,
    selected_user=20,
    exclude_users=remove_list,
)
features = utils.filter_features(X_train.columns, remove_bands=["gamma", "delta"])
arousal_train = pd.Series(
    np.where(arousal_train > 3.8, "high", "low"),
    index=arousal_train.index,
    dtype="string",
)
arousal_test = pd.Series(
    np.where(arousal_test > 3.8, "high", "low"),
    index=arousal_test.index,
    dtype="string",
)


def balance(X, y, seed=5):
    c = y.value_counts()
    if c.get("high", 0) == c.get("low", 0):
        return X.reset_index(drop=True), y.reset_index(drop=True)
    maj = c.idxmax()
    m = c.min()
    keep = y[y != maj].index.union(y[y == maj].sample(m, random_state=seed).index)
    return X.loc[keep].reset_index(drop=True), y.loc[keep].reset_index(drop=True)


X_train, arousal_train = balance(X_train, arousal_train)
X_test, arousal_test = balance(X_test, arousal_test)

print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))

Held-out patient: 20 | Held-out (patient, video) trials: [(20, 2), (20, 11), (20, 15)] | Excluded users: [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]
arousal_train counts:
 low     4286
high    4286
Name: count, dtype: Int64
arousal_test counts:
 high    173
low     173
Name: count, dtype: Int64


In [15]:
best_model = None
best_acc = 0
best_keep = None

best_lr = 0
best_f1 = 0
bar_len = 30


def render(bar_str: str, status_str: str, curr):
    print(bar_str)
    print(curr)
    print(status_str, end="")
    sys.stdout.flush()


status = f"Best: index= size= | " f"acc= | f1= | prec= | rec="

results = []

X_train_sub = X_train.loc[:, features]
X_test_sub = X_test.loc[:, features]

n_low = (arousal_train == "low").sum()
n_high = (arousal_train == "high").sum()
for lr in [0.001]:
    for e in [10]:
        for u in [512]:
            for b_s in [256]:
                lstm, X_test_eval, y_test_eval = train_lstm(
                    X_train_sub,
                    X_test_sub,
                    arousal_train,
                    arousal_test,
                    lr=lr,
                    epochs=e,
                    units=u,
                    batch_size=b_s,
                    bidirectional=False,
                )
                y_prob = lstm.predict(X_test_eval).ravel()
                arousal_pred = (y_prob >= 0.5).astype(int)

                acc = accuracy_score(y_test_eval, arousal_pred)
                f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
                prec = precision_score(y_test_eval, arousal_pred, average="weighted")
                rec = recall_score(y_test_eval, arousal_pred, average="weighted")

                if acc > best_acc:
                    best_acc = acc
                    best_model = lstm
                    best_lr = lr
                    best_e = e
                    best_u = u
                    best_b_s = b_s
                    best_f1 = f1
                    best_arousal_pred = arousal_pred
                    status = (
                        f"Best: "
                        f"acc={acc:.6f} | f1={f1:.6f} | prec={prec:.6f} | rec={rec:.6f} | lr={best_lr} | epochs={best_e} | units={best_u} | batch_size={best_b_s}"
                    )
                    print(status)

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_eval, best_arousal_pred))

print("\nClassification Report (pooled):")
print(classification_report(y_test_eval, best_arousal_pred, zero_division=0))

print(status)

[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
Best: acc=0.933526 | f1=0.933231 | prec=0.941327 | rec=0.933526 | lr=0.001 | epochs=10 | units=512 | batch_size=256

Confusion Matrix (pooled):
[[150  23]
 [  0 173]]

Classification Report (pooled):
              precision    recall  f1-score   support

         0.0       1.00      0.87      0.93       173
         1.0       0.88      1.00      0.94       173

    accuracy                           0.93       346
   macro avg       0.94      0.93      0.93       346
weighted avg       0.94      0.93      0.93       346

Best: acc=0.933526 | f1=0.933231 | prec=0.941327 | rec=0.933526 | lr=0.001 | epochs=10 | units=512 | batch_size=256
