# Section 5: Model Comparison and Diagnostics

#### PyData London 2025 - Bayesian Time Series Analysis with PyMC

---

## Why Model Diagnostics Matter

Model diagnostics are crucial for ensuring that:
- **MCMC chains have converged** to the target distribution
- **Models fit the data appropriately** without systematic biases
- **Model assumptions are reasonable** for the given data
- **Model selection is based on sound criteria** rather than just fit

### The Bayesian Workflow

1. **Build model** → 2. **Check convergence** → 3. **Validate fit** → 4. **Compare models** → 5. **Iterate**

This iterative process ensures robust, reliable models.

In [None]:
# Import necessary libraries for Section 5
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
import warnings
from scipy import stats

# Configure plotting and suppress warnings
plt.style.use('seaborn-v0_8')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)
RANDOM_SEED = 42

print("🔧 Section 5 libraries loaded successfully!")
print("Ready to diagnose and compare Bayesian time series models")

In [None]:
# Load data and build several models for comparison
births_data = pl.read_csv('../data/births.csv', null_values=['null', 'NA', '', 'NULL'])
births_data = births_data.filter(pl.col('day').is_not_null())

monthly_births = (births_data
    .group_by(['year', 'month'])
    .agg(pl.col('births').sum())
    .sort(['year', 'month'])
)

births_subset = (monthly_births
    .filter((pl.col('year') >= 1970) & (pl.col('year') <= 1990))
    .with_row_index('index')
)

original_data = births_subset['births'].to_numpy()
births_standardized = (original_data - original_data.mean()) / original_data.std()
n_obs = len(births_standardized)

print(f"📊 Data prepared: {n_obs} observations")

## Building Models for Comparison

Let's build several models that we can compare and diagnose.

In [None]:
# Model 1: Simple Normal Model (baseline)
with pm.Model() as normal_model:
    mu = pm.Normal('mu', mu=0, sigma=1)
    sigma = pm.HalfNormal('sigma', sigma=1)
    obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=births_standardized)
    trace_normal = pm.sample(1000, tune=1000, random_seed=RANDOM_SEED, chains=2)

# Model 2: AR(1) Model
with pm.Model() as ar1_model:
    rho = pm.Beta('rho', alpha=1, beta=1)
    phi = pm.Deterministic('phi', 2 * rho - 1)
    sigma = pm.HalfNormal('sigma', sigma=1)
    ar1 = pm.AR('ar1', rho=phi, sigma=sigma, constant=False, steps=n_obs-1)
    obs = pm.Normal('obs', mu=ar1, sigma=0.1, observed=births_standardized[1:])
    trace_ar1 = pm.sample(1000, tune=1000, random_seed=RANDOM_SEED, chains=2)

# Model 3: Random Walk Model
with pm.Model() as rw_model:
    sigma_walk = pm.HalfNormal('sigma_walk', sigma=1.0)
    init_dist = pm.Normal.dist(mu=0, sigma=1)
    walk = pm.GaussianRandomWalk('walk', mu=0, sigma=sigma_walk, 
                                init_dist=init_dist, steps=n_obs-1)
    sigma_obs = pm.HalfNormal('sigma_obs', sigma=1.0)
    obs = pm.Normal('obs', mu=walk, sigma=sigma_obs, observed=births_standardized)
    trace_rw = pm.sample(1000, tune=1000, random_seed=RANDOM_SEED, chains=2)

print("✅ Built three models for comparison")

## 1. Convergence Diagnostics

Before interpreting results, we must ensure that MCMC chains have converged to the target distribution.

### Key Metrics

- **R-hat (Gelman-Rubin statistic)**: Should be < 1.01 for convergence
- **Effective Sample Size (ESS)**: Should be > 400 for reliable inference
- **Monte Carlo Standard Error**: Should be small relative to posterior SD

In [None]:
# Check convergence diagnostics for all models
models_traces = {
    'Normal Model': trace_normal,
    'AR(1) Model': trace_ar1,
    'Random Walk': trace_rw
}

print("🔍 **Convergence Diagnostics**")
print("="*50)

