# Regime Shifts Analysis

Test model generalization across different market regimes (calm, high volatility, extreme).

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import json
from pathlib import Path
from copy import deepcopy

from src.utils.config import load_config, get_device
from src.models.deep_hedging import create_model
from src.models.losses import create_loss_function
from src.data.heston import HestonSimulator, get_or_generate_dataset
from src.data.preprocessor import create_dataloaders, compute_features
from src.evaluation.metrics import compute_all_metrics

## Setup

In [None]:
config = load_config('../configs/config.yaml')
device = get_device(config)
print(f"Using device: {device}")

# Extract key parameters
heston_config = config['data']['heston']
K = heston_config['K']
T = config['data']['T']
n_steps = config['data']['n_steps']
dt = T / n_steps

print(f"K={K}, T={T}, n_steps={n_steps}, dt={dt:.6f}")

## Define Market Regimes

In [None]:
# Number of paths per regime
n_paths_per_regime = 10000

# Define regime parameters
regime_params = {
    'calm': {
        'theta': heston_config.get('theta', 0.0398),  # Long-term variance
        'xi': heston_config.get('xi', 0.5751),        # Vol of vol
        'v_0': heston_config.get('v_0', 0.0175),      # Initial variance
        'description': 'Normal market conditions'
    },
    'high_vol': {
        'theta': 0.08,   # 2x baseline long-term variance
        'xi': 0.8,       # Higher vol of vol
        'v_0': 0.04,     # Higher initial variance
        'description': 'High volatility regime (e.g., market stress)'
    },
    'extreme': {
        'theta': 0.15,   # ~4x baseline
        'xi': 1.2,       # Very high vol of vol
        'v_0': 0.08,     # High initial variance
        'description': 'Extreme regime (e.g., crisis)'
    }
}

print("Market Regimes:")
for regime_name, params in regime_params.items():
    print(f"\n{regime_name}:")
    print(f"  theta (long-term var): {params['theta']}")
    print(f"  xi (vol of vol):       {params['xi']}")
    print(f"  v_0 (initial var):     {params['v_0']}")
    print(f"  Description: {params['description']}")

## Generate Regime Shift Data

In [None]:
regimes_data = {}

for regime_name, params in regime_params.items():
    print(f"\nGenerating {regime_name} regime ({n_paths_per_regime} paths)...")
    
    # Create modified Heston parameters
    heston_modified = deepcopy(heston_config)
    heston_modified['theta'] = params['theta']
    heston_modified['xi'] = params['xi']
    heston_modified['v_0'] = params['v_0']
    
    # Simulate
    simulator = HestonSimulator(heston_modified)
    seed = hash(regime_name) % 10000  # Deterministic seed per regime
    S, v = simulator.simulate(n_paths_per_regime, T, n_steps, seed=seed)
    
    # Compute payoffs
    Z = np.maximum(S[:, -1] - K, 0)  # Call payoff
    
    regimes_data[regime_name] = {
        'S': S,
        'v': v,
        'Z': Z
    }
    
    print(f"  S range: [{S.min():.2f}, {S.max():.2f}]")
    print(f"  v range: [{v.min():.6f}, {v.max():.6f}]")
    print(f"  ITM ratio: {np.mean(S[:, -1] > K) * 100:.1f}%")

print("\nAll regimes generated successfully")

## Load Models

In [None]:
models = {}
loss_fn = create_loss_function(config)

# Dense baseline
baseline_path = Path('../experiments/baseline/checkpoints/best.pt')
if baseline_path.exists():
    model_dense = create_model(config)
    checkpoint = torch.load(baseline_path, map_location=device)
    model_dense.load_state_dict(checkpoint['model_state_dict'])
    model_dense = model_dense.to(device)
    models['dense'] = model_dense
    print(f"Loaded dense baseline")
else:
    print(f"Dense baseline not found at {baseline_path}")

