In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv1D, AveragePooling1D, BatchNormalization, Activation,
    Dropout, Concatenate, Bidirectional, LSTM, Flatten, Dense
)
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint

from tensorflow.keras.models import load_model
from sklearn.metrics import confusion_matrix, classification_report

import pandas as pd
import numpy as np

In [None]:
def build_light_pairs_conv_lstm_model(
    window=60,
    channels=1,
    filters=12,
    kernels=(3,5),
    dilations=(1,8),
    lstm_units=32,
    fc_units=(36, 18),
    conv_dropout=0.05,
    lstm_dropout=0.15,
    final_dropout=0.20,
    num_classes=3,
    lr=3e-4
):

    inp = Input(shape=(window, channels), name="spread_input")

    # ------------------------------------------------------
    # 1) PARALLEL CNN BRANCHES + MEAN POOLING
    # ------------------------------------------------------
    branch_seq_outputs = []

    for k in kernels:
        x = inp

        for d in dilations:
            # --- Conv layer 1 ---
            x = Conv1D(filters, kernel_size=k, dilation_rate=d,
                       padding="same", activation=None,
                       name=f"conv1_k{k}_d{d}")(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)
            x = AveragePooling1D(pool_size=2, strides=1, padding="same")(x)
            x = Dropout(conv_dropout)(x)

            # --- Conv layer 2 ---
            x = Conv1D(filters, kernel_size=k, dilation_rate=d,
                       padding="same", activation=None,
                       name=f"conv2_k{k}_d{d}")(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)
            x = AveragePooling1D(pool_size=2, strides=1, padding="same")(x)
            x = Dropout(conv_dropout)(x)

        branch_seq_outputs.append(x)

    # ------------------------------------------------------
    # 2) CONCATENATE BRANCH SEQUENCES
    # ------------------------------------------------------
    fused_seq = Concatenate(axis=-1, name="conv_fusion")(branch_seq_outputs)

    # ------------------------------------------------------
    # 3) BiLSTM OVER FUSED CONV SEQUENCES
    # ------------------------------------------------------
    lstm_out = Bidirectional(
        LSTM(lstm_units, return_sequences=False, dropout=lstm_dropout),
        name="bilstm"
    )(fused_seq)

    # ------------------------------------------------------
    # 4) RAW SPREAD SKIP → FLATTEN → CONCAT
    # ------------------------------------------------------
    raw_flat = Flatten(name="raw_flatten")(inp)

    combined = Concatenate(name="fusion_with_raw")([lstm_out, raw_flat])

    # ------------------------------------------------------
    # 5) FULLY CONNECTED HEAD
    # ------------------------------------------------------
    x = Dense(fc_units[0], activation="relu")(combined)
    x = Dropout(final_dropout)(x)

    x = Dense(fc_units[1], activation="relu")(x)
    x = Dropout(final_dropout/2)(x)

    out = Dense(num_classes, activation="softmax", name="output")(x)

    model = Model(inputs=inp, outputs=out, name="LightPairsConvBiLSTM_MeanPool")

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model

In [None]:
model = build_light_pairs_conv_lstm_model()

model.summary()

In [None]:
# ============================================================
# ===================== USER PARAMETERS =======================
# ============================================================

DL_WINDOW = 60
HOLD_DOWNSAMPLE_FRAC = 0.5

DATA_FILE = "/content/final_dl_training_dataset.csv"

# ============================================================
# ===================== LOAD & PREPROCESS DATA ===============
# ============================================================

print("Loading dataset...")
df = pd.read_csv(DATA_FILE)
df["date"] = pd.to_datetime(df["date"])

# ----- Downsample hold class -----
df_hold = df[df["label"] == 0].sample(frac=HOLD_DOWNSAMPLE_FRAC, random_state=42)
df_trade = df[df["label"] != 0]
df = pd.concat([df_hold, df_trade]).sort_values("date")
df = df.reset_index(drop=True)

print("After balancing:")
print(df["label"].value_counts())

# ----- Extract spread windows -----
spread_cols = [f"spread_{i+1}" for i in range(DL_WINDOW)]
X_raw = df[spread_cols].values
y_raw = df["label"].values

# ----- Window normalization -----
def normalize_window(w):
    mean = w.mean()
    std = w.std()
    if std == 0:
        std = 1e-8
    return (w - mean) / std

