# Full Model Comparison: VAE Variants vs Heston

Compare all 5 models against market IV surfaces:
- **MLP** ‚Äî MLP VAE (raw IV)
- **MLP-log** ‚Äî MLP VAE (log-IV)
- **Conv** ‚Äî Conv VAE (raw IV)
- **Conv-log** ‚Äî Conv VAE (log-IV)
- **Heston** ‚Äî Calibrated Heston model

**Prerequisites:**
1. Train 4 VAE variants via `scripts/train_vae.py`
2. Evaluate each via `scripts/eval_vae.py`
3. Calibrate Heston via `scripts/calibrate_heston.py`

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
from pathlib import Path
from collections import OrderedDict

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.dpi'] = 120

OUTPUT_DIR = Path("../../artifacts/comparison")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## 1. Configuration & Data Loading

In [None]:
TICKER = "AAPL"

# VAE model directories (Name ‚Üí surfaces directory)
VAE_DIRS = OrderedDict([
    ("MLP",      Path("../../artifacts/eval/mlp/surfaces")),
    ("MLP-log",  Path("../../artifacts/eval/mlp_log/surfaces")),
    ("Conv",     Path("../../artifacts/eval/conv/surfaces")),
    ("Conv-log", Path("../../artifacts/eval/conv_log/surfaces")),
])

HESTON_DIR = Path("../../data/processed/heston/surfaces")

# Colours for consistent plotting
COLOURS = {
    "MLP":      "#1f77b4",
    "MLP-log":  "#9467bd",
    "Conv":     "#2ca02c",
    "Conv-log": "#8c564b",
    "Heston":   "#ff7f0e",
}

In [None]:
# --- Load all VAE surfaces ---
vae_surfaces = {}   # name ‚Üí (model_surf, market_surf, dates)
grid_spec = None

for name, sdir in VAE_DIRS.items():
    if not sdir.exists():
        print(f"‚ö†  {name}: directory not found ({sdir}) ‚Äî skipping")
        continue
    model = np.load(sdir / "vae_surfaces.npy")
    market = np.load(sdir / "market_surfaces.npy")
    dates = pd.to_datetime(pd.read_csv(sdir / "vae_surface_dates.csv")["date"])
    vae_surfaces[name] = (model, market, dates)
    if grid_spec is None:
        with open(sdir / "grid_spec.json") as f:
            grid_spec = json.load(f)
    print(f"‚úì {name:<10} {model.shape}  ({len(dates)} dates)")

# --- Load Heston ---
heston_surf = np.load(HESTON_DIR / f"{TICKER}_heston_surfaces.npy")
heston_dates = pd.to_datetime(
    pd.read_csv(HESTON_DIR / f"{TICKER}_heston_surface_dates.csv")["date"]
)
print(f"‚úì {'Heston':<10} {heston_surf.shape}  ({len(heston_dates)} dates)")

days_grid = np.array(grid_spec["days_grid"])
delta_grid = np.array(grid_spec["delta_grid"])
cp_order = grid_spec["cp_order"]
print(f"\nGrid: {cp_order} √ó {len(days_grid)} maturities √ó {len(delta_grid)} deltas")

## 2. Align to Common Dates

In [None]:
def _to_date_set(dates_series):
    return set(dates_series.dt.date)

# Intersect all date sets
first_name = next(iter(vae_surfaces))
common = _to_date_set(vae_surfaces[first_name][2])
for name, (_, _, dates) in vae_surfaces.items():
    common &= _to_date_set(dates)
common &= _to_date_set(heston_dates)
common = sorted(common)

print(f"Common dates: {len(common)}")
print(f"Range: {common[0]}  ‚Üí  {common[-1]}")

# Build aligned arrays
model_aligned = {}  # name ‚Üí aligned surfaces
for name, (surf, market, dates) in vae_surfaces.items():
    mask = [d in common for d in dates.dt.date]
    model_aligned[name] = surf[mask]

# Market (take from first VAE ‚Äî they all share the same underlying data)
first_vae = next(iter(vae_surfaces.values()))
market_mask = [d in common for d in first_vae[2].dt.date]
market_aligned = first_vae[1][market_mask]

