# Interactive Volatility Surface Explorer

Compare **MLP VAE**, **Conv VAE**, **Heston**, and **Market** volatility surfaces side-by-side.

| Control | Description |
|---------|-------------|
| **Call / Put** | Toggle option type |
| **Date slider** | Scrub through all common test dates |
| **VAE Advantage panel** | Blue = VAE closer to market, Red = Heston closer |
| **Stats panel** | Per-date MAE and cell-level winner breakdown |

Each 3-D surface can be independently rotated, zoomed, and panned.
A wide screen (>1100 px) is recommended.


In [None]:
# Uncomment the line below if plotly / ipywidgets are not installed:
# %pip install plotly ipywidgets --quiet

import numpy as np
import pandas as pd
import json
from pathlib import Path

def _load_vae(d):
    d = Path(d)
    return (np.load(d / 'vae_surfaces.npy'),
            np.load(d / 'market_surfaces.npy'),
            pd.to_datetime(pd.read_csv(d / 'vae_surface_dates.csv')['date']),
            json.load(open(d / 'grid_spec.json')))

def _load_heston(d, ticker='AAPL'):
    d = Path(d)
    return (np.load(d / f'{ticker}_heston_surfaces.npy'),
            pd.to_datetime(pd.read_csv(d / f'{ticker}_heston_surface_dates.csv')['date']))

mlp_raw, mkt_raw, mlp_d, gs = _load_vae('../../artifacts/eval/mlp/surfaces')
conv_raw, _, conv_d, _       = _load_vae('../../artifacts/eval/conv/surfaces')
hest_raw, hest_d             = _load_heston('../../data/processed/heston/surfaces')

# Align to common dates
common = set(mlp_d.dt.date) & set(conv_d.dt.date) & set(hest_d.dt.date)
mlp  = mlp_raw[[d in common for d in mlp_d.dt.date]]
conv = conv_raw[[d in common for d in conv_d.dt.date]]
hest = hest_raw[[d in common for d in hest_d.dt.date]]
mkt  = mkt_raw[[d in common for d in mlp_d.dt.date]]
dates = mlp_d[[d in common for d in mlp_d.dt.date]].reset_index(drop=True)

days    = np.array(gs['days_grid'])
deltas  = np.array(gs['delta_grid'])
cp_list = gs['cp_order']
n_dates = len(dates)

# Fill any remaining NaN in Heston
hest = np.nan_to_num(hest, nan=0.0)

print(f'Loaded {n_dates} common dates  |  Grid: {len(cp_list)} x {len(days)} x {len(deltas)}')
print(f'Range: {dates.iloc[0].date()}  to  {dates.iloc[-1].date()}')


In [None]:

import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display

# ── Figure dimensions ────────────────────────────────────────────────
W, H = 350, 350

SCENE = dict(
    xaxis_title='Delta',
    yaxis_title='Maturity (d)',
    zaxis_title='IV',
    aspectmode='manual',
    aspectratio=dict(x=1, y=1, z=0.6),
    camera=dict(eye=dict(x=1.8, y=-1.8, z=0.9)),
)
MARGIN = dict(l=0, r=0, t=35, b=0)

# ── Helper to build a surface figure ─────────────────────────────────
def _surf_fig(title):
    fig = go.FigureWidget()
    fig.add_trace(go.Surface(
        x=deltas, y=days,
        z=np.zeros((len(days), len(deltas))),
        colorscale='Viridis', showscale=False,
    ))
    fig.update_layout(
        title=dict(text=title, font=dict(size=11)),
        scene=SCENE, margin=MARGIN, height=H, width=W,
    )
    return fig

f_mkt  = _surf_fig('Market')
f_mlp  = _surf_fig('MLP VAE')
f_conv = _surf_fig('Conv VAE')
f_hest = _surf_fig('Heston')

# ── Advantage surface (coloured by who is closer to market) ──────────
f_adv = go.FigureWidget()
f_adv.add_trace(go.Surface(
    x=deltas, y=days,
    z=np.zeros((len(days), len(deltas))),
    surfacecolor=np.zeros((len(days), len(deltas))),
    colorscale='RdBu', showscale=True,
    colorbar=dict(title='Adv', len=0.55, thickness=12, tickformat='.3f'),
))
f_adv.update_layout(
    title=dict(text='VAE Advantage', font=dict(size=11)),
    scene=SCENE, margin=MARGIN, height=H, width=W,
)