X_norm = np.array([normalize_window(w) for w in X_raw])

# ----- Reshape for Conv-LSTM: (samples, timesteps=60, channels=1) -----
X = X_norm.reshape(-1, DL_WINDOW, 1)

# ----- Convert labels → [0,1,2] -----
label_map = {-1: 0, 0: 1, +1: 2}
y = np.array([label_map[v] for v in y_raw])
y_cat = to_categorical(y, num_classes=3)

# ----- Time-based split -----
N = len(df)
train_end = int(0.7 * N)
val_end = int(0.85 * N)

X_train, y_train = X[:train_end], y_cat[:train_end]
X_val,   y_val   = X[train_end:val_end], y_cat[train_end:val_end]
X_test,  y_test  = X[val_end:], y_cat[val_end:]

print("Shapes:")
print("Train:", X_train.shape, y_train.shape)
print("Val:  ", X_val.shape, y_val.shape)
print("Test: ", X_test.shape, y_test.shape)

# ============================================================
# ========================= TRAIN MODEL =======================
# ============================================================

checkpoint_all = ModelCheckpoint(
    filepath="/content/models_CE/model_epoch_{epoch:02d}_valLoss_{val_loss:.4f}_CE.h5",
    save_weights_only=False,
    save_freq="epoch"
)

history = model.fit(
    X_train,
    y_train,
    epochs=100,
    batch_size=64,
    validation_data=(X_val, y_val),
    callbacks=[checkpoint_all]
)

# ============================================================
# ========================== TESTING ==========================
# ============================================================

test_loss, test_acc = model.evaluate(X_test, y_test)
print("\nTEST ACCURACY:", test_acc)

Loading dataset...
After balancing:
label
 0    34496
 1    30438