for name, trace in models_traces.items():
    print(f"\n**{name}**:")
    summary = az.summary(trace)
    
    # Check R-hat values
    max_rhat = summary['r_hat'].max()
    print(f"   Max R-hat: {max_rhat:.4f} {'✅' if max_rhat < 1.01 else '❌'}")
    
    # Check effective sample size
    min_ess = summary['ess_bulk'].min()
    print(f"   Min ESS: {min_ess:.0f} {'✅' if min_ess > 400 else '❌'}")
    
    # Check for divergences
    divergences = trace.sample_stats.diverging.sum().values
    print(f"   Divergences: {divergences} {'✅' if divergences == 0 else '❌'}")

In [None]:
# Visualize trace plots for convergence assessment
fig, axes = plt.subplots(3, 2, figsize=(15, 12))

# Normal model traces
az.plot_trace(trace_normal, var_names=['mu', 'sigma'], axes=axes[0, :])
axes[0, 0].set_title('Normal Model - Trace Plots')

# AR(1) model traces  
az.plot_trace(trace_ar1, var_names=['phi', 'sigma'], axes=axes[1, :])
axes[1, 0].set_title('AR(1) Model - Trace Plots')

# Random walk model traces
az.plot_trace(trace_rw, var_names=['sigma_walk', 'sigma_obs'], axes=axes[2, :])
axes[2, 0].set_title('Random Walk Model - Trace Plots')

plt.tight_layout()
plt.show()

print("\n💡 **Trace Plot Interpretation**:")
print("   • **Left panels**: Parameter values over iterations")
print("   • **Right panels**: Marginal posterior distributions")
print("   • **Good mixing**: Chains should overlap and explore the space efficiently")
print("   • **Convergence**: Multiple chains should converge to same distribution")

## 2. Information Criteria for Model Comparison

Information criteria help us compare models by balancing fit quality with model complexity.

### WAIC vs LOO

- **WAIC (Widely Applicable Information Criterion)**: Approximates leave-one-out cross-validation
- **LOO (Leave-One-Out Cross-Validation)**: More robust but computationally intensive
- **Lower values indicate better models**

In [None]:
# Model comparison using information criteria
print("📊 **Model Comparison using Information Criteria**")
print("="*60)

# Compute WAIC and LOO for all models
comparison_waic = az.compare(models_traces, ic='waic')
comparison_loo = az.compare(models_traces, ic='loo')

print("\n**WAIC Comparison**:")
print(comparison_waic)

print("\n**LOO Comparison**:")
print(comparison_loo)

# Visualize model comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

az.plot_compare(comparison_waic, ax=ax1)
ax1.set_title('Model Comparison: WAIC')

az.plot_compare(comparison_loo, ax=ax2)
ax2.set_title('Model Comparison: LOO')

plt.tight_layout()
plt.show()

print("\n💡 **Interpretation**:")
print("   • **Rank**: Lower rank = better model")
print("   • **dWAIC/dLOO**: Difference from best model")
print("   • **Weight**: Relative model probability")
print("   • **SE**: Standard error of the difference")

## 3. Posterior Predictive Checks

Posterior predictive checks help us assess whether our models can reproduce key features of the observed data.

In [None]:
# Posterior predictive checks for model validation
print("🔍 **Posterior Predictive Checks**")
print("="*40)

# Generate posterior predictive samples
with normal_model:
    ppc_normal = pm.sample_posterior_predictive(trace_normal, random_seed=RANDOM_SEED)

with ar1_model:
    ppc_ar1 = pm.sample_posterior_predictive(trace_ar1, random_seed=RANDOM_SEED)

with rw_model:
    ppc_rw = pm.sample_posterior_predictive(trace_rw, random_seed=RANDOM_SEED)

# Plot posterior predictive checks
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Normal model PPC
az.plot_ppc(ppc_normal, num_pp_samples=50, ax=axes[0])
axes[0].set_title('Normal Model - PPC')

# AR(1) model PPC
az.plot_ppc(ppc_ar1, num_pp_samples=50, ax=axes[1])
axes[1].set_title('AR(1) Model - PPC')

