# Train and Apply Models

In [1]:
from ML.model_training import (
    omit_patient_video,
    train_lstm,
    train_lstm_regressor,
    STSNet
)
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    classification_report,
)
import numpy as np
import pandas as pd
from ML import utils
import sys
import random
from itertools import product


# remove_list = [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]
remove_list = []
X_train_vals, X_test_vals, arousal_train_vals, arousal_test_vals = omit_patient_video(
    target="arousal",
    trials=3,
    selected_user=3,
    exclude_users=remove_list,
)
features = utils.filter_features(X_train_vals.columns, remove_bands=["gamma", "delta"])

def balance(X, y, seed=5):
    c = y.value_counts()
    if c.get("high", 0) == c.get("low", 0):
        return X.reset_index(drop=True), y.reset_index(drop=True)
    maj = c.idxmax()
    m = c.min()
    keep = y[y != maj].index.union(y[y == maj].sample(m, random_state=seed).index)
    return X.loc[keep].reset_index(drop=True), y.loc[keep].reset_index(drop=True)

arousal_train = pd.Series(
    np.where(arousal_train_vals > 3.8, "high", "low"),
    index=arousal_train_vals.index,
    dtype="string",
)
arousal_test = pd.Series(
    np.where(arousal_test_vals > 3.8, "high", "low"),
    index=arousal_test_vals.index,
    dtype="string",
)

X_train, arousal_train = balance(X_train_vals, arousal_train)
X_test, arousal_test = balance(X_test_vals, arousal_test)

print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))

Held-out patient: 3 | Held-out (patient, video) trials: [(3, 1), (3, 9), (3, 11)] | Excluded users: []
arousal_train counts:
 low     20312
high    20312
Name: count, dtype: Int64
arousal_test counts:
 low     32
high    32
Name: count, dtype: Int64


Generate all subsets of columns for parameters.

In [2]:
subjects = [3, 7, 12, 13, 15, 20, 21]
# subjects = []
# for _ in range(5):
#     subjects.append(random.randint(0, 22))
print(subjects)

param_grid = {
    "lr": [0.0001],
    "epochs": [50],
    "units": [128, 256, 512],
    "batch_size": [128, 256, 512],
}

best_params = None
best_mean_acc = -float("inf")

print("Starting global hyperparameter search...\n")

for lr, epochs, units, batch_size in product(
    param_grid["lr"],
    param_grid["epochs"],
    param_grid["units"],
    param_grid["batch_size"],
):
    combo_accs = []  # accuracy for each subject under this hyperparam combo

    for i in subjects:
        while True:
            X_train, X_test, arousal_train, arousal_test = omit_patient_video(
                target="arousal",
                selected_user=i,
                trials=18,
                # holdout_videos=[2, 10, 15],
                exclude_users=remove_list,
            )

            # binarize labels
            arousal_train = pd.Series(
                np.where(arousal_train > 3.8, "high", "low"),
                index=arousal_train.index,
                dtype="string",
            )
            arousal_test = pd.Series(
                np.where(arousal_test > 3.8, "high", "low"),
                index=arousal_test.index,
                dtype="string",
            )

            # optional class sanity check if you want it back:
            # c = arousal_test.value_counts()
            # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
            #     continue

            X_train, arousal_train = balance(X_train, arousal_train)
            X_test, arousal_test = balance(X_test, arousal_test)
            break

        X_train_sub = X_train.loc[:, features]
        X_test_sub = X_test.loc[:, features]

        lstm, X_test_eval, y_test_eval = train_lstm(
            X_train_sub,
            X_test_sub,
            arousal_train,
            arousal_test,
            lr=lr,
            epochs=epochs,
            units=units,
            batch_size=batch_size,
            bidirectional=True,
        )

        y_prob = lstm.predict(X_test_eval).ravel()
        arousal_pred = (y_prob >= 0.5).astype(int)

        acc = accuracy_score(y_test_eval, arousal_pred)
        print(acc)
        combo_accs.append(float(acc))

    mean_acc = float(np.mean(combo_accs)) if combo_accs else float("nan")
    print(
        f"Params lr={lr}, epochs={epochs}, units={units}, batch_size={batch_size} "
        f"-> mean acc across subjects = {mean_acc:.4f}"
    )

    if mean_acc > best_mean_acc:
        best_mean_acc = mean_acc
        best_params = {
            "lr": lr,
            "epochs": epochs,
            "units": units,
            "batch_size": batch_size,
        }

