# Train and Apply Models

In [6]:
from ML.model_training import (
    train_lstm,
)
from ML.labels import build_video_rating_tables
from ML.splits import single_user_split
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    classification_report,
)
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import numpy as np
import pandas as pd
import math, re, itertools
from ML import utils
import sys
from IPython.display import clear_output
from scipy.stats import pearsonr

Generate all subsets of columns for parameters.

In [7]:
X_train, X_test, arousal_train, arousal_test = single_user_split(
    target="arousal", selected_user=20, k_holdouts=3, holdout_videos=[10, 2, 15]
)
features = utils.filter_features(X_train.columns, remove_bands=["gamma", "delta"])
arousal_train = pd.Series(
    np.where(arousal_train > 3.8, "high", "low"),
    index=arousal_train.index,
    dtype="string",
)
arousal_test = pd.Series(
    np.where(arousal_test > 3.8, "high", "low"),
    index=arousal_test.index,
    dtype="string",
)


def balance(X, y, seed=5):
    c = y.value_counts()
    if c.get("high", 0) == c.get("low", 0):
        return X.reset_index(drop=True), y.reset_index(drop=True)
    maj = c.idxmax()
    m = c.min()
    keep = y[y != maj].index.union(y[y == maj].sample(m, random_state=seed).index)
    return X.loc[keep].reset_index(drop=True), y.loc[keep].reset_index(drop=True)


X_train, arousal_train = balance(X_train, arousal_train, seed=5)
X_test, arousal_test = balance(X_test, arousal_test, seed=5)

print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))

20 [ 2 10 15]
arousal_train counts:
 low     430
high    430
Name: count, dtype: Int64
arousal_test counts:
 high    143
low     143
Name: count, dtype: Int64


## LSTM LOO

In [9]:
best_model = None
best_acc = 0
best_keep = None

best_lr = 0
best_f1 = 0
bar_len = 30

status = f"Best: index= size= | " f"acc= | f1= | prec= | rec="

results = []
X_train_sub = X_train.loc[:, features]
X_test_sub = X_test.loc[:, features]

n_low = (arousal_train == "low").sum()
n_high = (arousal_train == "high").sum()
for lr in [0.001]:
    for e in [10, 50]:
        for u in [256]:
            for b_s in [128]:
                lstm, X_test_eval, y_test_eval = train_lstm(
                    X_train_sub,
                    X_test_sub,
                    arousal_train,
                    arousal_test,
                    lr=lr,
                    epochs=e,
                    units=u,
                    batch_size=b_s,
                    bidirectional=False,
                )
                y_prob = lstm.predict(X_test_eval).ravel()
                arousal_pred = (y_prob >= 0.5).astype(int)

                acc = accuracy_score(y_test_eval, arousal_pred)
                f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
                prec = precision_score(y_test_eval, arousal_pred, average="weighted")
                rec = recall_score(y_test_eval, arousal_pred, average="weighted")

                if acc > best_acc:
                    best_acc = acc
                    best_model = lstm
                    best_lr = lr
                    best_e = e
                    best_u = u
                    best_b_s = b_s
                    best_f1 = f1
                    best_arousal_pred = arousal_pred
                    status = (
                        f"Best: "
                        f"acc={acc:.6f} | f1={f1:.6f} | prec={prec:.6f} | rec={rec:.6f} | lr={best_lr} | epochs={best_e} | units={best_u} | batch_size={best_b_s}"
                    )
                    print(status)

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_eval, best_arousal_pred))

print("\nClassification Report (pooled):")
print(classification_report(y_test_eval, best_arousal_pred, zero_division=0))

print(status)