# Heston
heston_mask = [d in common for d in heston_dates.dt.date]
model_aligned["Heston"] = heston_surf[heston_mask]

# Aligned dates Series for plotting
aligned_dates = first_vae[2][market_mask].reset_index(drop=True)

NAMES = list(model_aligned.keys())
N = len(common)
print(f"\nAligned {len(NAMES)} models √ó {N} dates:")
for name in NAMES:
    print(f"  {name}: {model_aligned[name].shape}")

## 3. Summary Metrics Table

In [None]:
def compute_metrics(a, b):
    """MSE, MAE, RMSE, max error ‚Äî NaN-aware."""
    err = a - b
    ae = np.abs(err)
    se = err ** 2
    return {
        "MSE":  float(np.nanmean(se)),
        "MAE":  float(np.nanmean(ae)),
        "RMSE": float(np.sqrt(np.nanmean(se))),
        "Max":  float(np.nanmax(ae)),
        "MAE (vol pts)": float(np.nanmean(ae) * 100),
        "RMSE (vol pts)": float(np.sqrt(np.nanmean(se)) * 100),
        "Valid %": float(np.sum(~np.isnan(err)) / err.size * 100),
    }

rows = []
for name in NAMES:
    m = compute_metrics(model_aligned[name], market_aligned)
    m["Model"] = name
    rows.append(m)

df_summary = pd.DataFrame(rows).set_index("Model")[
    ["MAE", "MAE (vol pts)", "RMSE", "RMSE (vol pts)", "MSE", "Max", "Valid %"]
]

# Highlight the best (lowest) in each column
print("=" * 70)
print("  MODEL VS MARKET  (lower is better)")
print("=" * 70)
display(df_summary.style.highlight_min(
    subset=["MAE", "RMSE", "MSE", "Max"],
    axis=0, props="font-weight:bold; background-color:#d4edda"
).format({
    "MAE": "{:.6f}", "RMSE": "{:.6f}", "MSE": "{:.6f}", "Max": "{:.6f}",
    "MAE (vol pts)": "{:.2f}%", "RMSE (vol pts)": "{:.2f}%", "Valid %": "{:.1f}%",
}))

In [None]:
# Winner
mae_dict = {name: compute_metrics(model_aligned[name], market_aligned)["MAE"]
            for name in NAMES}
winner = min(mae_dict, key=mae_dict.get)
print(f"\nüèÜ  WINNER (lowest MAE): {winner}  ‚Äî  MAE = {mae_dict[winner]:.6f}"
      f"  ({mae_dict[winner]*100:.2f} vol pts)")
for n in NAMES:
    if n != winner:
        gap = (mae_dict[n] - mae_dict[winner]) / mae_dict[n] * 100
        print(f"     vs {n}: {gap:.1f}% lower MAE")

## 4. Pairwise Model Differences

In [None]:
pw_rows = []
for i in range(len(NAMES)):
    for j in range(i + 1, len(NAMES)):
        m = compute_metrics(model_aligned[NAMES[i]], model_aligned[NAMES[j]])
        pw_rows.append({"Pair": f"{NAMES[i]} vs {NAMES[j]}",
                        "MAE": m["MAE"], "RMSE": m["RMSE"]})
df_pw = pd.DataFrame(pw_rows).set_index("Pair")
display(df_pw.style.format({"MAE": "{:.6f}", "RMSE": "{:.6f}"}))

## 5. Error Heatmaps (per model, per option type)

In [None]:
error_maps = {name: np.nanmean(np.abs(model_aligned[name] - market_aligned), axis=0)
              for name in NAMES}