print("\nBest universal hyperparameters:")
print(best_params)
print(f"Mean accuracy across subjects (tuning): {best_mean_acc:.4f}\n")


acc_list = []
f1_list = []
precw_list = []
recw_list = []

y_test_full = []
y_pred_full = []

for i in subjects:
    while True:
        X_train, X_test, arousal_train, arousal_test = omit_patient_video(
            target="arousal",
            selected_user=i,
            trials=18,
            # holdout_videos=[2, 10, 15],
            exclude_users=remove_list,
        )

        arousal_train = pd.Series(
            np.where(arousal_train > 3.8, "high", "low"),
            index=arousal_train.index,
            dtype="string",
        )
        arousal_test = pd.Series(
            np.where(arousal_test > 3.8, "high", "low"),
            index=arousal_test.index,
            dtype="string",
        )

        # c = arousal_test.value_counts()
        # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
        #     continue

        X_train, arousal_train = balance(X_train, arousal_train)
        X_test, arousal_test = balance(X_test, arousal_test)
        break

    print(f"\n=== Subject {i} ===")
    print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
    print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))

    X_train_sub = X_train.loc[:, features]
    X_test_sub = X_test.loc[:, features]

    lstm, X_test_eval, y_test_eval = train_lstm(
        X_train_sub,
        X_test_sub,
        arousal_train,
        arousal_test,
        lr=best_params["lr"],
        epochs=best_params["epochs"],
        units=best_params["units"],
        batch_size=best_params["batch_size"],
        bidirectional=True,
    )

    y_prob = lstm.predict(X_test_eval).ravel()
    arousal_pred = (y_prob >= 0.5).astype(int)

    acc = accuracy_score(y_test_eval, arousal_pred)
    f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
    prec = precision_score(y_test_eval, arousal_pred, average="weighted")
    rec = recall_score(y_test_eval, arousal_pred, average="weighted")

    print("\nConfusion Matrix (this subject):")
    print(confusion_matrix(y_test_eval, arousal_pred))
    print(f"Acc={acc:.4f}, F1w={f1:.4f}, Prec={prec:.4f}, Rec={rec:.4f}")

    acc_list.append(float(acc))
    f1_list.append(float(f1))
    precw_list.append(float(prec))
    recw_list.append(float(rec))

    y_test_full.extend(y_test_eval)
    y_pred_full.extend(
        arousal_pred.tolist() if hasattr(arousal_pred, "tolist") else list(arousal_pred)
    )

avg_acc = float(np.mean(acc_list)) if acc_list else float("nan")
avg_f1w = float(np.mean(f1_list)) if f1_list else float("nan")
avg_prec = float(np.mean(precw_list)) if precw_list else float("nan")
avg_rec = float(np.mean(recw_list)) if recw_list else float("nan")

print(
    "\nLSTM Classification Performance (cross-subject folds, best universal hyperparams)"
)
print("---------------------------------------------------")
print(f"Accuracy: {avg_acc:.4f}")
print(f"F1 (weighted): {avg_f1w:.4f}")
print(f"Precision (weighted): {avg_prec:.4f}")
print(f"Recall (weighted): {avg_rec:.4f}")

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_full, y_pred_full))

print("\nClassification Report (pooled):")
print(classification_report(y_test_full, y_pred_full, zero_division=0))

[3, 7, 12, 13, 15, 20, 21]
Starting global hyperparameter search...

