# Train and Apply Models

In [1]:
from ML.model_training import (
    train_lstm,
)
from ML.labels import build_video_rating_tables
from ML.splits import single_user_split
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    classification_report,
)
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import numpy as np
import pandas as pd
import math, re, itertools
from ML import utils
import sys
from IPython.display import clear_output
from scipy.stats import pearsonr

Generate all subsets of columns for parameters.

In [None]:
X_train, X_test, arousal_train, arousal_test = single_user_split(
    target="valence", selected_user=20, k_holdouts=3, holdout_videos=[10, 2, 15]
)
features = utils.filter_features(X_train.columns, remove_bands=["gamma", "delta"])
arousal_train = pd.Series(
    np.where(arousal_train > 3.8, "high", "low"),
    index=arousal_train.index,
    dtype="string",
)
arousal_test = pd.Series(
    np.where(arousal_test > 3.8, "high", "low"),
    index=arousal_test.index,
    dtype="string",
)


def balance(X, y, seed=5):
    c = y.value_counts()
    if c.get("high", 0) == c.get("low", 0):
        return X.reset_index(drop=True), y.reset_index(drop=True)
    maj = c.idxmax()
    m = c.min()
    keep = y[y != maj].index.union(y[y == maj].sample(m, random_state=seed).index)
    return X.loc[keep].reset_index(drop=True), y.loc[keep].reset_index(drop=True)


X_train, arousal_train = balance(X_train, arousal_train, seed=5)
X_test, arousal_test = balance(X_test, arousal_test, seed=5)

print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))

20 [ 2 10 15]
arousal_train counts:
 low     430
high    430
Name: count, dtype: Int64
arousal_test counts:
 high    143
low     143
Name: count, dtype: Int64


## LSTM LOSO

In [6]:
best_model = None
best_acc = 0
best_keep = None

best_lr = 0
best_f1 = 0
bar_len = 30

status = f"Best: index= size= | " f"acc= | f1= | prec= | rec="

results = []
X_train_sub = X_train.loc[:, features]
X_test_sub = X_test.loc[:, features]

n_low = (arousal_train == "low").sum()
n_high = (arousal_train == "high").sum()
for lr in [0.001]:
    for e in [10]:
        for u in [128]:
            for b_s in [256]:
                lstm, X_test_eval, y_test_eval = train_lstm(
                    X_train_sub,
                    X_test_sub,
                    arousal_train,
                    arousal_test,
                    lr=lr,
                    epochs=e,
                    units=u,
                    batch_size=b_s,
                    bidirectional=False,
                )
                y_prob = lstm.predict(X_test_eval).ravel()
                arousal_pred = (y_prob >= 0.5).astype(int)

                acc = accuracy_score(y_test_eval, arousal_pred)
                f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
                prec = precision_score(y_test_eval, arousal_pred, average="weighted")
                rec = recall_score(y_test_eval, arousal_pred, average="weighted")

                if acc > best_acc:
                    best_acc = acc
                    best_model = lstm
                    best_lr = lr
                    best_e = e
                    best_u = u
                    best_b_s = b_s
                    best_f1 = f1
                    best_arousal_pred = arousal_pred
                    status = (
                        f"Best: "
                        f"acc={acc:.6f} | f1={f1:.6f} | prec={prec:.6f} | rec={rec:.6f} | lr={best_lr} | epochs={best_e} | units={best_u} | batch_size={best_b_s}"
                    )
                    print(status)

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_eval, best_arousal_pred))

print("\nClassification Report (pooled):")
print(classification_report(y_test_eval, best_arousal_pred, zero_division=0))

print(status)

