# 8. Model Comparison

Compare VAE, GAN, and Flow models for macro scenario generation.

## Contents
1. Model Overview
2. Generation Quality
3. Statistical Comparison
4. Ensemble Performance

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from scipy import stats

from privatecredit.models import MacroVAE, MacroGAN, MacroFlow, MacroEnsemble
from privatecredit.models.macro_vae import MacroVAEConfig
from privatecredit.models.macro_gan import MacroGANConfig
from privatecredit.models.macro_flow import MacroFlowConfig
from privatecredit.models.ensemble import EnsembleConfig, EnsembleMethod

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Model Overview

| Model | Architecture | Key Feature |
|-------|-------------|-------------|
| VAE | LSTM Encoder-Decoder | ELBO optimization, latent interpolation |
| GAN | WGAN-GP | Wasserstein distance, sharper outputs |
| Flow | Real NVP | Exact likelihood, invertible |
| Ensemble | Weighted combination | Robustness, uncertainty |


In [None]:
# Create model instances
vae_config = MacroVAEConfig(n_macro_vars=9, seq_length=60, n_scenarios=4)
gan_config = MacroGANConfig(n_macro_vars=9, seq_length=60, n_scenarios=4)
flow_config = MacroFlowConfig(n_macro_vars=9, seq_length=60, n_scenarios=4)

vae = MacroVAE(vae_config)
gan = MacroGAN(gan_config)
flow = MacroFlow(flow_config)

print("Model Parameters:")
print(f"  VAE: {sum(p.numel() for p in vae.parameters()):,}")
print(f"  GAN: {sum(p.numel() for p in gan.parameters()):,}")
print(f"  Flow: {sum(p.numel() for p in flow.parameters()):,}")

## 2. Generation Quality

Compare generated samples from each model (using random weights for demonstration).

In [None]:
# Generate samples from each model
n_samples = 100
scenario = torch.tensor([0])  # Baseline

with torch.no_grad():
    vae_samples = vae.generate(scenario, seq_length=60, n_samples=n_samples)
    gan_samples = gan.generate(scenario, n_samples=n_samples)
    flow_samples = flow.generate(scenario, n_samples=n_samples)

vae_samples = vae_samples.numpy()
gan_samples = gan_samples.numpy()
flow_samples = flow_samples.numpy()

print(f"Generated shapes: VAE={vae_samples.shape}, GAN={gan_samples.shape}, Flow={flow_samples.shape}")

In [None]:
# Compare output distributions for first variable (GDP growth)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (name, samples) in zip(axes, [('VAE', vae_samples), ('GAN', gan_samples), ('Flow', flow_samples)]):
    var_data = samples[:, :, 0].flatten()  # GDP growth
    ax.hist(var_data, bins=50, density=True, alpha=0.7, color='steelblue')
    ax.set_xlabel('Value')
    ax.set_ylabel('Density')
    ax.set_title(f'{name} - GDP Growth Distribution')
    ax.axvline(var_data.mean(), color='red', linestyle='--', label=f'Mean: {var_data.mean():.3f}')
    ax.legend()

plt.tight_layout()
plt.show()

## 3. Statistical Comparison

In [None]:
# Compute statistics for each model
def compute_stats(samples):
    """Compute summary statistics"""
    return {
        'mean': samples.mean(axis=(0, 1)),
        'std': samples.std(axis=(0, 1)),
        'skew': stats.skew(samples.reshape(-1, samples.shape[-1]), axis=0),
        'kurtosis': stats.kurtosis(samples.reshape(-1, samples.shape[-1]), axis=0)
    }

vae_stats = compute_stats(vae_samples)
gan_stats = compute_stats(gan_samples)
flow_stats = compute_stats(flow_samples)

var_names = ['GDP', 'Unemp', 'Infl', 'Policy', 'Y10', 'IG', 'HY', 'Prop', 'Equity']

