# Convergence Monitoring

## Overview
- **What**: Comprehensive convergence diagnostics for Monte Carlo simulations -- trace plots, R-hat, ESS, loss distribution validation, adaptive stopping, and spectral analysis.
- **Prerequisites**: [../core/04_monte_carlo_engine](../core/04_monte_carlo_engine.ipynb)
- **Estimated runtime**: < 2 minutes
- **Audience**: [Developer]

## Topics Covered
1. MCMC trace plots and convergence diagnostics (R-hat, ESS, ACF)
2. Loss distribution validation (Q-Q plots, K-S tests)
3. Monte Carlo convergence analysis
4. Advanced autocorrelation and spectral density analysis
5. Adaptive stopping criteria
6. Fixed vs. adaptive stopping comparison

In [None]:
"""Google Colab setup: mount Drive and install package dependencies.

Run this cell first. If prompted to restart the runtime, do so, then re-run all cells.
This cell is a no-op when running locally.
"""
import sys, os
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')

    NOTEBOOK_DIR = '/content/drive/My Drive/Colab Notebooks/ei_notebooks/visualization'

    os.chdir(NOTEBOOK_DIR)
    if NOTEBOOK_DIR not in sys.path:
        sys.path.append(NOTEBOOK_DIR)

    !pip install git+https://github.com/AlexFiliakov/Ergodic-Insurance-Limits.git -q 2>&1 | tail -3
    print('\nSetup complete. If you see numpy/scipy import errors below,')
    print('restart the runtime (Runtime > Restart runtime) and re-run all cells.')

In [None]:
# Setup
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats

from ergodic_insurance.visualization.technical_plots import (
    plot_enhanced_convergence_diagnostics,
    plot_trace_plots,
    plot_loss_distribution_validation,
    plot_monte_carlo_convergence,
    plot_convergence_diagnostics,
)
from ergodic_insurance.convergence import ConvergenceDiagnostics
from ergodic_insurance.convergence_advanced import (
    AdvancedConvergenceDiagnostics,
    SpectralDiagnostics,
    AutocorrelationAnalysis,
)
from ergodic_insurance.convergence_plots import RealTimeConvergencePlotter
from ergodic_insurance.adaptive_stopping import (
    AdaptiveStoppingMonitor,
    StoppingCriteria,
    StoppingRule,
)
from ergodic_insurance.loss_distributions import LognormalLoss

np.random.seed(42)

%matplotlib inline
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

print("Convergence monitoring modules loaded successfully!")

## 1. Synthetic MCMC Chains

Generate chains with a known burn-in period and autocorrelation so we can verify that convergence diagnostics correctly identify converged vs. non-converged regions.

In [None]:
N_CHAINS = 4
N_ITERATIONS = 5000
BURN_IN = 1000
TRUE_VALUES = np.array([10.0, 0.5, 2.0])
PARAM_NAMES = ["Premium Rate", "Deductible", "Risk Factor"]
AUTOCORR = 0.5


def generate_mcmc_chains():
    chains = np.zeros((N_CHAINS, N_ITERATIONS, len(TRUE_VALUES)))
    for chain in range(N_CHAINS):
        for param in range(len(TRUE_VALUES)):
            chains[chain, 0, param] = np.random.randn() + TRUE_VALUES[param] * 0.5
            for t in range(1, N_ITERATIONS):
                if t < BURN_IN:
                    mean = TRUE_VALUES[param] * (0.5 + 0.5 * t / BURN_IN)
                    std = TRUE_VALUES[param] * 0.5
                else:
                    mean = TRUE_VALUES[param]
                    std = TRUE_VALUES[param] * 0.1
                innovation = np.random.normal(mean, std)
                chains[chain, t, param] = (
                    AUTOCORR * chains[chain, t - 1, param]
                    + (1 - AUTOCORR) * innovation
                )
    return chains


chains = generate_mcmc_chains()
print(f"Generated {chains.shape[0]} chains x {chains.shape[1]} iterations x {chains.shape[2]} params")

## 2. Trace Plots

Visual check for chain mixing and burn-in period identification.

In [None]:
fig_trace = plot_trace_plots(
    chains,
    parameter_names=PARAM_NAMES,
    burn_in=BURN_IN,
    title="MCMC Trace Plots with Burn-in Period",
    figsize=(14, 8),
)
plt.show()

## 3. Enhanced Convergence Diagnostics

Multi-panel view: R-hat (Gelman-Rubin), Effective Sample Size (ESS), autocorrelation functions, and Monte Carlo Standard Errors (MCSE).

In [None]:
fig_diag = plot_enhanced_convergence_diagnostics(
    chains,
    parameter_names=PARAM_NAMES,
    burn_in=BURN_IN,
    title="Enhanced Convergence Diagnostics",
    figsize=(14, 10),
)
plt.show()