# ── Stats HTML panel ─────────────────────────────────────────────────
info = widgets.HTML(layout=widgets.Layout(
    width=f'{W}px', height=f'{H}px',
    overflow_y='auto', padding='6px',
))

# ── Controls ─────────────────────────────────────────────────────────
sl = widgets.IntSlider(
    min=0, max=n_dates - 1, value=0, step=1,
    readout=False,
    layout=widgets.Layout(width='65%'),
)
dl = widgets.Label(layout=widgets.Layout(width='110px'))
cp = widgets.ToggleButtons(options=['Call', 'Put'], value='Call')

# ── Update callback ─────────────────────────────────────────────────
def _upd(*_):
    t  = sl.value
    ci = cp_list.index('C') if cp.value == 'Call' else cp_list.index('P')
    dl.value = str(dates.iloc[t].date())

    m      = mkt[t, ci]
    v_mlp  = mlp[t, ci]
    v_conv = conv[t, ci]
    v_hest = hest[t, ci]

    lo = float(min(m.min(), v_mlp.min(), v_conv.min(), v_hest.min())) - 0.005
    hi = float(max(m.max(), v_mlp.max(), v_conv.max(), v_hest.max())) + 0.005

    for fig, z in [(f_mkt, m), (f_mlp, v_mlp), (f_conv, v_conv), (f_hest, v_hest)]:
        with fig.batch_update():
            fig.data[0].z = z
            fig.layout.scene.zaxis.range = [lo, hi]

    # Per-cell absolute errors
    e_mlp  = np.abs(v_mlp  - m)
    e_conv = np.abs(v_conv - m)
    e_hest = np.abs(v_hest - m)
    e_best = np.minimum(e_mlp, e_conv)       # best VAE error per cell
    adv    = e_hest - e_best                  # >0 means VAE is closer

    with f_adv.batch_update():
        f_adv.data[0].z = m
        f_adv.data[0].surfacecolor = adv
        va = float(max(abs(adv.min()), abs(adv.max()))) or 0.01
        f_adv.data[0].cmin = -va
        f_adv.data[0].cmax = va
        f_adv.layout.scene.zaxis.range = [lo, hi]

    # Stats
    names  = ['MLP VAE', 'Conv VAE', 'Heston']
    maes   = [float(e_mlp.mean()), float(e_conv.mean()), float(e_hest.mean())]
    best_i = int(np.argmin(maes))
    n_vae  = int((e_best < e_hest).sum())
    total  = m.size

    rows = '<br>'.join(
        f'{n}: {v:.4f}  ({v*100:.2f} vp)' for n, v in zip(names, maes)
    )
    info.value = (
        "<div style='font-family:monospace;font-size:12px;"
        "background:#f8f8f8;padding:10px;border-radius:6px'>"
        f"<b style='font-size:13px'>{dates.iloc[t].date()} &mdash; "
        f"{cp.value}s</b><hr>"
        f"<b>MAE vs Market</b><br>{rows}<br><br>"
        f"<b>Best this date:</b> {names[best_i]}<br><br>"
        f"<b>VAE closer:</b> {n_vae}/{total} ({100*n_vae/total:.0f}%)<br>"
        f"<b>Heston closer:</b> {total-n_vae}/{total} "
        f"({100*(total-n_vae)/total:.0f}%)<br><br>"
        "<b>Advantage legend</b><br>"
        "<span style='color:#2166ac'>&#9632; Blue</span> = VAE better<br>"
        "<span style='color:#b2182b'>&#9632; Red</span> = Heston better"
        "</div>"
    )

sl.observe(_upd, 'value')
cp.observe(_upd, 'value')
_upd()

# ── Layout ───────────────────────────────────────────────────────────
display(widgets.VBox([
    widgets.HBox([cp, sl, dl],
                 layout=widgets.Layout(align_items='center')),
    widgets.HBox([f_mkt, f_mlp, f_conv]),
    widgets.HBox([f_hest, f_adv, info]),
]))