Held-out patient: 3 | Held-out (patient, video) trials: [(3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (3, 11), (3, 12), (3, 13), (3, 14), (3, 15), (3, 16), (3, 17)] | Excluded users: []
[1m37/37[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step
0.3630849220103986
Held-out patient: 7 | Held-out (patient, video) trials: [(7, 0), (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (7, 10), (7, 11), (7, 12), (7, 13), (7, 14), (7, 15), (7, 16), (7, 17)] | Excluded users: []
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step
0.5058252427184466
Held-out patient: 12 | Held-out (patient, video) trials: [(12, 0), (12, 1), (12, 2), (12, 3), (12, 4), (12, 5), (12, 6), (12, 7), (12, 8), (12, 9), (12, 10), (12, 11), (12, 12), (12, 13), (12, 14), (12, 15), (12, 16), (12, 17)] | Excluded users: []


KeyboardInterrupt: 

## STSNet

In [None]:
subjects = [3, 7, 12, 13, 15, 20, 21]
# subjects = []
# for _ in range(5):
#     subjects.append(random.randint(0, 22))
print(subjects)

param_grid = {
    "lr": [0.0001],
    "epochs": [15],
    "units": [256, 512],
    "batch_size": [256],
}

best_params = None
best_mean_acc = -float("inf")

print("Starting global hyperparameter search...\n")

for lr, epochs, units, batch_size in product(
    param_grid["lr"],
    param_grid["epochs"],
    param_grid["units"],
    param_grid["batch_size"],
):
    combo_accs = []  # accuracy for each subject under this hyperparam combo

    for i in subjects:
        while True:
            X_train, X_test, arousal_train, arousal_test = omit_patient_video(
                target="arousal",
                selected_user=i,
                trials=5,
                # holdout_videos=[2, 10, 15],
                exclude_users=remove_list,
            )

            # binarize labels
            arousal_train = pd.Series(
                np.where(arousal_train > 3.8, "high", "low"),
                index=arousal_train.index,
                dtype="string",
            )
            arousal_test = pd.Series(
                np.where(arousal_test > 3.8, "high", "low"),
                index=arousal_test.index,
                dtype="string",
            )

            # optional class sanity check if you want it back:
            # c = arousal_test.value_counts()
            # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
            #     continue

            X_train, arousal_train = balance(X_train, arousal_train)
            X_test, arousal_test = balance(X_test, arousal_test)
            break

        X_train_sub = X_train.loc[:, features]
        X_test_sub = X_test.loc[:, features]

        lstm, X_test_eval, y_test_eval = STSNet(
            X_train_sub,
            arousal_train,
            X_test_sub,
            arousal_test
        )

        y_prob = lstm.predict(X_test_eval).ravel()
        arousal_pred = (y_prob >= 0.5).astype(int)

        acc = accuracy_score(y_test_eval, arousal_pred)
        print(acc)
        combo_accs.append(float(acc))

    mean_acc = float(np.mean(combo_accs)) if combo_accs else float("nan")
    print(
        f"Params lr={lr}, epochs={epochs}, units={units}, batch_size={batch_size} "
        f"-> mean acc across subjects = {mean_acc:.4f}"
    )

    if mean_acc > best_mean_acc:
        best_mean_acc = mean_acc
        best_params = {
            "lr": lr,
            "epochs": epochs,
            "units": units,
            "batch_size": batch_size,
        }

print("\nBest universal hyperparameters:")
print(best_params)
print(f"Mean accuracy across subjects (tuning): {best_mean_acc:.4f}\n")


acc_list = []
f1_list = []
precw_list = []
recw_list = []

y_test_full = []
y_pred_full = []

for i in subjects:
    while True:
        X_train, X_test, arousal_train, arousal_test = omit_patient_video(
            target="arousal",
            selected_user=i,
            trials=18,
            # holdout_videos=[2, 10, 15],
            exclude_users=remove_list,
        )

        arousal_train = pd.Series(
            np.where(arousal_train > 3.8, "high", "low"),
            index=arousal_train.index,
            dtype="string",
        )
        arousal_test = pd.Series(
            np.where(arousal_test > 3.8, "high", "low"),
            index=arousal_test.index,
            dtype="string",
        )

        # c = arousal_test.value_counts()
        # if c.get("high", 0) == 0 or c.get("low", 0) == 0:
        #     continue

        X_train, arousal_train = balance(X_train, arousal_train)
        X_test, arousal_test = balance(X_test, arousal_test)
        break

    print(f"\n=== Subject {i} ===")
    print("arousal_train counts:\n", arousal_train.value_counts(dropna=False))
    print("arousal_test counts:\n", arousal_test.value_counts(dropna=False))

    X_train_sub = X_train.loc[:, features]
    X_test_sub = X_test.loc[:, features]

    lstm, X_test_eval, y_test_eval = train_lstm(
        X_train_sub,
        X_test_sub,
        arousal_train,
        arousal_test,
        lr=best_params["lr"],
        epochs=best_params["epochs"],
        units=best_params["units"],
        batch_size=best_params["batch_size"],
        bidirectional=True,
    )

    y_prob = lstm.predict(X_test_eval).ravel()
    arousal_pred = (y_prob >= 0.5).astype(int)

    acc = accuracy_score(y_test_eval, arousal_pred)
    f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
    prec = precision_score(y_test_eval, arousal_pred, average="weighted")
    rec = recall_score(y_test_eval, arousal_pred, average="weighted")

    print("\nConfusion Matrix (this subject):")
    print(confusion_matrix(y_test_eval, arousal_pred))
    print(f"Acc={acc:.4f}, F1w={f1:.4f}, Prec={prec:.4f}, Rec={rec:.4f}")

    acc_list.append(float(acc))
    f1_list.append(float(f1))
    precw_list.append(float(prec))
    recw_list.append(float(rec))

    y_test_full.extend(y_test_eval)
    y_pred_full.extend(
        arousal_pred.tolist() if hasattr(arousal_pred, "tolist") else list(arousal_pred)
    )

avg_acc = float(np.mean(acc_list)) if acc_list else float("nan")
avg_f1w = float(np.mean(f1_list)) if f1_list else float("nan")
avg_prec = float(np.mean(precw_list)) if precw_list else float("nan")
avg_rec = float(np.mean(recw_list)) if recw_list else float("nan")

print(
    "\nLSTM Classification Performance (cross-subject folds, best universal hyperparams)"
)
print("---------------------------------------------------")
print(f"Accuracy: {avg_acc:.4f}")
print(f"F1 (weighted): {avg_f1w:.4f}")
print(f"Precision (weighted): {avg_prec:.4f}")
print(f"Recall (weighted): {avg_rec:.4f}")

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_full, y_pred_full))

print("\nClassification Report (pooled):")
print(classification_report(y_test_full, y_pred_full, zero_division=0))

In [None]:
best_model = None
best_acc = 0
best_keep = None

best_lr = 0
best_f1 = 0
bar_len = 30


def render(bar_str: str, status_str: str, curr):
    print(bar_str)
    print(curr)
    print(status_str, end="")
    sys.stdout.flush()


status = f"Best: index= size= | " f"acc= | f1= | prec= | rec="

results = []

X_train_sub = X_train.loc[:, features]
X_test_sub = X_test.loc[:, features]

n_low = (arousal_train == "low").sum()
n_high = (arousal_train == "high").sum()
for lr in [0.001, 0.0001]:
    for e in [10]:
        for u in [512]:
            for b_s in [256]:
                    lstm, X_test_eval, y_test_eval = train_lstm(
                        X_train_sub,
                        X_test_sub,
                        arousal_train,
                        arousal_test,
                        lr=lr,
                        epochs=e,
                        units=u,
                        batch_size=b_s,
                        bidirectional=True,
                    )
                    y_prob = lstm.predict(X_test_eval).ravel()
                    arousal_pred = (y_prob >= 0.5).astype(int)

                    acc = accuracy_score(y_test_eval, arousal_pred)
                    f1 = f1_score(y_test_eval, arousal_pred, average="weighted")
                    prec = precision_score(y_test_eval, arousal_pred, average="weighted")
                    rec = recall_score(y_test_eval, arousal_pred, average="weighted")

                    if acc > best_acc:
                        best_acc = acc
                        best_model = lstm
                        best_lr = lr
                        best_e = e
                        best_u = u
                        best_b_s = b_s
                        best_f1 = f1
                        best_arousal_pred = arousal_pred
                        status = (
                            f"Best: "
                            f"acc={acc:.6f} | f1={f1:.6f} | prec={prec:.6f} | rec={rec:.6f} | lr={best_lr} | epochs={best_e} | units={best_u} | batch_size={best_b_s}"
                        )
                        print(status)

print("\nConfusion Matrix (pooled):")
print(confusion_matrix(y_test_eval, best_arousal_pred))

print("\nClassification Report (pooled):")
print(classification_report(y_test_eval, best_arousal_pred, zero_division=0))

print(status)

[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 57ms/step
Best: acc=0.973404 | f1=0.973404 | prec=0.973458 | rec=0.973404 | lr=0.001 | epochs=10 | units=512 | batch_size=256
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 49ms/step

Confusion Matrix (pooled):
[[91  3]
 [ 2 92]]

Classification Report (pooled):
              precision    recall  f1-score   support

         0.0       0.98      0.97      0.97        94
         1.0       0.97      0.98      0.97        94

    accuracy                           0.97       188
   macro avg       0.97      0.97      0.97       188
weighted avg       0.97      0.97      0.97       188

Best: acc=0.973404 | f1=0.973404 | prec=0.973458 | rec=0.973404 | lr=0.001 | epochs=10 | units=512 | batch_size=256


In [26]:
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import numpy as np

best_model = None

best_mean_r2 = -np.inf
best_mean_pcc = -np.inf
best_lr = None
best_e = None
best_u = None
best_b_s = None
best_mean_mae = None
best_mean_rmse = None

status = "Best: mean R2= | mean PCC= | mean MAE= | mean RMSE= | lr= | epochs= | units= | batch_size="

results = []

users = [3, 7, 12, 13, 15, 20, 21]

for lr in [0.0001]:
    for e in [50]:
        for u in [512]:
            for b_s in [256]:
                fold_r2 = []
                fold_mae = []
                fold_rmse = []
                fold_pcc = []

                y_true_all = []
                y_pred_all = []

                for i in users:
                    X_train_vals, X_test_vals, arousal_train_vals, arousal_test_vals = omit_patient_video(
                        target="arousal",
                        trials=5,
                        selected_user=i,
                        exclude_users=remove_list,
                    )

                    X_train_sub = X_train_vals.loc[:, features]
                    X_test_sub = X_test_vals.loc[:, features]

                    lstm, X_test_eval, y_test_eval = train_lstm_regressor(
                        X_train_sub,
                        X_test_sub,
                        arousal_train_vals,
                        arousal_test_vals,
                        lr=lr,
                        epochs=e,
                        units=u,
                        batch_size=b_s,
                        bidirectional=True,
                    )

                    y_pred = lstm.predict(X_test_eval).ravel()
                    y_true = np.asarray(y_test_eval, dtype=float).ravel()

                    r2 = r2_score(y_true, y_pred)
                    mae = mean_absolute_error(y_true, y_pred)
                    mse = mean_squared_error(y_true, y_pred)
                    rmse = np.sqrt(mse)

                    if np.std(y_true) == 0 or np.std(y_pred) == 0:
                        pcc = np.nan
                    else:
                        pcc = np.corrcoef(y_true, y_pred)[0, 1]

                    fold_r2.append(r2)
                    fold_mae.append(mae)
                    fold_rmse.append(rmse)
                    fold_pcc.append(pcc)

                    y_true_all.append(y_true)
                    y_pred_all.append(y_pred)

                y_true_all = np.concatenate(y_true_all)
                y_pred_all = np.concatenate(y_pred_all)

                mean_r2 = float(np.nanmean(fold_r2))
                mean_mae = float(np.nanmean(fold_mae))
                mean_rmse = float(np.nanmean(fold_rmse))

                if np.std(y_true_all) == 0 or np.std(y_pred_all) == 0:
                    mean_pcc = np.nan
                else:
                    mean_pcc = float(np.corrcoef(y_true_all, y_pred_all)[0, 1])

                results.append(
                    {
                        "lr": lr,
                        "epochs": e,
                        "units": u,
                        "batch_size": b_s,
                        "mean_r2": mean_r2,
                        "mean_mae": mean_mae,
                        "mean_rmse": mean_rmse,
                        "mean_pcc": mean_pcc,
                    }
                )

                if mean_r2 > best_mean_r2:
                    best_mean_r2 = mean_r2
                    best_mean_pcc = mean_pcc
                    best_lr = lr
                    best_e = e
                    best_u = u
                    best_b_s = b_s
                    best_mean_mae = mean_mae
                    best_mean_rmse = mean_rmse

                    best_model = lstm
                    best_y_test_all = y_true_all
                    best_y_pred_all = y_pred_all

                    status = (
                        f"Best: "
                        f"mean PCC={mean_pcc:.3f} | mean R2={mean_r2:.6f} | "
                        f"mean MAE={mean_mae:.6f} | mean RMSE={mean_rmse:.6f} | "
                        f"lr={best_lr} | epochs={best_e} | units={best_u} | batch_size={best_b_s}"
                    )
                    print(status)

print("\nFinal best model (cross-validated hyperparams) metrics:")
print(f"mean R2   : {best_mean_r2:.6f}")
print(f"mean PCC  : {best_mean_pcc:.6f}")
print(f"mean MAE  : {best_mean_mae:.6f}")
print(f"mean RMSE : {best_mean_rmse:.6f}")
print(status)

utils.plot_regressor_accuracy(best_y_test_all, best_y_pred_all)


Held-out patient: 3 | Held-out (patient, video) trials: [(3, 0), (3, 3), (3, 8), (3, 10), (3, 12)] | Excluded users: [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19, 22]


KeyboardInterrupt: 