# Numeric summary
diag = ConvergenceDiagnostics()
post_chains = chains[:, BURN_IN:, :]
print("R-hat Statistics (target < 1.1):")
for i, name in enumerate(PARAM_NAMES):
    r_hat = diag.calculate_r_hat(post_chains[:, :, i:i + 1])
    status = "PASS" if r_hat < 1.1 else "FAIL"
    print(f"  {name}: {r_hat:.4f} [{status}]")

print("\nEffective Sample Size (target > 1000):")
for i, name in enumerate(PARAM_NAMES):
    ess = diag.calculate_ess(post_chains[:, :, i].flatten())
    status = "PASS" if ess > 1000 else "FAIL"
    print(f"  {name}: {ess:.0f} [{status}]")

## 4. Loss Distribution Validation

Q-Q plots and goodness-of-fit tests for attritional and large losses to verify that the simulated distributions match theoretical models.

In [None]:
# Generate synthetic loss data
attritional_gen = LognormalLoss(mean=50_000, cv=0.5, seed=42)
attritional_losses = attritional_gen.generate_severity(2000)

large_gen = LognormalLoss(mean=1_000_000, cv=1.5, seed=43)
large_losses = large_gen.generate_severity(200)

# Add outliers for realism
attritional_losses[::100] *= 3
large_losses[::50] *= 5

fig_val = plot_loss_distribution_validation(
    attritional_losses,
    large_losses,
    title="Loss Distribution Validation",
    figsize=(14, 12),
)
plt.show()

print(f"Attritional: n={len(attritional_losses)}, mean=${np.mean(attritional_losses):,.0f}")
print(f"Large:       n={len(large_losses)}, mean=${np.mean(large_losses):,.0f}")

## 5. Monte Carlo Convergence Analysis

Track running means and confidence intervals for key insurance metrics over increasing iteration counts.

In [None]:
def simulate_mc_metrics(n_iterations=10000):
    """Simulate convergence of insurance metrics."""
    targets = {"ROE (%)": 12.0, "Ruin Probability (%)": 0.8,
              "Sharpe Ratio": 1.5, "Premium Adequacy": 1.25}
    noise = {"ROE (%)": 5.0, "Ruin Probability (%)": 0.5,
            "Sharpe Ratio": 0.3, "Premium Adequacy": 0.15}
    history = {}
    for metric, target in targets.items():
        vals = []
        for i in range(n_iterations):
            vf = 1.0 / np.sqrt(i + 1)
            jf = 1.5 if (i % 1000 == 0 and i > 0) else 1.0
            vals.append(target + np.random.normal(0, noise[metric] * vf * jf))
        history[metric] = vals
    return history


metrics_history = simulate_mc_metrics()
convergence_thresholds = {
    "ROE (%)": 12.0,
    "Ruin Probability (%)": 1.0,
    "Sharpe Ratio": 1.5,
    "Premium Adequacy": 1.25,
}

fig_mc = plot_monte_carlo_convergence(
    metrics_history,
    convergence_thresholds=convergence_thresholds,
    title="Monte Carlo Convergence Analysis",
    figsize=(16, 12),
    log_scale=True,
)
plt.show()

print("Final metric estimates (last 1000 iterations):")
for metric, vals in metrics_history.items():
    print(f"  {metric}: {np.mean(vals[-1000:]):.3f} +/- {np.std(vals[-1000:]):.3f}")

## 6. Advanced Autocorrelation Analysis

Compare FFT, direct, and biased ACF methods. The integrated autocorrelation time quantifies how many iterations are needed per effectively independent sample.

In [None]:
adv_diag = AdvancedConvergenceDiagnostics()
chain_sample = chains[0, BURN_IN:, 0]  # Post burn-in, first param

methods = ["fft", "direct", "biased"]
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, method in enumerate(methods):
    result = adv_diag.calculate_autocorrelation_full(chain_sample, max_lag=100, method=method)
    axes[i].plot(result.lags, result.acf_values, "b-", linewidth=1.5)
    axes[i].axhline(y=0, color="black", linestyle="-", alpha=0.3)
    ci = 1.96 / np.sqrt(len(chain_sample))
    axes[i].fill_between(result.lags, -ci, ci, alpha=0.2, color="gray", label="95% CI")
    axes[i].set_title(f"ACF ({method.upper()})")
    axes[i].set_xlabel("Lag")
    axes[i].set_ylabel("Autocorrelation")
    axes[i].grid(True, alpha=0.3)
    axes[i].set_ylim(-0.2, 1.0)