for c, cp in enumerate(cp_order):
    n_models = len(NAMES)
    fig, axes = plt.subplots(1, n_models, figsize=(4.5 * n_models, 4))
    vmax = max(error_maps[n][c].max() for n in NAMES
               if not np.all(np.isnan(error_maps[n][c])))

    for i, name in enumerate(axes_names := NAMES):
        im = axes[i].imshow(error_maps[name][c], aspect="auto", origin="lower",
                            cmap="Reds", vmin=0, vmax=vmax)
        axes[i].set_title(f"{name} ({cp})", fontsize=9)
        axes[i].set_yticks(range(len(days_grid)))
        axes[i].set_yticklabels([int(d) for d in days_grid], fontsize=7)
        if i == 0:
            axes[i].set_ylabel("Maturity (days)")
        axes[i].set_xticks(range(0, len(delta_grid), 3))
        axes[i].set_xticklabels([f"{delta_grid[k]:.2f}" for k in range(0, len(delta_grid), 3)],
                                fontsize=7, rotation=45)
        axes[i].set_xlabel("Delta")
        plt.colorbar(im, ax=axes[i], format="%.3f", shrink=0.85)

    fig.suptitle(f"Mean Absolute Error ‚Äî {cp}", fontsize=12)
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / f"error_heatmap_{cp}.png", dpi=150, bbox_inches="tight")
    plt.show()

## 6. Error Time Series

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=True)

# Daily MAE
for name in NAMES:
    mae_ts = np.nanmean(np.abs(model_aligned[name] - market_aligned), axis=(1, 2, 3))
    ax1.plot(aligned_dates, mae_ts, label=name, alpha=0.6, linewidth=0.8,
             color=COLOURS.get(name))
ax1.set_ylabel("MAE")
ax1.set_title("Daily MAE vs Market")
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)

# Rolling 20-day average
window = 20
for name in NAMES:
    mae_ts = np.nanmean(np.abs(model_aligned[name] - market_aligned), axis=(1, 2, 3))
    rolling = pd.Series(mae_ts).rolling(window).mean()
    ax2.plot(aligned_dates, rolling, label=f"{name}", linewidth=1.5,
             color=COLOURS.get(name))
ax2.set_ylabel("MAE")
ax2.set_xlabel("Date")
ax2.set_title(f"{window}-Day Rolling Average MAE")
ax2.legend(fontsize=8)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "error_timeseries.png", dpi=150, bbox_inches="tight")
plt.show()

## 7. Per-Maturity and Per-Delta Breakdown

In [None]:
# MAE by maturity
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

for name in NAMES:
    mae_by_mat = [np.nanmean(np.abs(model_aligned[name][:, :, d, :] - market_aligned[:, :, d, :]))
                  for d in range(len(days_grid))]
    ax1.plot(days_grid, mae_by_mat, 'o-', label=name, color=COLOURS.get(name), markersize=4)
ax1.set_xlabel("Maturity (days)")
ax1.set_ylabel("MAE")
ax1.set_title("MAE by Maturity")
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)

# MAE by delta
for name in NAMES:
    mae_by_delta = [np.nanmean(np.abs(model_aligned[name][:, :, :, i] - market_aligned[:, :, :, i]))
                    for i in range(len(delta_grid))]
    ax2.plot(delta_grid, mae_by_delta, 'o-', label=name, color=COLOURS.get(name), markersize=4)
ax2.set_xlabel("Delta")
ax2.set_ylabel("MAE")
ax2.set_title("MAE by Delta")
ax2.legend(fontsize=8)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "mae_breakdown.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Tables
print("MAE by Maturity:")
print(f"{'Days':<8}", end="")
for name in NAMES:
    print(f"{name:<14}", end="")
print(f"{'Winner':<10}")
print("-" * (8 + 14 * len(NAMES) + 10))

for d, days in enumerate(days_grid):
    print(f"{int(days):<8}", end="")
    vals = {}
    for name in NAMES:
        v = np.nanmean(np.abs(model_aligned[name][:, :, d, :] - market_aligned[:, :, d, :]))
        vals[name] = v
        print(f"{v:<14.6f}", end="")
    print(f"{min(vals, key=vals.get):<10}")

print(f"\n\nMAE by Delta:")
print(f"{'Delta':<8}", end="")
for name in NAMES:
    print(f"{name:<14}", end="")
print(f"{'Winner':<10}")
print("-" * (8 + 14 * len(NAMES) + 10))

for i, delta in enumerate(delta_grid):
    print(f"{delta:<8.2f}", end="")
    vals = {}
    for name in NAMES:
        v = np.nanmean(np.abs(model_aligned[name][:, :, :, i] - market_aligned[:, :, :, i]))
        vals[name] = v
        print(f"{v:<14.6f}", end="")
    print(f"{min(vals, key=vals.get):<10}")

