In [None]:
# notebooks/umi_analysis.py
# ================================================================
#  UMI – Post-Run Analysis
# ================================================================
#
#  Produces:
#    • hyper-param table
#    • parameter inventory (top tensors, totals)
#    • Optuna trial curve (if any)
#    • training-loss curves
#    • prediction-vs-truth plot for user-selected ticker
#    • cumulative NAV curve from equity_curve.csv
# ================================================================

from pathlib import Path
import json
import gzip
import yaml

import torch
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use("ggplot")           # comment out if you dislike ggplot colours
pd.set_option("display.max_rows", 50)

# ---------------------------------------------------------------------------
# 0. Locate latest run folder
# ---------------------------------------------------------------------------
cfg = yaml.safe_load(Path("configs/umi.yaml").read_text())
run_root = Path(cfg["model_dir"]).expanduser()
best_ckpts = sorted(run_root.rglob("best.pt"), key=lambda p: p.stat().st_mtime)
if not best_ckpts:
    raise FileNotFoundError(f"No best.pt under {run_root}")
best_pt   = best_ckpts[-1]
run_dir   = best_pt.parent
print(f"🔍  Analysing run folder:  {run_dir.relative_to(Path('.').resolve())}")

# ---------------------------------------------------------------------------
# 1. Hyper-parameters
# ---------------------------------------------------------------------------
hp_file = run_dir / "hparams.json"
with hp_file.open() as fp:
    hp = json.load(fp)
hp_df = pd.DataFrame(hp.items(), columns=["param", "value"])
print("\n=== Saved hyper-parameters ===")
display(hp_df)

# ---------------------------------------------------------------------------
# 2. Parameter inventory
# ---------------------------------------------------------------------------
inv_csv = run_dir / "param_inventory.csv"
param_df = pd.read_csv(inv_csv)
total_params = int(param_df["numel"].sum())
total_bytes  = int(param_df["bytes"].sum())

print(f"\nTotal learnable parameters : {total_params:,}")
print(f"≈ Model size (fp32)        : {total_bytes/1024/1024:.1f} MB")

print("\n=== Largest tensors (top-20) ===")
display(param_df.sort_values("numel", ascending=False).head(20))

plt.figure(figsize=(10, 4))
top = param_df.sort_values("numel", ascending=False).head(10)
plt.bar(top["tensor"], top["numel"])
plt.xticks(rotation=45, ha="right")
plt.ylabel("# parameters")
plt.title("Top-10 tensors by parameter count")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------------------------
# 3. Optuna history (if any)
# ---------------------------------------------------------------------------
trial_csv = run_dir / "hp_trials.csv"
if trial_csv.exists():
    trials = pd.read_csv(trial_csv)
    trials.plot(x="trial", y="loss_pred", marker="o", figsize=(6, 3),
                title="Optuna trial losses")
    plt.ylabel("validation loss")
    plt.show()
else:
    print("\n(no hp_trials.csv found – tuning was disabled)")

# ---------------------------------------------------------------------------
# 4. Training-loss curves
# ---------------------------------------------------------------------------
loss_csv = run_dir / "train_losses.csv"
if loss_csv.exists():
    loss_df = pd.read_csv(loss_csv)
    cols = [c for c in ["loss_pred", "loss_stock", "loss_market"] if c in loss_df.columns]
    loss_df.set_index("epoch")[cols].plot(figsize=(8, 4))
    plt.title("Training losses")
    plt.ylabel("loss")
    plt.show()
else:
    print("\n(no train_losses.csv found)")

# ---------------------------------------------------------------------------
# 5. Prediction vs Truth
# ---------------------------------------------------------------------------
pred_csv  = run_dir / "bt_pred_close.csv"
truth_csv = run_dir / "bt_truth_close.csv"

if pred_csv.exists() and truth_csv.exists():
    pred  = pd.read_csv(pred_csv,  index_col=0, parse_dates=True)
    truth = pd.read_csv(truth_csv, index_col=0, parse_dates=True)

    print("\nAvailable tickers :", list(pred.columns))
    ticker = input("Choose a ticker to plot ➜ ").strip().upper()
    if ticker not in pred.columns:
        print("Ticker not found – falling back to first column.")
        ticker = pred.columns[0]

    ax = truth[ticker].plot(label="real close", figsize=(9, 4))
    pred[ticker].plot(ax=ax, label="predicted close")
    ax.set_title(f"{ticker} – 1-bar-ahead close forecast")
    ax.legend()
    plt.show()

    mse = ((pred - truth).pow(2)).mean().mean()
    ic  = pred.corrwith(truth, axis=1).mean()
    print(f"MSE across back-test : {mse:.6f}")
    print(f"Average IC           : {ic:.4f}")
else:
    print("\n(bt_pred_close.csv / bt_truth_close.csv missing)")

# ---------------------------------------------------------------------------
# 6. Cumulative NAV curve
# ---------------------------------------------------------------------------
equity_csv = run_dir / "equity_curve.csv"
if equity_csv.exists():
    nav = pd.read_csv(equity_csv, parse_dates=["ts"]).set_index("ts")["equity"]
    nav.plot(figsize=(9, 4), title="Cumulative Wealth (Nautilus back-test)")
    plt.ylabel("Equity")
    plt.show()
else:
    print("\n(equity_curve.csv not found – back-test may not be finished)")

print("\n🎉  Analysis complete.")
