In [1]:
#!/usr/bin/env python3
"""
Create interactive 3D plot of SPX volatility surface using model's native output.
Saves to HTML file that can be opened in any web browser.
"""

import os
import sys
import torch
from torch.utils.data import DataLoader
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Add parent directory to path
sys.path.insert(0, os.path.abspath('../..'))

from volatility_smoothing.utils.train.loss import Loss
from volatility_smoothing.utils.options_data import WRDSOptionsDataset
from volatility_smoothing.utils.train.dataset import GNOOptionsDataset
from volatility_smoothing.utils.train import misc

print("="*80)
print("CREATING INTERACTIVE SPX VOLATILITY SURFACE PLOT")
print("="*80)

# Setup
os.environ['OPDS_CACHE_DIR'] = os.path.expanduser('~/.cache/opds')
os.environ['OPDS_WRDS_DATA_DIR'] = os.path.abspath("../data/openbb/spx")

device = torch.device("cpu")
step_r = 0.02
step_z = 0.01

# Load data
print("\n1. Loading dataset...")
dataset = WRDSOptionsDataset()
print(f"   ✓ Loaded {len(dataset)} surface(s)")
print(f"   ✓ Date: {dataset.quote_datetimes[0]}")

# Load model
print("\n2. Loading model...")
checkpoint_path = "../train/store/9448705/checkpoints/checkpoint_final.pt"
model, _ = misc.load_checkpoint(checkpoint_path, device=device)
model.eval()
print(f"   ✓ Model loaded from checkpoint")

# Create loss for evaluation
dev_loss = Loss(step_r=step_r, step_z=step_z)

# Get model predictions
print("\n3. Running model prediction...")
with torch.no_grad():
    gno_dataset = GNOOptionsDataset(dataset)
    batch = gno_dataset[0]

    dataloader = DataLoader([batch], batch_size=1, collate_fn=dev_loss.collate_fn)

    for data, input, aux in dataloader:
        input = {k: v.to(device) if torch.is_tensor(v) else v for k, v in input.items()}
        output = model(**input)

        # Use dev_loss.read_output - existing code from the paper
        iv_predict, iv_surface, *_ = dev_loss.read_output(output, aux)

        # Get the grid from Loss class
        grid = aux['grids'][0]

        # Get raw market data (normalized coordinates)
        rho_data = data['r'].cpu().numpy().ravel()
        z_data = data['z'].cpu().numpy().ravel()
        iv_market = data['implied_volatility'].cpu().numpy().ravel()

        # Get actual values for plotting
        tau_data = data['time_to_maturity'].cpu().numpy().ravel()
        log_moneyness_data = data['log_moneyness'].cpu().numpy().ravel()

        break

# Extract model's native grid (normalized coordinates)
rho_grid = grid['r'].cpu().numpy()
z_grid = grid['z'].cpu().numpy()
iv_surface_model = iv_surface.cpu().numpy()

# Convert to actual values for plotting
import numpy as np
tau_grid = rho_grid ** 2  # τ = ρ²
log_moneyness_grid = z_grid * rho_grid  # log_moneyness = z × ρ

print(f"   ✓ Model's native grid: {rho_grid.shape}")
print(f"   ✓ Market data points: {len(rho_data)}")
print(f"   ✓ Time to maturity range: {tau_data.min():.4f} to {tau_data.max():.4f} years")
print(f"   ✓ Log-moneyness range: {log_moneyness_data.min():.4f} to {log_moneyness_data.max():.4f}")

# Create interactive visualization
print("\n4. Creating interactive 3D plot...")

fig = make_subplots(
    rows=1, cols=2,
    specs=[[{'type': 'surface'}, {'type': 'surface'}]],
    subplot_titles=(
        'Market IV Data Points',
        f'Model IV Surface (Native Grid: {rho_grid.shape[0]}×{rho_grid.shape[1]})'
    ),
    horizontal_spacing=0.1
)

# LEFT: Market data as scatter points (ACTUAL VALUES)
fig.add_trace(
    go.Scatter3d(
        x=tau_data,
        y=log_moneyness_data,
        z=iv_market,
        mode='markers',
        marker=dict(
            size=2,
            color=iv_market,
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title="IV", x=-0.07, len=0.8)
        ),
        name='Market Data',
        hovertemplate='Maturity: %{x:.3f} yrs<br>Log-moneyness: %{y:.3f}<br>IV: %{z:.4f}<extra></extra>'
    ),
    row=1, col=1
)

# RIGHT: Model's actual surface output (ACTUAL VALUES)
fig.add_trace(
    go.Surface(
        x=tau_grid,
        y=log_moneyness_grid,
        z=iv_surface_model,
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(title="IV", x=1.07, len=0.8),
        contours=dict(
            z=dict(
                show=True,
                usecolormap=True,
                highlightcolor="limegreen",
                project=dict(z=True)
            )
        ),
        name='Model Surface',
        hovertemplate='Maturity: %{x:.3f} yrs<br>Log-moneyness: %{y:.3f}<br>IV: %{z:.4f}<extra></extra>',
        opacity=0.95
    ),
    row=1, col=2
)

# Update layout
fig.update_layout(
    title_text="SPX Implied Volatility Surface - Actual Values (Time to Maturity & Log-Moneyness)",
    height=700,
    showlegend=False,
    scene=dict(
        xaxis_title='Time to Maturity (years)',
        yaxis_title='Log-Moneyness',
        zaxis_title='Implied Volatility',
        camera=dict(eye=dict(x=1.5, y=-1.5, z=1.3)),
        aspectmode='auto'
    ),
    scene2=dict(
        xaxis_title='Time to Maturity (years)',
        yaxis_title='Log-Moneyness',
        zaxis_title='Implied Volatility',
        camera=dict(eye=dict(x=1.5, y=-1.5, z=1.3)),
        aspectmode='auto'
    ),
    font=dict(size=11)
)

# Save to HTML file
output_file = "spx_volatility_surface_interactive.html"
fig.write_html(output_file)

print(f"  Plot created successfully")

print(f"\nSaved interactive plot to: {output_file}")



CREATING INTERACTIVE SPX VOLATILITY SURFACE PLOT

1. Loading dataset...
   ✓ Loaded 1 surface(s)
   ✓ Date: 2025-12-07 00:00:00

2. Loading model...
   ✓ Model loaded from checkpoint

3. Running model prediction...
   ✓ Model's native grid: (50, 200)
   ✓ Market data points: 9022
   ✓ Time to maturity range: 0.0027 to 0.9534 years
   ✓ Log-moneyness range: -1.3692 to 0.4615

4. Creating interactive 3D plot...
   ✓ Plot created successfully

SUCCESS!

✓ Saved interactive plot to: spx_volatility_surface_interactive.html

To view the plot:
  1. Open the file: /Users/anighot/Documents/Thesis_New/operator-deep-smoothing-for-implied-volatility-main/volatility_smoothing/eval/spx_volatility_surface_interactive.html
  2. Or double-click the file in Finder to open in browser

Plot features:
  • Rotate: Click and drag
  • Zoom: Scroll or pinch
  • Pan: Right-click and drag
  • Hover: See exact values
  • Reset: Double-click

Data source (thesis-appropriate):
  • Grid: Created by Loss class (step_