[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
Best: acc=0.965035 | f1=0.964992 | prec=0.967320 | rec=0.965035 | lr=0.001 | epochs=10 | units=256 | batch_size=128
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step

Confusion Matrix (pooled):
[[133  10]
 [  0 143]]

Classification Report (pooled):
              precision    recall  f1-score   support

         0.0       1.00      0.93      0.96       143
         1.0       0.93      1.00      0.97       143

    accuracy                           0.97       286
   macro avg       0.97      0.97      0.96       286
weighted avg       0.97      0.97      0.96       286

Best: acc=0.965035 | f1=0.964992 | prec=0.967320 | rec=0.965035 | lr=0.001 | epochs=10 | units=256 | batch_size=128


## Choose participants for training

In [5]:
acc_list = []
f1_list = []
precw_list = []
recw_list = []
n_list = []

num_folds = 5
y_test_full = []
y_pred_full = []
removed = []

for i in range(0, 23):

    while True:
        X_train, X_test, arousal_train, arousal_test = single_user_split(
            target="arousal", k_holdouts=3, selected_user=i, holdout_videos=[10, 2, 15]
        )
        print(i)
        arousal_train = pd.Series(
            np.where(arousal_train > 3.8, "high", "low"),
            index=arousal_train.index,
            dtype="string",
        )
        arousal_test = pd.Series(
            np.where(arousal_test > 3.8, "high", "low"),
            index=arousal_test.index,
            dtype="string",
        )

        c = arousal_test.value_counts()
        # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
        #     continue
        # if c.get("low", 0) * 5 >= arousal_train.value_counts().get("low", 0):
        #     print(c, arousal_train.value_counts())
        #     continue


        X_train, arousal_train = balance(X_train, arousal_train)
        X_test, arousal_test = balance(X_test, arousal_test)
        break

    print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
    print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))


    lstm, X_test_eval, y_test_eval = train_lstm(
            X_train,
            X_test,
            arousal_train,
            arousal_test,
            lr=best_lr,
            epochs=best_e,
            units=best_u,
            batch_size=best_b_s,
            bidirectional=False,
    )
    y_prob = lstm.predict(X_test_eval).ravel()
    arousal_pred = (y_prob >= 0.5).astype(int)
    
    acc = accuracy_score(y_test_eval, arousal_pred)
    if acc < 0.59:
        print(f"Remove participant {i} acc:", acc)
        removed.append(i)
    else:
        f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
        prec = precision_score(y_test_eval, arousal_pred, average="weighted")
        rec = recall_score(y_test_eval, arousal_pred, average="weighted")

        print("\nConfusion Matrix (pooled):")
        print(confusion_matrix(y_test_eval, arousal_pred))

        acc_list.append(float(acc))
        print(f"Participant {i} acc:", acc)
        f1_list.append(float(f1))
        precw_list.append(float(prec))
        recw_list.append(float(rec))

        y_test_full.extend(y_test_eval)
        y_pred_full.extend(arousal_pred.tolist() if hasattr(arousal_pred, "tolist") else list(arousal_pred))

# averages across folds
avg_acc = float(np.mean(acc_list)) if acc_list else float("nan")
avg_f1w = float(np.mean(f1_list)) if f1_list else float("nan")
avg_prec = float(np.mean(precw_list)) if precw_list else float("nan")
avg_rec = float(np.mean(recw_list)) if recw_list else float("nan")

print("LSTM Classification Performance (cross-subject folds)")
print("---------------------------------------------------")
print(f"Accuracy: {avg_acc:.4f}")
print(f"F1 (weighted): {avg_f1w:.4f}")
print(f"Precision (weighted): {avg_prec:.4f}")
print(f"Recall (weighted): {avg_rec:.4f}")
print(f"Removed {len(removed)} Participants: {removed}")

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_full, y_pred_full))

print("\nClassification Report (pooled):")
print(classification_report(y_test_full, y_pred_full, zero_division=0))

0 [ 2 10 15]
0
arousal_train counts:
 low     407
high    407
Name: count, dtype: Int64
arousal_test counts:
 high    143
low     143
Name: count, dtype: Int64
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
Remove participant 0 acc: 0.17132867132867133
1 [ 2 10 15]
1
arousal_train counts:
 low     725
high    725
Name: count, dtype: Int64
arousal_test counts:
 high    47
low     47
Name: count, dtype: Int64
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 57ms/step
Remove participant 1 acc: 0.19148936170212766
2 [ 2 10 15]
2
arousal_train counts:
 high    690
low     690
Name: count, dtype: Int64
arousal_test counts:
 high    143
low     143
Name: count, dtype: Int64
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
Remove participant 2 acc: 0.46853146853146854
3 [ 2 10 15]
3
arousal_train counts:
 low     404
high    404
Name: count, dtype: Int64
arousal_test counts:
 high    143
low     143
Name: count, dtype: Int64
[1m9/9

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


8 [ 2 10 15]
8
arousal_train counts:
 low     467
high    467
Name: count, dtype: Int64
arousal_test counts:
 high    47
low     47
Name: count, dtype: Int64
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step
Remove participant 8 acc: 0.574468085106383
9 [ 2 10 15]
9
arousal_train counts:
 low     613
high    613
Name: count, dtype: Int64
arousal_test counts:
 low     96
high    96
Name: count, dtype: Int64
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step  
Remove participant 9 acc: 0.5208333333333334
10 [ 2 10 15]
10
arousal_train counts:
 low     624
high    624
Name: count, dtype: Int64
arousal_test counts:
 high    143
low     143
Name: count, dtype: Int64
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
Remove participant 10 acc: 0.5699300699300699
11 [ 2 10 15]
11
arousal_train counts:
 low     621
high    621
Name: count, dtype: Int64
arousal_test counts:
 low    316
Name: count, dtype: Int64
[1m10/10[0m [32m━━