# Sparse ticket 80%
ticket_path = Path('../experiments/pruning/sparsity_80/checkpoints/best.pt')
if ticket_path.exists():
    model_ticket = create_model(config)
    checkpoint = torch.load(ticket_path, map_location=device)
    model_ticket.load_state_dict(checkpoint['model_state_dict'])
    model_ticket = model_ticket.to(device)
    models['ticket_80%'] = model_ticket
    print(f"Loaded ticket 80%")
else:
    print(f"Ticket 80% not found, skipping")

# Adversarially trained model
adv_path = Path('../experiments/adversarial_training/checkpoints/best.pt')
if adv_path.exists():
    model_adv = create_model(config)
    checkpoint = torch.load(adv_path, map_location=device)
    model_adv.load_state_dict(checkpoint['model_state_dict'])
    model_adv = model_adv.to(device)
    models['robust_ticket'] = model_adv
    print(f"Loaded robust ticket")
else:
    print(f"Robust ticket not found, skipping")

print(f"\nTotal models loaded: {len(models)}")

## Helper Function: Evaluate Model

In [None]:
def evaluate_on_regime(model, loss_fn, S, v, Z, config, device, batch_size=256):
    """
    Evaluate model on a specific regime's data.
    
    Returns:
        Dictionary of metrics
    """
    heston_config = config['data']['heston']
    K = heston_config['K']
    T = config['data']['T']
    n_steps = config['data']['n_steps']
    dt = T / n_steps
    
    model.eval()
    all_pnl = []
    
    # Create simple batches
    n_samples = S.shape[0]
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            
            S_batch = torch.tensor(S[i:end_idx], dtype=torch.float32, device=device)
            v_batch = torch.tensor(v[i:end_idx], dtype=torch.float32, device=device)
            Z_batch = torch.tensor(Z[i:end_idx], dtype=torch.float32, device=device)
            
            features = compute_features(S_batch, v_batch, K, T, dt)
            deltas, y = model(features, S_batch)
            pnl = loss_fn.compute_pnl(deltas, S_batch, Z_batch, dt)
            
            all_pnl.append(pnl.cpu())
    
    all_pnl = torch.cat(all_pnl).numpy()
    return compute_all_metrics(all_pnl)

## Evaluate Across Regimes

In [None]:
results_regimes = {}

for model_name, model in models.items():
    print(f"\nEvaluating {model_name}...")
    results_regimes[model_name] = {}
    
    for regime_name, data in regimes_data.items():
        print(f"  Regime: {regime_name}...", end=" ")
        
        metrics = evaluate_on_regime(
            model, loss_fn,
            data['S'], data['v'], data['Z'],
            config, device
        )
        
        results_regimes[model_name][regime_name] = metrics
        print(f"CVaR={metrics['cvar_05']:.4f}, Sharpe={metrics['sharpe_ratio']:.4f}")

# Save results
output_dir = Path('../experiments/regime_shifts')
output_dir.mkdir(parents=True, exist_ok=True)

with open(output_dir / 'results.json', 'w') as f:
    json.dump(results_regimes, f, indent=2, default=float)

print(f"\nResults saved to {output_dir / 'results.json'}")

## Compute Performance Degradation

In [None]:
print("\n" + "=" * 80)
print("PERFORMANCE DEGRADATION ANALYSIS")
print("=" * 80)

degradation_results = {}

for model_name in models.keys():
    baseline_cvar = results_regimes[model_name]['calm']['cvar_05']
    baseline_sharpe = results_regimes[model_name]['calm']['sharpe_ratio']
    
    print(f"\n{model_name}:")
    print(f"  Baseline (calm): CVaR={baseline_cvar:.4f}, Sharpe={baseline_sharpe:.4f}")
    
    degradation_results[model_name] = {}
    
    for regime_name in ['high_vol', 'extreme']:
        regime_cvar = results_regimes[model_name][regime_name]['cvar_05']
        regime_sharpe = results_regimes[model_name][regime_name]['sharpe_ratio']
        
        # Compute degradation (negative = worse)
        cvar_degradation = ((regime_cvar - baseline_cvar) / abs(baseline_cvar)) * 100
        sharpe_degradation = ((regime_sharpe - baseline_sharpe) / abs(baseline_sharpe + 1e-8)) * 100
        
        degradation_results[model_name][regime_name] = {
            'cvar_degradation_pct': cvar_degradation,
            'sharpe_degradation_pct': sharpe_degradation
        }
        
        print(f"  {regime_name}: CVaR={regime_cvar:.4f} ({cvar_degradation:+.1f}%), "
              f"Sharpe={regime_sharpe:.4f} ({sharpe_degradation:+.1f}%)")