# Random walk model PPC
az.plot_ppc(ppc_rw, num_pp_samples=50, ax=axes[2])
axes[2].set_title('Random Walk Model - PPC')

plt.tight_layout()
plt.show()

print("\n💡 **PPC Interpretation**:")
print("   • **Blue line**: Observed data distribution")
print("   • **Light blue**: Posterior predictive samples")
print("   • **Good fit**: Observed data should be typical of predictive samples")
print("   • **Poor fit**: Systematic deviations indicate model misspecification")

## 4. Time Series-Specific Diagnostics

Time series models require additional diagnostics to check for temporal patterns in residuals.

In [None]:
# Time series specific diagnostics
print("📈 **Time Series-Specific Diagnostics**")
print("="*45)

# Function to compute residuals
def compute_residuals(observed, predicted_samples):
    """Compute residuals from posterior predictive samples"""
    pred_mean = predicted_samples.mean(axis=0)
    return observed - pred_mean

# Compute residuals for each model
residuals_normal = compute_residuals(births_standardized, 
                                   ppc_normal.posterior_predictive['obs'].values.reshape(-1, n_obs))
residuals_ar1 = compute_residuals(births_standardized[1:], 
                                ppc_ar1.posterior_predictive['obs'].values.reshape(-1, n_obs-1))
residuals_rw = compute_residuals(births_standardized, 
                               ppc_rw.posterior_predictive['obs'].values.reshape(-1, n_obs))

# Plot residual diagnostics
fig, axes = plt.subplots(3, 2, figsize=(15, 12))

models_residuals = [
    ('Normal Model', residuals_normal),
    ('AR(1) Model', residuals_ar1),
    ('Random Walk', residuals_rw)
]

for i, (name, residuals) in enumerate(models_residuals):
    # Time series plot of residuals
    axes[i, 0].plot(residuals, 'o-', alpha=0.7)
    axes[i, 0].axhline(0, color='red', linestyle='--', alpha=0.7)
    axes[i, 0].set_title(f'{name} - Residuals vs Time')
    axes[i, 0].set_ylabel('Residuals')
    axes[i, 0].grid(True, alpha=0.3)
    
    # Q-Q plot for normality check
    stats.probplot(residuals, dist="norm", plot=axes[i, 1])
    axes[i, 1].set_title(f'{name} - Q-Q Plot')
    axes[i, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Compute autocorrelation of residuals
print("\n**Residual Autocorrelation (Lag 1)**:")
for name, residuals in models_residuals:
    if len(residuals) > 1:
        autocorr = np.corrcoef(residuals[:-1], residuals[1:])[0, 1]
        print(f"   {name}: {autocorr:.4f} {'✅' if abs(autocorr) < 0.1 else '❌'}")

print("\n💡 **Residual Diagnostics Interpretation**:")
print("   • **Time plot**: Should show no patterns or trends")
print("   • **Q-Q plot**: Points should follow diagonal line for normality")
print("   • **Autocorrelation**: Should be close to zero for good models")

## Summary and Best Practices

### Model Diagnostic Checklist

✅ **Convergence**:
- R-hat < 1.01 for all parameters
- ESS > 400 for reliable inference
- No divergent transitions

✅ **Model Fit**:
- Posterior predictive checks show good agreement
- Residuals show no systematic patterns
- Information criteria favor your model

✅ **Time Series Specific**:
- Residuals show no autocorrelation
- No obvious temporal patterns in residuals
- Model captures key data features (trend, seasonality)

### When Models Fail Diagnostics

- **Poor convergence**: Increase tune/draws, reparameterize, or use different sampler
- **Bad fit**: Add missing components (trend, seasonality, AR terms)
- **Residual patterns**: Consider more complex models or different distributions

**Next**: In Section 6, we'll use our validated models for forecasting and practical applications.

---

**Key Takeaways**:
- Always check convergence before interpreting results
- Use multiple criteria for model comparison
- Posterior predictive checks are essential for validation
- Time series models require specialized residual diagnostics
- Iterate and improve models based on diagnostic results