-1    30384
Name: count, dtype: int64
Shapes:
Train: (66722, 60, 1) (66722, 3)
Val:   (14298, 60, 1) (14298, 3)
Test:  (14298, 60, 1) (14298, 3)
Epoch 1/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.5856 - loss: 0.8851



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 26ms/step - accuracy: 0.5858 - loss: 0.8849 - val_accuracy: 0.7060 - val_loss: 0.7074
Epoch 2/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6840 - loss: 0.7580



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6840 - loss: 0.7580 - val_accuracy: 0.7144 - val_loss: 0.6922
Epoch 3/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6926 - loss: 0.7354



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6926 - loss: 0.7354 - val_accuracy: 0.7135 - val_loss: 0.6949
Epoch 4/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6972 - loss: 0.7307



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.6972 - loss: 0.7307 - val_accuracy: 0.7130 - val_loss: 0.6924
Epoch 5/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6976 - loss: 0.7162



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6976 - loss: 0.7162 - val_accuracy: 0.7124 - val_loss: 0.6862
Epoch 6/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6990 - loss: 0.7152



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6990 - loss: 0.7152 - val_accuracy: 0.7118 - val_loss: 0.6834
Epoch 7/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.7014 - loss: 0.7044



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7014 - loss: 0.7044 - val_accuracy: 0.7148 - val_loss: 0.6800
Epoch 8/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7032 - loss: 0.7031



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 24ms/step - accuracy: 0.7032 - loss: 0.7031 - val_accuracy: 0.7177 - val_loss: 0.6757
Epoch 9/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7022 - loss: 0.7005



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 24ms/step - accuracy: 0.7022 - loss: 0.7005 - val_accuracy: 0.7134 - val_loss: 0.6800
Epoch 10/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7089 - loss: 0.6828



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 24ms/step - accuracy: 0.7089 - loss: 0.6828 - val_accuracy: 0.7151 - val_loss: 0.6811
Epoch 11/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7064 - loss: 0.6858



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 24ms/step - accuracy: 0.7064 - loss: 0.6858 - val_accuracy: 0.7127 - val_loss: 0.6801
Epoch 12/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7078 - loss: 0.6825



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 24ms/step - accuracy: 0.7078 - loss: 0.6825 - val_accuracy: 0.7121 - val_loss: 0.6797
Epoch 13/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7100 - loss: 0.6776



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7100 - loss: 0.6776 - val_accuracy: 0.7118 - val_loss: 0.6817
Epoch 14/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7114 - loss: 0.6769



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 26ms/step - accuracy: 0.7114 - loss: 0.6769 - val_accuracy: 0.7121 - val_loss: 0.6833
Epoch 15/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7125 - loss: 0.6691



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 24ms/step - accuracy: 0.7125 - loss: 0.6691 - val_accuracy: 0.7116 - val_loss: 0.6874
Epoch 16/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7143 - loss: 0.6648



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.7143 - loss: 0.6648 - val_accuracy: 0.7092 - val_loss: 0.6838
Epoch 17/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7203 - loss: 0.6526



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7203 - loss: 0.6526 - val_accuracy: 0.7059 - val_loss: 0.7001
Epoch 18/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7162 - loss: 0.6571



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7162 - loss: 0.6571 - val_accuracy: 0.7058 - val_loss: 0.7019
Epoch 19/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7232 - loss: 0.6463



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.7232 - loss: 0.6463 - val_accuracy: 0.7123 - val_loss: 0.6972
Epoch 20/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7219 - loss: 0.6484



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.7219 - loss: 0.6484 - val_accuracy: 0.7079 - val_loss: 0.6955
Epoch 21/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7240 - loss: 0.6415



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7240 - loss: 0.6415 - val_accuracy: 0.7051 - val_loss: 0.6969
Epoch 22/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7239 - loss: 0.6431



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.7239 - loss: 0.6431 - val_accuracy: 0.7094 - val_loss: 0.7069
Epoch 23/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.7271 - loss: 0.6328



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 26ms/step - accuracy: 0.7271 - loss: 0.6328 - val_accuracy: 0.7106 - val_loss: 0.7004
Epoch 24/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.7304 - loss: 0.6306



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 28ms/step - accuracy: 0.7304 - loss: 0.6306 - val_accuracy: 0.7103 - val_loss: 0.7086
Epoch 25/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7315 - loss: 0.6249



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 25ms/step - accuracy: 0.7315 - loss: 0.6249 - val_accuracy: 0.7081 - val_loss: 0.7036
Epoch 26/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7293 - loss: 0.6254



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7293 - loss: 0.6254 - val_accuracy: 0.7014 - val_loss: 0.7207
Epoch 27/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7328 - loss: 0.6186



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.7328 - loss: 0.6186 - val_accuracy: 0.7030 - val_loss: 0.7185
Epoch 28/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7335 - loss: 0.6177



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7335 - loss: 0.6177 - val_accuracy: 0.7023 - val_loss: 0.7244
Epoch 29/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7348 - loss: 0.6139



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.7348 - loss: 0.6139 - val_accuracy: 0.7031 - val_loss: 0.7287
Epoch 30/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7375 - loss: 0.6100



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.7375 - loss: 0.6100 - val_accuracy: 0.7093 - val_loss: 0.7215
Epoch 31/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7364 - loss: 0.6073



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7364 - loss: 0.6073 - val_accuracy: 0.7076 - val_loss: 0.7374
Epoch 32/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.7411 - loss: 0.6011



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.7411 - loss: 0.6011 - val_accuracy: 0.7019 - val_loss: 0.7446
Epoch 33/100
[1m 247/1043[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m17s[0m 22ms/step - accuracy: 0.7496 - loss: 0.5907

KeyboardInterrupt: 

Report for best model (avgpool)

In [None]:
# ----------------------------------------------------
# LOAD ANY SAVED MODEL FILE
# ----------------------------------------------------
model_path = "/content/models_CE/model_epoch_12_valLoss_0.6797_CE.h5"   # <--- change this
model = load_model(model_path)
print("Loaded model:", model_path)

# ----------------------------------------------------
# RUN PREDICTION ON TEST SET
# ----------------------------------------------------
y_pred_probs = model.predict(X_test)     # shape: (N, 3)
y_pred = np.argmax(y_pred_probs, axis=1) # convert to class labels

y_true = np.argmax(y_test, axis=1)       # convert one-hot to labels

# ----------------------------------------------------
# CONFUSION MATRIX
# ----------------------------------------------------
cm = confusion_matrix(y_true, y_pred)
print("\nConfusion Matrix:")
print(cm)

# ----------------------------------------------------
# CLASSIFICATION REPORT
# ----------------------------------------------------
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))




Loaded model: /content/models_CE/model_epoch_12_valLoss_0.6797_CE.h5
[1m447/447[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 13ms/step

Confusion Matrix:
[[3791  916  152]
 [ 815 3062  775]
 [ 377 1081 3329]]

Classification Report:
              precision    recall  f1-score   support

           0     0.7608    0.7802    0.7704      4859
           1     0.6053    0.6582    0.6306      4652
           2     0.7822    0.6954    0.7363      4787

    accuracy                         0.7121     14298
   macro avg     0.7161    0.7113    0.7124     14298
weighted avg     0.7173    0.7121    0.7135     14298



Custom Loss Fn

In [None]:
import tensorflow.keras.backend as K

In [None]:
def weighted_categorical_crossentropy(weight_matrix):
    """
    weight_matrix: 2D matrix [true_class][predicted_class]
    """
    weight_matrix = tf.constant(weight_matrix, dtype=tf.float32)

    def loss(y_true, y_pred):
        # y_true, y_pred are (batch, num_classes)
        y_true_idx = tf.argmax(y_true, axis=-1)     # true class index
        y_pred_idx = tf.argmax(y_pred, axis=-1)     # predicted class index

        # gather weights per sample
        weights = tf.gather_nd(weight_matrix,
            tf.stack([y_true_idx, y_pred_idx], axis=1)
        )

        # standard CE
        ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)

        # apply weights
        return ce * weights

    return loss

In [None]:
penalty_matrix = [
    [1.0, 2.0, 4.0],   # true class 0 = -1
    [1.5, 1.0, 1.5],   # true class 1 =  0
    [4.0, 2.0, 1.0]    # true class 2 = +1
]

In [None]:
def build_light_pairs_conv_lstm_model(
    window=60,
    channels=1,
    filters=12,
    kernels=(3,5),
    dilations=(1,8),
    lstm_units=32,
    fc_units=(36, 18),
    conv_dropout=0.05,
    lstm_dropout=0.15,
    final_dropout=0.20,
    num_classes=3,
    lr=3e-4
):

    inp = Input(shape=(window, channels), name="spread_input")

    # ------------------------------------------------------
    # 1) PARALLEL CNN BRANCHES + MEAN POOLING
    # ------------------------------------------------------
    branch_seq_outputs = []

    for k in kernels:
        x = inp

        for d in dilations:
            # --- Conv layer 1 ---
            x = Conv1D(filters, kernel_size=k, dilation_rate=d,
                       padding="same", activation=None,
                       name=f"conv1_k{k}_d{d}")(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)
            x = AveragePooling1D(pool_size=2, strides=1, padding="same")(x)
            x = Dropout(conv_dropout)(x)

            # --- Conv layer 2 ---
            x = Conv1D(filters, kernel_size=k, dilation_rate=d,
                       padding="same", activation=None,
                       name=f"conv2_k{k}_d{d}")(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)
            x = AveragePooling1D(pool_size=2, strides=1, padding="same")(x)
            x = Dropout(conv_dropout)(x)

        branch_seq_outputs.append(x)

    # ------------------------------------------------------
    # 2) CONCATENATE BRANCH SEQUENCES
    # ------------------------------------------------------
    fused_seq = Concatenate(axis=-1, name="conv_fusion")(branch_seq_outputs)

    # ------------------------------------------------------
    # 3) BiLSTM OVER FUSED CONV SEQUENCES
    # ------------------------------------------------------
    lstm_out = Bidirectional(
        LSTM(lstm_units, return_sequences=False, dropout=lstm_dropout),
        name="bilstm"
    )(fused_seq)

    # ------------------------------------------------------
    # 4) RAW SPREAD SKIP → FLATTEN → CONCAT
    # ------------------------------------------------------
    raw_flat = Flatten(name="raw_flatten")(inp)

    combined = Concatenate(name="fusion_with_raw")([lstm_out, raw_flat])

    # ------------------------------------------------------
    # 5) FULLY CONNECTED HEAD
    # ------------------------------------------------------
    x = Dense(fc_units[0], activation="relu")(combined)
    x = Dropout(final_dropout)(x)

    x = Dense(fc_units[1], activation="relu")(x)
    x = Dropout(final_dropout/2)(x)

    out = Dense(num_classes, activation="softmax", name="output")(x)

    model = Model(inputs=inp, outputs=out, name="LightPairsConvBiLSTM_MeanPool")

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss=weighted_categorical_crossentropy(penalty_matrix),
        metrics=["accuracy"]
    )

    return model

In [None]:
model = build_light_pairs_conv_lstm_model()

model.summary()

In [None]:
# ============================================================
# ===================== USER PARAMETERS =======================
# ============================================================

DL_WINDOW = 60
HOLD_DOWNSAMPLE_FRAC = 0.5

DATA_FILE = "/content/final_dl_training_dataset.csv"

# ============================================================
# ===================== LOAD & PREPROCESS DATA ===============
# ============================================================

print("Loading dataset...")
df = pd.read_csv(DATA_FILE)
df["date"] = pd.to_datetime(df["date"])

# ----- Downsample hold class -----
df_hold = df[df["label"] == 0].sample(frac=HOLD_DOWNSAMPLE_FRAC, random_state=42)
df_trade = df[df["label"] != 0]
df = pd.concat([df_hold, df_trade]).sort_values("date")
df = df.reset_index(drop=True)

print("After balancing:")
print(df["label"].value_counts())

# ----- Extract spread windows -----
spread_cols = [f"spread_{i+1}" for i in range(DL_WINDOW)]
X_raw = df[spread_cols].values
y_raw = df["label"].values

# ----- Window normalization -----
def normalize_window(w):
    mean = w.mean()
    std = w.std()
    if std == 0:
        std = 1e-8
    return (w - mean) / std

X_norm = np.array([normalize_window(w) for w in X_raw])

# ----- Reshape for Conv-LSTM: (samples, timesteps=60, channels=1) -----
X = X_norm.reshape(-1, DL_WINDOW, 1)

# ----- Convert labels → [0,1,2] -----
label_map = {-1: 0, 0: 1, +1: 2}
y = np.array([label_map[v] for v in y_raw])
y_cat = to_categorical(y, num_classes=3)

# ----- Time-based split -----
N = len(df)
train_end = int(0.7 * N)
val_end = int(0.85 * N)

X_train, y_train = X[:train_end], y_cat[:train_end]
X_val,   y_val   = X[train_end:val_end], y_cat[train_end:val_end]
X_test,  y_test  = X[val_end:], y_cat[val_end:]

print("Shapes:")
print("Train:", X_train.shape, y_train.shape)
print("Val:  ", X_val.shape, y_val.shape)
print("Test: ", X_test.shape, y_test.shape)

# ============================================================
# ========================= TRAIN MODEL =======================
# ============================================================

checkpoint_all = ModelCheckpoint(
    filepath="/content/models/model_epoch_{epoch:02d}_valLoss_{val_loss:.4f}.h5",
    save_weights_only=False,
    save_freq="epoch"
)

history = model.fit(
    X_train,
    y_train,
    epochs=100,
    batch_size=64,
    validation_data=(X_val, y_val),
    callbacks=[checkpoint_all]
)

# ============================================================
# ========================== TESTING ==========================
# ============================================================

test_loss, test_acc = model.evaluate(X_test, y_test)
print("\nTEST ACCURACY:", test_acc)

Loading dataset...
After balancing:
label
 0    34496
 1    30438
-1    30384
Name: count, dtype: int64
Shapes:
Train: (66722, 60, 1) (66722, 3)
Val:   (14298, 60, 1) (14298, 3)
Test:  (14298, 60, 1) (14298, 3)
Epoch 1/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.5659 - loss: 1.5936



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 25ms/step - accuracy: 0.5659 - loss: 1.5935 - val_accuracy: 0.6952 - val_loss: 1.2032
Epoch 2/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6694 - loss: 1.2898



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 24ms/step - accuracy: 0.6694 - loss: 1.2898 - val_accuracy: 0.7051 - val_loss: 1.1702
Epoch 3/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6739 - loss: 1.2713



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6739 - loss: 1.2713 - val_accuracy: 0.7079 - val_loss: 1.1586
Epoch 4/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6817 - loss: 1.2406



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6817 - loss: 1.2406 - val_accuracy: 0.7054 - val_loss: 1.1733
Epoch 5/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6819 - loss: 1.2285



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6819 - loss: 1.2284 - val_accuracy: 0.7041 - val_loss: 1.1523
Epoch 6/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6851 - loss: 1.2164



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6851 - loss: 1.2164 - val_accuracy: 0.7032 - val_loss: 1.1496
Epoch 7/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6804 - loss: 1.2142



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6804 - loss: 1.2142 - val_accuracy: 0.6958 - val_loss: 1.1723
Epoch 8/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6818 - loss: 1.1995



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6818 - loss: 1.1995 - val_accuracy: 0.7072 - val_loss: 1.1372
Epoch 9/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6866 - loss: 1.1799



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6866 - loss: 1.1799 - val_accuracy: 0.7074 - val_loss: 1.1427
Epoch 10/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6836 - loss: 1.1891



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6836 - loss: 1.1890 - val_accuracy: 0.7058 - val_loss: 1.1409
Epoch 11/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6878 - loss: 1.1733



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6878 - loss: 1.1732 - val_accuracy: 0.7013 - val_loss: 1.1514
Epoch 12/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6847 - loss: 1.1659



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6847 - loss: 1.1659 - val_accuracy: 0.7038 - val_loss: 1.1494
Epoch 13/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6852 - loss: 1.1593



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.6852 - loss: 1.1593 - val_accuracy: 0.7097 - val_loss: 1.1488
Epoch 14/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6905 - loss: 1.1327



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6905 - loss: 1.1327 - val_accuracy: 0.7076 - val_loss: 1.1530
Epoch 15/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6895 - loss: 1.1304



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6895 - loss: 1.1304 - val_accuracy: 0.7032 - val_loss: 1.1427
Epoch 16/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6891 - loss: 1.1189



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6891 - loss: 1.1189 - val_accuracy: 0.6876 - val_loss: 1.1968
Epoch 17/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6903 - loss: 1.1156



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6903 - loss: 1.1156 - val_accuracy: 0.7051 - val_loss: 1.1974
Epoch 18/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6863 - loss: 1.1146



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6863 - loss: 1.1146 - val_accuracy: 0.7061 - val_loss: 1.1967
Epoch 19/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6915 - loss: 1.1050



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6915 - loss: 1.1050 - val_accuracy: 0.7049 - val_loss: 1.1721
Epoch 20/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6872 - loss: 1.1105



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.6872 - loss: 1.1105 - val_accuracy: 0.7020 - val_loss: 1.2306
Epoch 21/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6875 - loss: 1.1007



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6875 - loss: 1.1007 - val_accuracy: 0.7015 - val_loss: 1.2210
Epoch 22/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6892 - loss: 1.0966



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6892 - loss: 1.0966 - val_accuracy: 0.6728 - val_loss: 1.2606
Epoch 23/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6904 - loss: 1.0916



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6904 - loss: 1.0916 - val_accuracy: 0.7027 - val_loss: 1.2368
Epoch 24/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6894 - loss: 1.0868



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6894 - loss: 1.0868 - val_accuracy: 0.6825 - val_loss: 1.2514
Epoch 25/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6934 - loss: 1.0713



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 25ms/step - accuracy: 0.6934 - loss: 1.0714 - val_accuracy: 0.7000 - val_loss: 1.2440
Epoch 26/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6872 - loss: 1.0773



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6872 - loss: 1.0773 - val_accuracy: 0.7045 - val_loss: 1.2333
Epoch 27/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6894 - loss: 1.0756



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 25ms/step - accuracy: 0.6894 - loss: 1.0756 - val_accuracy: 0.6885 - val_loss: 1.2889
Epoch 28/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6936 - loss: 1.0588



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 25ms/step - accuracy: 0.6936 - loss: 1.0589 - val_accuracy: 0.6826 - val_loss: 1.2697
Epoch 29/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6900 - loss: 1.0641



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 26ms/step - accuracy: 0.6900 - loss: 1.0641 - val_accuracy: 0.7039 - val_loss: 1.2033
Epoch 30/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6938 - loss: 1.0485



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6938 - loss: 1.0485 - val_accuracy: 0.7022 - val_loss: 1.2629
Epoch 31/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.6898 - loss: 1.0528



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6898 - loss: 1.0528 - val_accuracy: 0.6841 - val_loss: 1.2957
Epoch 32/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6937 - loss: 1.0401



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6936 - loss: 1.0401 - val_accuracy: 0.6892 - val_loss: 1.3289
Epoch 33/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6925 - loss: 1.0464



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6925 - loss: 1.0464 - val_accuracy: 0.6929 - val_loss: 1.2820
Epoch 34/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6947 - loss: 1.0369



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 28ms/step - accuracy: 0.6947 - loss: 1.0369 - val_accuracy: 0.6884 - val_loss: 1.3096
Epoch 35/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.6945 - loss: 1.0359



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 28ms/step - accuracy: 0.6945 - loss: 1.0359 - val_accuracy: 0.7034 - val_loss: 1.3563
Epoch 36/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.6948 - loss: 1.0235



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6948 - loss: 1.0235 - val_accuracy: 0.6823 - val_loss: 1.3577
Epoch 37/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6918 - loss: 1.0295



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 27ms/step - accuracy: 0.6918 - loss: 1.0295 - val_accuracy: 0.7002 - val_loss: 1.3164
Epoch 38/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.6914 - loss: 1.0227



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 26ms/step - accuracy: 0.6914 - loss: 1.0227 - val_accuracy: 0.6996 - val_loss: 1.3024
Epoch 39/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.6951 - loss: 1.0194



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 26ms/step - accuracy: 0.6951 - loss: 1.0194 - val_accuracy: 0.7029 - val_loss: 1.3686
Epoch 40/100
[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6964 - loss: 1.0132



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6964 - loss: 1.0132 - val_accuracy: 0.6783 - val_loss: 1.3479
Epoch 41/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6933 - loss: 1.0129



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 26ms/step - accuracy: 0.6933 - loss: 1.0129 - val_accuracy: 0.6876 - val_loss: 1.3390
Epoch 42/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6993 - loss: 1.0048



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6993 - loss: 1.0048 - val_accuracy: 0.6984 - val_loss: 1.3353
Epoch 43/100
[1m1042/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 23ms/step - accuracy: 0.6925 - loss: 1.0129



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6925 - loss: 1.0129 - val_accuracy: 0.7039 - val_loss: 1.3286
Epoch 44/100
[1m1041/1043[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 24ms/step - accuracy: 0.6982 - loss: 0.9949



[1m1043/1043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 25ms/step - accuracy: 0.6982 - loss: 0.9949 - val_accuracy: 0.7035 - val_loss: 1.3921
Epoch 45/100
[1m 457/1043[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m13s[0m 24ms/step - accuracy: 0.7013 - loss: 0.9867

KeyboardInterrupt: 

In [None]:
# ----------------------------------------------------
# LOAD ANY SAVED MODEL FILE
# ----------------------------------------------------
model_path = "/content/models/model_epoch_13_valLoss_1.1488.h5"   # <--- change this

model = load_model(
    model_path,
    custom_objects={"loss": weighted_categorical_crossentropy}
)

print("Loaded model:", model_path)

# ----------------------------------------------------
# RUN PREDICTION ON TEST SET
# ----------------------------------------------------
y_pred_probs = model.predict(X_test)     # shape: (N, 3)
y_pred = np.argmax(y_pred_probs, axis=1) # convert to class labels

y_true = np.argmax(y_test, axis=1)       # convert one-hot to labels

# ----------------------------------------------------
# CONFUSION MATRIX
# ----------------------------------------------------
cm = confusion_matrix(y_true, y_pred)
print("\nConfusion Matrix:")
print(cm)

# ----------------------------------------------------
# CLASSIFICATION REPORT
# ----------------------------------------------------
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))




Loaded model: /content/models/model_epoch_13_valLoss_1.1488.h5
[1m447/447[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 10ms/step

Confusion Matrix:
[[4006  650  203]
 [ 925 2701 1026]
 [ 500  778 3509]]

Classification Report:
              precision    recall  f1-score   support

           0     0.7376    0.8244    0.7786      4859
           1     0.6542    0.5806    0.6152      4652
           2     0.7406    0.7330    0.7368      4787

    accuracy                         0.7145     14298
   macro avg     0.7108    0.7127    0.7102     14298
weighted avg     0.7115    0.7145    0.7114     14298