# Create comparison table
comparison = []
for i, var in enumerate(var_names):
    comparison.append({
        'Variable': var,
        'VAE_Mean': vae_stats['mean'][i],
        'GAN_Mean': gan_stats['mean'][i],
        'Flow_Mean': flow_stats['mean'][i],
        'VAE_Std': vae_stats['std'][i],
        'GAN_Std': gan_stats['std'][i],
        'Flow_Std': flow_stats['std'][i]
    })

df_comp = pd.DataFrame(comparison)
print("Model Comparison Statistics:")
print(df_comp.round(4).to_string(index=False))

In [None]:
# Autocorrelation comparison
def compute_autocorr(samples, lag=1):
    """Compute lag-1 autocorrelation for each variable"""
    acfs = []
    for v in range(samples.shape[-1]):
        var_data = samples[:, :, v]  # (n_samples, seq_len)
        acf = np.corrcoef(var_data[:, :-lag].flatten(), var_data[:, lag:].flatten())[0, 1]
        acfs.append(acf)
    return np.array(acfs)

vae_acf = compute_autocorr(vae_samples)
gan_acf = compute_autocorr(gan_samples)
flow_acf = compute_autocorr(flow_samples)

# Plot
fig, ax = plt.subplots(figsize=(12, 5))
x = np.arange(len(var_names))
width = 0.25

ax.bar(x - width, vae_acf, width, label='VAE', color='steelblue')
ax.bar(x, gan_acf, width, label='GAN', color='coral')
ax.bar(x + width, flow_acf, width, label='Flow', color='green')

ax.set_xticks(x)
ax.set_xticklabels(var_names)
ax.set_ylabel('Lag-1 Autocorrelation')
ax.set_title('Temporal Autocorrelation by Model')
ax.legend()
ax.axhline(0, color='gray', linestyle='--')
plt.tight_layout()
plt.show()

## 4. Ensemble Performance

In [None]:
# Create ensemble
ensemble_config = EnsembleConfig(
    n_macro_vars=9,
    seq_length=60,
    n_scenarios=4,
    method=EnsembleMethod.WEIGHTED
)

ensemble = MacroEnsemble(
    config=ensemble_config,
    vae_model=vae,
    gan_model=gan,
    flow_model=flow
)

print(f"Ensemble with {ensemble.n_models} models")

In [None]:
# Compute model disagreement (uncertainty)
disagreement = ensemble.compute_disagreement(scenario, n_samples=100)

print("Model Disagreement (Between-Model Variance):")
mean_var = disagreement['variance'].mean(axis=(0, 1))
for i, var in enumerate(var_names):
    print(f"  {var}: {mean_var[i]:.4f}")

In [None]:
# Generate ensemble predictions
with torch.no_grad():
    ensemble_samples = ensemble.generate(scenario, n_samples=100)

ensemble_samples = ensemble_samples.numpy()

# Compare distributions
fig, ax = plt.subplots(figsize=(10, 5))

for name, samples, color in [('VAE', vae_samples, 'blue'), 
                              ('GAN', gan_samples, 'orange'),
                              ('Flow', flow_samples, 'green'),
                              ('Ensemble', ensemble_samples, 'red')]:
    data = samples[:, :, 0].flatten()
    ax.hist(data, bins=30, density=True, alpha=0.3, label=name, color=color)

ax.set_xlabel('GDP Growth')
ax.set_ylabel('Density')
ax.set_title('GDP Growth Distribution by Model')
ax.legend()
plt.tight_layout()
plt.show()

## Summary

| Model | Strengths | Weaknesses |
|-------|-----------|------------|
| VAE | Smooth interpolation, fast | Mode collapse, blurry |
| GAN | Sharp outputs, flexible | Training instability |
| Flow | Exact likelihood, invertible | Computational cost |
| Ensemble | Robust, uncertainty | Requires all models |

**Recommendations:**
- Use VAE for fast prototyping and interpolation
- Use GAN for sharper scenario boundaries
- Use Flow for tail risk and exact likelihood
- Use Ensemble for production with uncertainty quantification