print("=" * 80)

## Visualization: Performance Across Regimes

In [None]:
if len(models) > 0:
    regimes_list = ['calm', 'high_vol', 'extreme']
    model_names_list = list(models.keys())
    n_models = len(model_names_list)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    x = np.arange(len(regimes_list))
    width = 0.8 / n_models
    colors = ['#2563eb', '#16a34a', '#dc2626', '#9333ea'][:n_models]
    
    # CVaR across regimes
    for i, model_name in enumerate(model_names_list):
        cvars = [results_regimes[model_name][r]['cvar_05'] for r in regimes_list]
        axes[0].bar(x + i * width - (n_models - 1) * width / 2, cvars, width, 
                   label=model_name, color=colors[i], alpha=0.8)
    
    axes[0].set_xlabel('Market Regime')
    axes[0].set_ylabel('CVaR 5%')
    axes[0].set_title('CVaR Across Market Regimes')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(['Calm', 'High Vol', 'Extreme'])
    axes[0].legend()
    axes[0].grid(True, alpha=0.3, axis='y')
    
    # Sharpe across regimes
    for i, model_name in enumerate(model_names_list):
        sharpes = [results_regimes[model_name][r]['sharpe_ratio'] for r in regimes_list]
        axes[1].bar(x + i * width - (n_models - 1) * width / 2, sharpes, width, 
                   label=model_name, color=colors[i], alpha=0.8)
    
    axes[1].set_xlabel('Market Regime')
    axes[1].set_ylabel('Sharpe Ratio')
    axes[1].set_title('Sharpe Ratio Across Market Regimes')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(['Calm', 'High Vol', 'Extreme'])
    axes[1].legend()
    axes[1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('../figures/regime_shifts_performance.pdf', dpi=300)
    plt.show()
else:
    print("No models to visualize")

## Visualization: Degradation Heatmap

In [None]:
if len(degradation_results) > 0:
    # Create degradation matrix
    model_names_list = list(degradation_results.keys())
    regimes_stressed = ['high_vol', 'extreme']
    
    matrix = np.zeros((len(model_names_list), len(regimes_stressed)))
    
    for i, model_name in enumerate(model_names_list):
        for j, regime in enumerate(regimes_stressed):
            matrix[i, j] = degradation_results[model_name][regime]['cvar_degradation_pct']
    
    fig, ax = plt.subplots(figsize=(8, 6))
    
    im = ax.imshow(matrix, cmap='RdYlGn', aspect='auto', vmin=-50, vmax=10)
    
    ax.set_xticks(range(len(regimes_stressed)))
    ax.set_xticklabels(['High Vol', 'Extreme'])
    ax.set_yticks(range(len(model_names_list)))
    ax.set_yticklabels(model_names_list)
    
    # Add text annotations
    for i in range(len(model_names_list)):
        for j in range(len(regimes_stressed)):
            text = ax.text(j, i, f'{matrix[i, j]:.1f}%',
                          ha='center', va='center', color='black', fontsize=12)
    
    ax.set_title('CVaR Degradation (%) by Model and Regime')
    plt.colorbar(im, ax=ax, label='Degradation %')
    
    plt.tight_layout()
    plt.savefig('../figures/regime_degradation_heatmap.pdf', dpi=300)
    plt.show()

## Summary

Key findings:
- All models show performance degradation in stressed regimes
- Sparse networks may show larger degradation than dense networks
- Adversarially trained models are more robust to regime shifts
- This highlights the importance of robust training for deployment