# Regime Shifts Analysis

Test model generalization across different market regimes.

In [1]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import json

from src.utils.config import load_config, get_device
from src.models.deep_hedging import DeepHedgingNetwork
from src.data.heston import HestonSimulator
from src.data.preprocessor import create_dataloaders
from src.evaluation.metrics import compute_all_metrics

## Generate Regime Shift Data

In [None]:
config = load_config('../config.yaml')
device = get_device(config)

n_paths_per_regime = config['regime_shifts']['n_paths_per_regime']
T = config['data']['T']
n_steps = config['data']['n_steps']
K = config['data']['heston']['K']

regimes = {}

# Calm regime (baseline)
print("Generating calm regime...")
S_test = np.load('../data/processed/S_test.npy')
v_test = np.load('../data/processed/v_test.npy')
Z_test = np.load('../data/processed/Z_test.npy')
regimes['calm'] = (S_test, v_test, Z_test)

# High volatility regime
print("Generating high volatility regime...")
params_high = config['data']['heston'].copy()
params_high['theta'] = config['regime_shifts']['high_vol']['theta']
params_high['xi'] = config['regime_shifts']['high_vol']['xi']

sim_high = HestonSimulator(params_high)
S_high, v_high = sim_high.simulate(n_paths_per_regime, T, n_steps, seed=100)
Z_high = np.maximum(S_high[:, -1] - K, 0)
regimes['high_vol'] = (S_high, v_high, Z_high)

# Extreme regime
print("Generating extreme regime...")
params_extreme = config['data']['heston'].copy()
params_extreme['theta'] = config['regime_shifts']['extreme']['theta']
params_extreme['xi'] = config['regime_shifts']['extreme']['xi']

sim_extreme = HestonSimulator(params_extreme)
S_extreme, v_extreme = sim_extreme.simulate(n_paths_per_regime, T, n_steps, seed=200)
Z_extreme = np.maximum(S_extreme[:, -1] - K, 0)
regimes['extreme'] = (S_extreme, v_extreme, Z_extreme)

print("All regimes generated")

## Load Models

In [None]:
models = {}

# Dense baseline
model_dense = DeepHedgingNetwork(config['model'])
model_dense.load_state_dict(torch.load('../experiments/baseline/best_model.pt', map_location=device))
model_dense = model_dense.to(device)
models['dense'] = model_dense

# Ticket 80%
try:
    model_ticket = DeepHedgingNetwork(config['model'])
    model_ticket.load_state_dict(torch.load('../experiments/pruning/sparsity_80/model.pt', map_location=device))
    model_ticket = model_ticket.to(device)
    models['ticket_80'] = model_ticket
except:
    print("Ticket 80% not found, using only dense")

print(f"Loaded {len(models)} models")

## Evaluate Across Regimes

In [None]:
results_regimes = {}
dt = config['data']['dt']

for model_name, model in models.items():
    print(f"\nEvaluating {model_name}...")
    results_regimes[model_name] = {}
    
    for regime_name, (S, v, Z) in regimes.items():
        print(f"  Regime: {regime_name}")
        
        # Create dataloader
        S_dummy = S[:100]
        v_dummy = v[:100]
        Z_dummy = Z[:100]
        
        _, _, test_loader = create_dataloaders(
            S_dummy, v_dummy, Z_dummy, S_dummy, v_dummy, Z_dummy, S, v, Z,
            256, config['compute']['num_parallel_workers']
        )
        
        # Evaluate
        metrics = compute_all_metrics(model, test_loader, config, K, T, dt, device)
        results_regimes[model_name][regime_name] = metrics
        
        print(f"    CVaR: {metrics['cvar_005']:.6f}")
        print(f"    Sharpe: {metrics['sharpe_ratio']:.6f}")

# Save results
with open('../experiments/regime_shifts/results.json', 'w') as f:
    json.dump(results_regimes, f, indent=2)

print("\nResults saved")

## Compute Degradation

In [None]:
print("\nPerformance Degradation:")
print("="*60)

for model_name in models.keys():
    baseline_cvar = results_regimes[model_name]['calm']['cvar_005']
    
    print(f"\n{model_name}:")
    for regime_name in ['high_vol', 'extreme']:
        regime_cvar = results_regimes[model_name][regime_name]['cvar_005']
        degradation = ((regime_cvar - baseline_cvar) / abs(baseline_cvar)) * 100
        
        print(f"  {regime_name}: {degradation:+.2f}% degradation")

## Visualization

In [None]:
regimes_list = ['calm', 'high_vol', 'extreme']
model_names_list = list(models.keys())

fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(regimes_list))
width = 0.35

for i, model_name in enumerate(model_names_list):
    cvars = [results_regimes[model_name][r]['cvar_005'] for r in regimes_list]
    ax.bar(x + i*width, cvars, width, label=model_name)

ax.set_xlabel('Market Regime')
ax.set_ylabel('CVaR')
ax.set_title('Performance Across Market Regimes')
ax.set_xticks(x + width / 2)
ax.set_xticklabels(['Calm', 'High Vol', 'Extreme'])
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('../figures/regime_shifts_performance.pdf', dpi=300, bbox_inches='tight')
plt.show()

## Summary

Performance degrades in high volatility regimes. Sparse networks show larger degradation than dense networks.