# VAE Volatility Surface Experiments

This notebook provides a clean workflow for:
1. Loading processed volatility surface data
2. Training VAE models (grid-wise and pointwise)
3. Evaluating surface completion with masking (0%, 25%, 50%, 75%)
4. Visualizing reconstructions

In [1]:
# Imports
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt

# Project imports
from src.data.dataloaders import create_dataloaders
from src.models import MLPVAE, PointwiseMLPVAE, vae_loss
from src.utils import (
    fit_vae,
    evaluate_vae,
    evaluate_completion_sweep,
    print_completion_summary,
    create_random_mask,
)

## 1. Configuration

In [None]:
# Paths (updated for new folder structure)
PARQUET_PATH = Path("../../data/processed/vae/parquet/AAPL_vsurf_processed.parquet")
CHECKPOINT_DIR = Path("../../artifacts/train")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# Model configuration
GRID_SHAPE = (2, 11, 17)  # (channels, maturities, deltas) - will be overridden by bundle.input_shape
INPUT_DIM = np.prod(GRID_SHAPE)
LATENT_DIM = 8
HIDDEN_DIMS = [256, 128]

# Training configuration
BATCH_SIZE = 32
EPOCHS = 1000
LR = 1e-3
BETA = 1.0  # KL weight
PATIENCE = 100  # Early stopping

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


## 2. Load Data

In [3]:
# Create dataloaders (chronological split, normalized)
bundle = create_dataloaders(
    parquet_path=PARQUET_PATH,
    value_col="impl_volatility",
    train_ratio=0.80,
    val_ratio=0.10,
    batch_size=BATCH_SIZE,
    normalize=True,
    return_date=True,  # Return (x, date) tuples
)

print(f"Train: {len(bundle.train_loader.dataset)} samples")
print(f"Val:   {len(bundle.val_loader.dataset)} samples")
print(f"Test:  {len(bundle.test_loader.dataset)} samples")

# Check a batch shape
batch = next(iter(bundle.train_loader))
x, date = batch
print(f"Batch shape: {x.shape}")

Train: 1943 samples
Val:   242 samples
Test:  244 samples
Batch shape: torch.Size([32, 2, 11, 17])


## 3. Define Model

In [None]:
# Use actual shape from loaded data (not hardcoded)
GRID_SHAPE = bundle.input_shape
print(f"Grid shape from data: {GRID_SHAPE} = {np.prod(GRID_SHAPE)} values")

# Grid-wise VAE
model = MLPVAE(
    in_shape=GRID_SHAPE,
    hidden_dims=HIDDEN_DIMS,
    latent_dim=LATENT_DIM,
).to(DEVICE)

print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Train Model

In [None]:
# Train with early stopping
history = fit_vae(
    model=model,
    train_loader=bundle.train_loader,
    val_loader=bundle.val_loader,
    epochs=EPOCHS,
    lr=LR,
    beta=BETA,
    device=DEVICE,
    patience=PATIENCE,
    checkpoint_dir=CHECKPOINT_DIR,  # Directory, not file path
)

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# history is a list of TrainStats dataclasses
train_loss = [h.train_loss for h in history]
val_loss = [h.val_loss for h in history]
train_recon = [h.train_recon for h in history]
val_recon = [h.val_recon for h in history]
train_kl = [h.train_kl for h in history]
val_kl = [h.val_kl for h in history]

ax1.plot(train_loss, label='Train')
ax1.plot(val_loss, label='Val')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Total Loss')
ax1.legend()

ax2.plot(train_recon, label='Train Recon')
ax2.plot(val_recon, label='Val Recon')
ax2.plot(train_kl, label='Train KL', linestyle='--')
ax2.plot(val_kl, label='Val KL', linestyle='--')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss Component')
ax2.set_title('Loss Components')
ax2.legend()

plt.tight_layout()
plt.show()

## 5. Load Best Model & Evaluate

In [None]:
# Load best checkpoint
ckpt = torch.load(CHECKPOINT_DIR / "best_model.pt", map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'])
print(f"Loaded best model from epoch {ckpt['epoch']}")
print(f"Best val loss: {ckpt['val_loss']:.6f}")

In [None]:
# Standard evaluation on test set
metrics = evaluate_vae(model, bundle.test_loader, DEVICE)

print(f"\n=== Test Set Metrics ===")
print(f"MSE:  {metrics.mse:.6f}")
print(f"MAE:  {metrics.mae:.6f}")
print(f"RMSE: {metrics.rmse:.6f}")

## 6. Surface Completion Evaluation (Masking)

Following Bergeron et al., we evaluate the model's ability to complete partially observed surfaces.
We mask 0%, 25%, 50%, and 75% of the surface and measure reconstruction error on the masked points.

In [None]:
# Run masking evaluation
completion_results = evaluate_completion_sweep(
    model=model,
    loader=bundle.test_loader,
    grid_shape=GRID_SHAPE,
    device=DEVICE,
    mask_ratios=(0.0, 0.25, 0.50, 0.75, 0.90, 1),
    seed=42,
    scaler=bundle.scaler,
)

# Print summary table
print_completion_summary(completion_results)

In [None]:
# Plot completion performance
ratios = [r.mask_ratio for r in completion_results]
maes = [r.mae_masked for r in completion_results]
rmses = [r.rmse_masked for r in completion_results]

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot([r*100 for r in ratios], maes, 'o-', label='MAE (masked points)', markersize=8)
ax.plot([r*100 for r in ratios], rmses, 's--', label='RMSE (masked points)', markersize=8)
ax.set_xlabel('Masking Percentage (%)')
ax.set_ylabel('Error (normalized)')
ax.set_title('Surface Completion Performance')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()

## 7. Visualize Reconstructions

In [None]:
def plot_surface_comparison(x_orig, x_recon, mask=None, title="Surface Comparison"):
    """Plot original vs reconstructed surface for one channel."""
    fig, axes = plt.subplots(1, 3 if mask is None else 4, figsize=(14, 4))
    
    # Original surface (channel 0 - Calls)
    im0 = axes[0].imshow(x_orig[0], aspect='auto', cmap='viridis')
    axes[0].set_title('Original (Calls)')
    axes[0].set_xlabel('Delta')
    axes[0].set_ylabel('Maturity')
    plt.colorbar(im0, ax=axes[0])
    
    # Reconstructed
    im1 = axes[1].imshow(x_recon[0], aspect='auto', cmap='viridis')
    axes[1].set_title('Reconstructed (Calls)')
    axes[1].set_xlabel('Delta')
    plt.colorbar(im1, ax=axes[1])
    
    # Error
    error = np.abs(x_orig[0] - x_recon[0])
    im2 = axes[2].imshow(error, aspect='auto', cmap='Reds')
    axes[2].set_title('Absolute Error')
    axes[2].set_xlabel('Delta')
    plt.colorbar(im2, ax=axes[2])
    
    if mask is not None:
        im3 = axes[3].imshow(mask[0], aspect='auto', cmap='gray')
        axes[3].set_title('Mask (1=hidden)')
        axes[3].set_xlabel('Delta')
        plt.colorbar(im3, ax=axes[3])
    
    fig.suptitle(title)
    plt.tight_layout()
    return fig

In [None]:
# Get a sample and reconstruct
model.eval()
with torch.no_grad():
    x, date = next(iter(bundle.test_loader))
    x = x.to(DEVICE)
    recon, mu, logvar = model(x)
    
    # Plot first sample
    idx = 0
    x_np = x[idx].cpu().numpy()
    recon_np = recon[idx].cpu().numpy()
    
    plot_surface_comparison(x_np, recon_np, title=f"Test Sample (Date: {date[idx]})")
    plt.show()

In [None]:
# Visualize masked reconstruction (50% mask)
mask = create_random_mask(GRID_SHAPE, mask_ratio=0.5, seed=42, device=DEVICE)

with torch.no_grad():
    x, date = next(iter(bundle.test_loader))
    x = x.to(DEVICE)
    
    # Apply mask
    x_masked = x * (1 - mask.unsqueeze(0))
    
    # Reconstruct from masked input
    recon, mu, logvar = model(x_masked)
    
    # Plot
    idx = 0
    plot_surface_comparison(
        x[idx].cpu().numpy(),
        recon[idx].cpu().numpy(),
        mask.cpu().numpy(),
        title=f"50% Masked Reconstruction (Date: {date[idx]})"
    )
    plt.show()

## 8. Latent Space Visualization

In [None]:
# Extract latent codes for all test samples
latent_codes = []
dates = []

model.eval()
with torch.no_grad():
    for batch in bundle.test_loader:
        x, d = batch
        x = x.to(DEVICE)
        mu, logvar = model.encode(x)
        latent_codes.append(mu.cpu().numpy())
        dates.extend(d)

latent_codes = np.concatenate(latent_codes, axis=0)
print(f"Latent codes shape: {latent_codes.shape}")

In [None]:
# PCA visualization (if latent_dim > 2)
from sklearn.decomposition import PCA

if LATENT_DIM > 2:
    pca = PCA(n_components=2)
    latent_2d = pca.fit_transform(latent_codes)
    print(f"PCA explained variance: {pca.explained_variance_ratio_.sum()*100:.1f}%")
else:
    latent_2d = latent_codes

# Plot
plt.figure(figsize=(10, 8))
scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], 
                       c=range(len(latent_2d)), cmap='viridis', alpha=0.6, s=20)
plt.colorbar(scatter, label='Time index')
plt.xlabel('Latent Dim 1' if LATENT_DIM <= 2 else 'PC 1')
plt.ylabel('Latent Dim 2' if LATENT_DIM <= 2 else 'PC 2')
plt.title('Latent Space of Test Surfaces (colored by time)')
plt.show()