[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
Best: acc=0.923077 | f1=0.922771 | prec=0.929888 | rec=0.923077 | lr=0.001 | epochs=10 | units=128 | batch_size=256

Confusion Matrix (pooled):
[[123  20]
 [  2 141]]

Classification Report (pooled):
              precision    recall  f1-score   support

         0.0       0.98      0.86      0.92       143
         1.0       0.88      0.99      0.93       143

    accuracy                           0.92       286
   macro avg       0.93      0.92      0.92       286
weighted avg       0.93      0.92      0.92       286

Best: acc=0.923077 | f1=0.922771 | prec=0.929888 | rec=0.923077 | lr=0.001 | epochs=10 | units=128 | batch_size=256


## Choose participants for training

In [9]:
acc_list = []
f1_list = []
precw_list = []
recw_list = []
n_list = []

num_folds = 5
y_test_full = []
y_pred_full = []
removed = []

for i in range(0, 23):

    while True:
        X_train, X_test, arousal_train, arousal_test = single_user_split(
            target="valence", k_holdouts=3, selected_user=i, holdout_videos=[10, 2, 15]
        )
        print(i)
        arousal_train = pd.Series(
            np.where(arousal_train > 3.8, "high", "low"),
            index=arousal_train.index,
            dtype="string",
        )
        arousal_test = pd.Series(
            np.where(arousal_test > 3.8, "high", "low"),
            index=arousal_test.index,
            dtype="string",
        )

        c = arousal_test.value_counts()
        # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
        #     continue
        # if c.get("low", 0) * 5 >= arousal_train.value_counts().get("low", 0):
        #     print(c, arousal_train.value_counts())
        #     continue


        X_train, arousal_train = balance(X_train, arousal_train)
        X_test, arousal_test = balance(X_test, arousal_test)
        break

    print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
    print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))


    lstm, X_test_eval, y_test_eval = train_lstm(
            X_train,
            X_test,
            arousal_train,
            arousal_test,
            lr=best_lr,
            epochs=best_e,
            units=best_u,
            batch_size=best_b_s,
            bidirectional=False,
    )
    y_prob = lstm.predict(X_test_eval).ravel()
    arousal_pred = (y_prob >= 0.5).astype(int)
    
    acc = accuracy_score(y_test_eval, arousal_pred)
    if acc < 0.59:
        print(f"Remove participant {i} acc:", acc)
        removed.append(i)
    else:
        f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
        prec = precision_score(y_test_eval, arousal_pred, average="weighted")
        rec = recall_score(y_test_eval, arousal_pred, average="weighted")

        print("\nConfusion Matrix (pooled):")
        print(confusion_matrix(y_test_eval, arousal_pred))

        acc_list.append(float(acc))
        print(f"Participant {i} acc:", acc)
        f1_list.append(float(f1))
        precw_list.append(float(prec))
        recw_list.append(float(rec))

        y_test_full.extend(y_test_eval)
        y_pred_full.extend(arousal_pred.tolist() if hasattr(arousal_pred, "tolist") else list(arousal_pred))

# averages across folds
avg_acc = float(np.mean(acc_list)) if acc_list else float("nan")
avg_f1w = float(np.mean(f1_list)) if f1_list else float("nan")
avg_prec = float(np.mean(precw_list)) if precw_list else float("nan")
avg_rec = float(np.mean(recw_list)) if recw_list else float("nan")

print("LSTM Classification Performance (cross-subject folds)")
print("---------------------------------------------------")
print(f"Accuracy: {avg_acc:.4f}")
print(f"F1 (weighted): {avg_f1w:.4f}")
print(f"Precision (weighted): {avg_prec:.4f}")
print(f"Recall (weighted): {avg_rec:.4f}")
print(f"Removed {len(removed)} Participants: {removed}")

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_full, y_pred_full))

print("\nClassification Report (pooled):")
print(classification_report(y_test_full, y_pred_full, zero_division=0))

0 [ 2 10 15]
0
arousal_train counts:
 high    646
low     646
Name: count, dtype: Int64
arousal_test counts:
 high    96
low     96
Name: count, dtype: Int64
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step  
Remove participant 0 acc: 0.5625
1 [ 2 10 15]
1
arousal_train counts:
 low     367
high    367
Name: count, dtype: Int64
arousal_test counts:
 high    96
low     96
Name: count, dtype: Int64
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step  
Remove participant 1 acc: 0.5625
2 [ 2 10 15]
2
arousal_train counts:
 low     584