## 8. Sample Surface Comparisons

In [None]:
def plot_surface_comparison(idx, cp_idx=0):
    """Plot Market + all model surfaces for a given date index."""
    date_str = str(aligned_dates.iloc[idx].date())
    cp = cp_order[cp_idx]
    n_models = len(NAMES)

    fig, axes = plt.subplots(1, n_models + 1, figsize=(4 * (n_models + 1), 4))

    # Market
    vmin = market_aligned[idx, cp_idx].min()
    vmax = market_aligned[idx, cp_idx].max()
    im0 = axes[0].imshow(market_aligned[idx, cp_idx], aspect="auto", origin="lower",
                         cmap="viridis", vmin=vmin, vmax=vmax)
    axes[0].set_title(f"Market ({cp})", fontsize=9)
    axes[0].set_ylabel("Maturity")
    plt.colorbar(im0, ax=axes[0], format="%.3f", shrink=0.85)

    for i, name in enumerate(NAMES):
        im = axes[i+1].imshow(model_aligned[name][idx, cp_idx], aspect="auto",
                              origin="lower", cmap="viridis", vmin=vmin, vmax=vmax)
        mae_val = np.nanmean(np.abs(model_aligned[name][idx, cp_idx] - market_aligned[idx, cp_idx]))
        axes[i+1].set_title(f"{name}\nMAE={mae_val:.4f}", fontsize=9)
        plt.colorbar(im, ax=axes[i+1], format="%.3f", shrink=0.85)

    fig.suptitle(f"{date_str} ‚Äî {cp}", fontsize=12)
    plt.tight_layout()
    plt.show()

# Plot first, middle, last
for idx in [0, N // 2, N - 1]:
    plot_surface_comparison(idx, cp_idx=0)

## 9. Per-Cell Win Count

In [None]:
# For each grid cell, which model has the lowest MAE?
for c, cp in enumerate(cp_order):
    print(f"\n{cp}:")
    error_stack = {name: error_maps[name][c] for name in NAMES}
    stacked = np.stack([error_stack[n] for n in NAMES], axis=0)  # (M, H, W)
    best_idx = np.nanargmin(stacked, axis=0)

    for i, name in enumerate(NAMES):
        count = np.sum(best_idx == i)
        total = best_idx.size
        print(f"  {name:<12} best in {count:>3}/{total} cells ({100*count/total:.1f}%)")

## 10. Final Summary

In [None]:
print("=" * 70)
print("  FINAL COMPARISON SUMMARY")
print("=" * 70)
print(f"\n  Test period : {common[0]}  ‚Üí  {common[-1]}")
print(f"  Surfaces    : {N}")
print(f"  Grid        : {cp_order} √ó {len(days_grid)} mat √ó {len(delta_grid)} delta")
print(f"  Models      : {', '.join(NAMES)}")
print()
print(f"  {'Model':<14} {'MAE':>10}  {'(vol pts)':>10}  {'RMSE':>10}  {'(vol pts)':>10}")
print("  " + "-" * 58)

for name in NAMES:
    m = compute_metrics(model_aligned[name], market_aligned)
    print(f"  {name:<14} {m['MAE']:>10.6f}  {m['MAE']*100:>9.2f}%  "
          f"{m['RMSE']:>10.6f}  {m['RMSE']*100:>9.2f}%")

print("  " + "-" * 58)
print(f"\n  üèÜ  WINNER: {winner}  (MAE = {mae_dict[winner]*100:.2f} vol pts)")
print("=" * 70)

# Save metrics JSON for reference
metrics_out = {name: compute_metrics(model_aligned[name], market_aligned) for name in NAMES}
metrics_out["_meta"] = {
    "n_dates": N,
    "date_range": [str(common[0]), str(common[-1])],
    "winner": winner,
}
with open(OUTPUT_DIR / "comparison_metrics.json", "w") as f:
    json.dump(metrics_out, f, indent=2)
print(f"\nMetrics also saved to {OUTPUT_DIR / 'comparison_metrics.json'}")