plt.suptitle("Autocorrelation Analysis Comparison", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

## 7. Spectral Density and ESS Methods

Spectral density estimation provides an alternative ESS calculation. Comparing multiple ESS methods builds confidence in convergence assessment.

In [None]:
# Spectral density
spectral_methods = ["welch", "periodogram"]
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for i, method in enumerate(spectral_methods):
    result = adv_diag.calculate_spectral_density(chain_sample, method=method)
    axes[i].semilogy(result.frequencies, result.spectral_density, "b-", lw=1)
    axes[i].set_title(f"Spectral Density ({method.capitalize()})")
    axes[i].set_xlabel("Frequency")
    axes[i].set_ylabel("PSD")
    axes[i].grid(True, alpha=0.3)
    axes[i].text(
        0.6, 0.9,
        f"tau={result.integrated_autocorr_time:.2f}\nESS={result.effective_sample_size:.0f}",
        transform=axes[i].transAxes, fontsize=10,
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
    )
plt.suptitle("Spectral Density Analysis", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

# ESS comparison across methods
ess_methods = {
    "Batch Means": lambda c: adv_diag.calculate_ess_batch_means(c),
    "Overlap Batch": lambda c: adv_diag.calculate_ess_overlapping_batch(c),
    "Spectral": lambda c: adv_diag.calculate_spectral_density(c, "welch").effective_sample_size,
    "Basic ACF": lambda c: ConvergenceDiagnostics().calculate_ess(c),
}

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for pi, pname in enumerate(PARAM_NAMES):
    ess_vals = {}
    for mname, mfunc in ess_methods.items():
        per_chain = [mfunc(chains[ci, BURN_IN:, pi]) for ci in range(N_CHAINS)]
        ess_vals[mname] = per_chain
    df = pd.DataFrame(ess_vals, index=[f"Chain {c+1}" for c in range(N_CHAINS)])
    df.plot(kind="bar", ax=axes[pi], width=0.8)
    axes[pi].set_title(pname)
    axes[pi].set_ylabel("ESS")
    axes[pi].legend(loc="upper right", fontsize=8)
    axes[pi].grid(True, alpha=0.3)
    axes[pi].axhline(y=1000, color="red", linestyle="--", alpha=0.5)
plt.suptitle("ESS Calculation Methods Comparison", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

## 8. Advanced Diagnostic Tests

Heidelberger-Welch stationarity test and Raftery-Lewis diagnostic provide quantitative convergence assessment beyond R-hat.

In [None]:
print("Advanced Diagnostic Tests")
print("=" * 50)

for pi, pname in enumerate(PARAM_NAMES):
    print(f"\n{pname}:")
    print("-" * 30)
    combined = chains[:, BURN_IN:, pi].flatten()

    hw = adv_diag.heidelberger_welch_advanced(combined)
    print(f"  Heidelberger-Welch: stationary={hw['stationary']} (p={hw['pvalue']:.3f})")
    print(f"    Halfwidth test: {'PASSED' if hw['halfwidth_passed'] else 'FAILED'}")
    print(f"    Mean: {hw['mean']:.3f} +/- {hw['mcse']:.4f}")

    rl = adv_diag.raftery_lewis_diagnostic(combined)
    print(f"  Raftery-Lewis: min_n={rl['n_min']:.0f}, sufficient={'YES' if rl['sufficient'] else 'NO'}")
    print(f"    Dependence factor: {rl['dependence_factor']:.2f}")

## 9. Adaptive Stopping

Automatically terminate simulations when convergence criteria are met, saving computation while ensuring quality. The patience mechanism prevents premature stopping.

In [None]:
criteria = StoppingCriteria(
    rule=StoppingRule.COMBINED,
    r_hat_threshold=1.05,
    min_ess=1000,
    mcse_relative_threshold=0.05,
    min_iterations=1000,
    max_iterations=10000,
    check_interval=200,
    patience=3,
)
monitor = AdaptiveStoppingMonitor(criteria)

print("Adaptive Stopping Criteria:")
print(f"  Rule: {criteria.rule.value}")
print(f"  R-hat threshold: {criteria.r_hat_threshold}")
print(f"  Min ESS: {criteria.min_ess}")
print(f"  Patience: {criteria.patience} checks")
print(f"  Check interval: every {criteria.check_interval} iterations")

# Simulate adaptive stopping
print("\nIteration | R-hat | ESS    | Status")
print("-" * 50)
for iteration in range(1000, 5001, 200):
    current_chains = chains[:, :iteration, 0]
    status = monitor.check_convergence(iteration, current_chains)
    if iteration % criteria.check_interval == 0 and iteration >= criteria.min_iterations:
        r_hat = status.diagnostics.get("r_hat", 0)
        ess = status.diagnostics.get("ess", 0)
        mark = "PASS" if status.converged else "----"
        print(f"{iteration:8d} | {r_hat:.3f} | {ess:6.0f} | {mark} {status.reason[:30]}...")
        if status.should_stop:
            print(f"\nSTOPPED at iteration {iteration}: {status.reason}")
            break

## 10. Fixed vs. Adaptive Stopping Comparison

Adaptive stopping typically saves 30-50% of iterations compared to fixed budgets while maintaining the same convergence quality.

In [None]:
basic_diag = ConvergenceDiagnostics()

scenarios = {
    "Fixed 5000": {"type": "fixed", "iterations": 5000},
    "Fixed 10000": {"type": "fixed", "iterations": 10000},
    "Adaptive (R<1.1)": {"type": "adaptive", "r_hat": 1.1, "min_ess": 500},
    "Adaptive (R<1.05)": {"type": "adaptive", "r_hat": 1.05, "min_ess": 1000},
    "Adaptive (Strict)": {"type": "adaptive", "r_hat": 1.01, "min_ess": 2000},
}

results = []
for name, cfg in scenarios.items():
    if cfg["type"] == "fixed":
        n_iter = cfg["iterations"]
        fc = chains[:, :n_iter, 0]
        r_hat = basic_diag.calculate_r_hat(fc[:, :, np.newaxis])
        ess = basic_diag.calculate_ess(fc.flatten())
    else:
        crit = StoppingCriteria(
            rule=StoppingRule.COMBINED,
            r_hat_threshold=cfg["r_hat"],
            min_ess=cfg["min_ess"],
            min_iterations=1000,
            check_interval=100,
            patience=2,
        )
        mon = AdaptiveStoppingMonitor(crit)
        n_iter = 5000
        for ci in range(1000, 5001, 100):
            st = mon.check_convergence(ci, chains[:, :ci, 0])
            if st.should_stop:
                n_iter = ci
                break
        r_hat = st.diagnostics.get("r_hat", np.inf)
        ess = st.diagnostics.get("ess", 0)

    results.append({
        "Scenario": name, "Iterations": n_iter,
        "R-hat": r_hat, "ESS": ess,
        "ESS/Iter": ess / n_iter,
    })

df = pd.DataFrame(results)
print("Performance Comparison:")
print(df.to_string(index=False))

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
colors = ["blue", "blue", "green", "green", "green"]

axes[0].bar(range(len(df)), df["Iterations"], color=colors)
axes[0].set_xticks(range(len(df)))
axes[0].set_xticklabels(df["Scenario"], rotation=45, ha="right")
axes[0].set_ylabel("Iterations")
axes[0].set_title("Computational Cost")
axes[0].grid(True, alpha=0.3)

bar_colors = ["red" if r > 1.1 else "orange" if r > 1.05 else "green" for r in df["R-hat"]]
axes[1].bar(range(len(df)), df["R-hat"], color=bar_colors)
axes[1].axhline(y=1.1, color="red", linestyle="--", alpha=0.5)
axes[1].set_xticks(range(len(df)))
axes[1].set_xticklabels(df["Scenario"], rotation=45, ha="right")
axes[1].set_ylabel("R-hat")
axes[1].set_title("Convergence Quality")
axes[1].grid(True, alpha=0.3)

axes[2].bar(range(len(df)), df["ESS/Iter"])
axes[2].set_xticks(range(len(df)))
axes[2].set_xticklabels(df["Scenario"], rotation=45, ha="right")
axes[2].set_ylabel("ESS per Iteration")
axes[2].set_title("Sampling Efficiency")
axes[2].grid(True, alpha=0.3)

plt.suptitle("Fixed vs Adaptive Stopping Comparison", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

fixed_base = df.loc[df["Scenario"] == "Fixed 10000", "Iterations"].values[0]
adaptive_opt = df.loc[df["Scenario"] == "Adaptive (R<1.05)", "Iterations"].values[0]
savings = (fixed_base - adaptive_opt) / fixed_base * 100
print(f"\nComputational savings with adaptive stopping: {savings:.1f}%")

## Key Takeaways

- R-hat < 1.1 and ESS > 1000 are standard convergence thresholds; always check both.
- Heidelberger-Welch and Raftery-Lewis tests provide additional statistical rigor.
- Spectral density and overlapping-batch ESS methods offer robust alternatives to basic ACF.
- Adaptive stopping typically saves 30-50% of iterations while maintaining convergence quality.
- Always validate loss distributions with Q-Q plots and K-S tests before trusting simulation output.

## Next Steps

- [05_ruin_analysis_plots](05_ruin_analysis_plots.ipynb) -- ruin cliff and ROE-ruin frontier
- [../core/04_monte_carlo_engine](../core/04_monte_carlo_engine.ipynb) -- Monte Carlo engine details
- [06_scenario_comparison](06_scenario_comparison.ipynb) -- scenario comparison and annotation framework