high    584
Name: count, dtype: Int64
arousal_test counts:
 high    316
Name: count, dtype: Int64
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step

Confusion Matrix (pooled):
[[  0   0]
 [ 86 230]]
Participant 2 acc: 0.7278481012658228


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


3 [ 2 10 15]
3
arousal_train counts:
 low     256
high    256
Name: count, dtype: Int64
arousal_test counts:
 high    143
low     143
Name: count, dtype: Int64
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step

Confusion Matrix (pooled):
[[142   1]
 [ 14 129]]
Participant 3 acc: 0.9475524475524476
4 [ 2 10 15]
4
arousal_train counts:
 high    465
low     465
Name: count, dtype: Int64
arousal_test counts:
 high    96
low     96
Name: count, dtype: Int64
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step  
Remove participant 4 acc: 0.3229166666666667
5 [ 2 10 15]
5
arousal_train counts:
 low     526
high    526
Name: count, dtype: Int64
arousal_test counts:
 high    316
Name: count, dtype: Int64
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
Remove participant 5 acc: 0.5411392405063291
6 [ 2 10 15]
6
arousal_train counts:
 low     367
high    367
Name: count, dtype: Int64
arousal_test counts:
 high    47
low     47
Name

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


10 [ 2 10 15]
10
arousal_train counts:
 low     434
high    434
Name: count, dtype: Int64
arousal_test counts:
 high    96
low     96
Name: count, dtype: Int64
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step  
Remove participant 10 acc: 0.5416666666666666
11 [ 2 10 15]
11
arousal_train counts:
 high    626
low     626
Name: count, dtype: Int64
arousal_test counts:
 high    316
Name: count, dtype: Int64
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step
Remove participant 11 acc: 0.43037974683544306
12 [ 2 10 15]
12
arousal_train counts:
 high    757
low     757
Name: count, dtype: Int64
arousal_test counts:
 high    47
low     47
Name: count, dtype: Int64
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
Remove participant 12 acc: 0.5638297872340425
13 [ 2 10 15]
13
arousal_train counts:
 low     503
high    503
Name: count, dtype: Int64
arousal_test counts:
 high    316
Name: count, dtype: Int64
[1m10/10[0m [32m━━━

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


16 [ 2 10 15]
16
arousal_train counts:
 low     431
high    431
Name: count, dtype: Int64
arousal_test counts:
 high    96
low     96
Name: count, dtype: Int64
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step  
Remove participant 16 acc: 0.359375
17 [ 2 10 15]
17
arousal_train counts:
 low     367
high    367
Name: count, dtype: Int64
arousal_test counts:
 high    316
Name: count, dtype: Int64
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step
Remove participant 17 acc: 0.38924050632911394
18 [ 2 10 15]
18
arousal_train counts:
 low     651
high    651
Name: count, dtype: Int64
arousal_test counts:
 high    316
Name: count, dtype: Int64
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step

Confusion Matrix (pooled):
[[  0   0]
 [ 62 254]]
Participant 18 acc: 0.8037974683544303


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


19 [ 2 10 15]
19
arousal_train counts:
 high    465
low     465
Name: count, dtype: Int64
arousal_test counts:
 low     143
high    143
Name: count, dtype: Int64
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
Remove participant 19 acc: 0.4825174825174825
20 [ 2 10 15]
20
arousal_train counts:
 high    470
low     470
Name: count, dtype: Int64
arousal_test counts:
 high    96
low     96
Name: count, dtype: Int64
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step  

Confusion Matrix (pooled):
[[93  3]
 [37 59]]
Participant 20 acc: 0.7916666666666666
21 [ 2 10 15]
21
arousal_train counts:
 low     409
high    409
Name: count, dtype: Int64
arousal_test counts:
 high    47
low     47
Name: count, dtype: Int64
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step

Confusion Matrix (pooled):
[[36 11]
 [18 29]]
Participant 21 acc: 0.6914893617021277
22 [ 2 10 15]
22
arousal_train counts:
 high    626
low     626
Name: count, dtype: