# 03 Pure LSTM

Train the Pure LSTM volatility model with `MSE` loss and evaluate with both `MSE` and `QLIKE`.
This notebook checkpoints predictions/logs/gates after every split and can resume.


In [None]:
from __future__ import annotations

from pathlib import Path
import sys

import pandas as pd
import matplotlib.pyplot as plt

PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / "src").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from src.evaluation import evaluate_forecasts
from src.models.rnn import RNNTrainingConfig, run_rolling_experiment
from src.utils import set_seed

set_seed(42)
pd.set_option("display.max_columns", 100)


In [None]:
data_path = PROJECT_ROOT / "data" / "processed" / "sp500_log_returns.csv"
splits_path = PROJECT_ROOT / "data" / "processed" / "rolling_splits.csv"

df = pd.read_csv(data_path, parse_dates=["date"])
splits_df = pd.read_csv(
    splits_path,
    parse_dates=[
        "train_start_date",
        "train_end_date",
        "val_start_date",
        "val_end_date",
        "test_start_date",
        "test_end_date",
    ],
)

pred_dir = PROJECT_ROOT / "reports" / "predictions"
pred_dir.mkdir(parents=True, exist_ok=True)
pred_path = pred_dir / "pure_lstm_predictions.csv"
log_path = pred_dir / "pure_lstm_train_logs.csv"
gate_path = pred_dir / "pure_lstm_gate_values.csv"

print(f"predictions path: {pred_path}")
print(f"train logs path: {log_path}")
print(f"gate path: {gate_path}")

df.head()


In [None]:
cfg = RNNTrainingConfig(
    lookback=21,
    hidden_units=8,
    dropout=0.10,
    learning_rate=1e-3,
    batch_size=64,
    epochs=35,
    patience=6,
    seed=42,
    scale_features=True,
    scale_target=True,
    target_transform="log_standardize",
    log_garch_features=True,
    eps=1e-8,
    force_linear_output=True,
)
cfg


In [None]:
predictions_df, train_logs_df, gates_df, last_fit_history = run_rolling_experiment(
    df=df,
    splits_df=splits_df,
    architecture="lstm",
    variant="pure",
    cfg=cfg,
    output_activation="linear",
    verbose_fit=0,
    capture_gates=True,
    predictions_path=pred_path,
    train_logs_path=log_path,
    gates_path=gate_path,
    resume=False,
    collect_last_history=True,
)

print(f"prediction rows: {len(predictions_df):,}")
print(f"train-log rows: {len(train_logs_df):,}")
print(f"gate rows: {len(gates_df):,}")
if last_fit_history is not None:
    print(f"last trained split id: {last_fit_history['split_id']}")
predictions_df.head()


In [None]:
metrics_df = evaluate_forecasts(
    predictions_df,
    group_cols=["variant", "architecture", "train_loss"],
)
metrics_df


In [None]:
print(f"Saved predictions: {pred_path}")
print(f"Saved train logs: {log_path}")
print(f"Saved gate values: {gate_path}")


In [None]:
# Overfitting diagnostics: train vs validation loss across rolling splits.
log_plot = train_logs_df.copy()
required_cols = {
    "split_id",
    "best_train_loss",
    "best_val_loss",
    "best_gap_val_minus_train",
    "final_train_loss",
    "final_val_loss",
    "final_gap_val_minus_train",
}
missing = sorted(required_cols.difference(log_plot.columns))
if missing:
    print(
        "Train log is missing aggregate diagnostics columns. "
        "Using last trained split history instead. "
        f"Missing: {missing}"
    )
else:
    log_plot = log_plot.sort_values("split_id").reset_index(drop=True)

    fig, axes = plt.subplots(2, 1, figsize=(14, 8), sharex=True)

    axes[0].plot(log_plot["split_id"], log_plot["best_train_loss"], label="Best Train Loss", linewidth=1.2)
    axes[0].plot(log_plot["split_id"], log_plot["best_val_loss"], label="Best Val Loss", linewidth=1.2)
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Best Loss by Rolling Split")
    axes[0].grid(alpha=0.2)
    axes[0].legend()

    axes[1].plot(log_plot["split_id"], log_plot["best_gap_val_minus_train"], label="Best Gap (Val - Train)", linewidth=1.2)
    axes[1].axhline(0.0, color="gray", linestyle="--", linewidth=1.0)
    axes[1].set_xlabel("Split ID")
    axes[1].set_ylabel("Gap")
    axes[1].set_title("Overfit Gap by Rolling Split (Best Epoch)")
    axes[1].grid(alpha=0.2)
    axes[1].legend()

    plt.tight_layout()
    plt.show()

# Last trained split epoch curve (most direct overfit inspection for rolling setup).
if last_fit_history is None:
    print(
        "No last split history captured. "
        "If resume=True and no new split was trained, set resume=False (or clear outputs) and rerun."
    )
else:
    hist_df = pd.DataFrame(
        {
            "epoch": range(1, len(last_fit_history["loss"]) + 1),
            "train_loss": last_fit_history["loss"],
            "val_loss": last_fit_history["val_loss"],
        }
    )
    fig, ax = plt.subplots(figsize=(9, 5))
    ax.plot(hist_df["epoch"], hist_df["train_loss"], label="Train Loss", linewidth=1.4)
    ax.plot(hist_df["epoch"], hist_df["val_loss"], label="Validation Loss", linewidth=1.4)
    ax.set_title(f"Last Split Loss Curves (split_id={last_fit_history['split_id']})")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.grid(alpha=0.2)
    ax.legend()
    plt.tight_layout()
    plt.show()
