In [None]:
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm, gamma, poisson
import jax.random as random
from jax import vmap, jit

print("=== MCMC CONVERGENCE DIAGNOSTICS - COMPLETE IMPLEMENTATION ===\n")

# =============================================================================
# STEP 1: R-hat Calculation
# =============================================================================
print("STEP 1: R-hat (Potential Scale Reduction Factor)")
print("-" * 50)

def compute_Rhat(chains):
    """
    Compute the potential scale reduction factor (R-hat) for MCMC chains.
    
    Parameters:
    -----------
    chains : array-like, shape (M, S, P)
        M = number of chains
        S = number of samples per chain
        P = number of parameters
    
    Returns:
    --------
    Rhat : array-like, shape (P,)
        R-hat for each parameter
    """
    M, S, P = chains.shape
    
    # Initialize array for R-hat values
    Rhat = jnp.zeros(P)
    
    for p in range(P):
        # Extract parameter p from all chains
        param_chains = chains[:, :, p]
        
        # Compute chain means
        chain_means = jnp.mean(param_chains, axis=1)
        
        # Compute overall mean
        overall_mean = jnp.mean(param_chains)
        
        # Within-chain variance
        W = jnp.mean(jnp.var(param_chains, axis=1, ddof=1))
        
        # Between-chain variance
        B = S * jnp.var(chain_means, ddof=1)
        
        # Compute R-hat
        R_squared = (S - 1) / S + B / (S * W)
        Rhat = Rhat.at[p].set(jnp.sqrt(R_squared))
    
    return Rhat

# Small hand calculation example
print("Hand calculation example:")
chain1 = jnp.array([1.2, 1.4, 1.1, 1.3, 1.2])
chain2 = jnp.array([1.8, 1.6, 1.9, 1.7, 1.5])
chains_small = jnp.stack([chain1, chain2])[:, :, None]

R_hat_small = compute_Rhat(chains_small)
print(f"Chains:\n{chains_small[:, :, 0]}")
print(f"R-hat = {R_hat_small[0]:.3f}")

# =============================================================================
# STEP 2: Effective Sample Size Calculation
# =============================================================================
print("\n\nSTEP 2: Effective Sample Size (ESS)")
print("-" * 50)

def compute_autocorrelation(x, max_lag=None):
    """Compute autocorrelation function up to max_lag"""
    n = len(x)
    if max_lag is None:
        max_lag = n // 4
    
    x_centered = x - jnp.mean(x)
    c0 = jnp.dot(x_centered, x_centered) / n
    
    autocorr = []
    for lag in range(max_lag + 1):
        if lag == 0:
            autocorr.append(1.0)
        else:
            ck = jnp.dot(x_centered[:-lag], x_centered[lag:]) / (n - lag)
            autocorr.append(ck / c0)
    
    return jnp.array(autocorr)

def compute_effective_sample_size(chains):
    """
    Compute effective sample size for MCMC chains.
    
    Parameters:
    -----------
    chains : array-like, shape (M, S, P)
        M = number of chains
        S = number of samples per chain
        P = number of parameters
    
    Returns:
    --------
    ESS : array-like, shape (P,)
        Effective sample size for each parameter
    """
    M, S, P = chains.shape
    ESS = jnp.zeros(P)
    
    for p in range(P):
        # Merge all chains for parameter p
        merged_chain = chains[:, :, p].ravel()
        n = len(merged_chain)
        
        # Compute autocorrelation
        max_lag = min(n // 4, 1000)
        autocorr = compute_autocorrelation(merged_chain, max_lag)
        
        # Find first negative autocorrelation
        sum_autocorr = 1.0
        for lag in range(1, len(autocorr)):
            if autocorr[lag] < 0:
                break
            sum_autocorr += 2 * autocorr[lag]
        
        ESS = ESS.at[p].set(n / sum_autocorr)
    
    return ESS

# Example calculation
print("ESS calculation example:")
ess_small = compute_effective_sample_size(chains_small)
print(f"Total samples: {chains_small.size}")
print(f"ESS = {ess_small[0]:.3f}")
print(f"Efficiency = {ess_small[0]/chains_small.size:.3f}")

# =============================================================================
# STEP 3: Multiple Metropolis Chains
# =============================================================================
print("\n\nSTEP 3: Multiple Metropolis Chains")
print("-" * 50)

def metropolis(log_target, num_params, tau, num_iter, theta_init=None, seed=0):
    """Basic Metropolis algorithm with Gaussian proposals"""
    if theta_init is None:
        theta_init = jnp.zeros(num_params)
    
    # Initialize
    key = random.PRNGKey(seed)
    samples = jnp.zeros((num_iter + 1, num_params))
    samples = samples.at[0].set(theta_init)
    current_log_prob = log_target(theta_init)
    
    n_accepted = 0
    
    for i in range(num_iter):
        key, subkey = random.split(key)
        
        # Propose new state
        proposal = samples[i] + tau * random.normal(subkey, shape=(num_params,))
        
        # Compute acceptance ratio
        proposal_log_prob = log_target(proposal)
        log_ratio = proposal_log_prob - current_log_prob
        
        # Accept/reject
        key, accept_key = random.split(key)
        if jnp.log(random.uniform(accept_key)) < log_ratio:
            samples = samples.at[i + 1].set(proposal)
            current_log_prob = proposal_log_prob
            n_accepted += 1
        else:
            samples = samples.at[i + 1].set(samples[i])
    
    accept_rate = n_accepted / num_iter
    return samples, accept_rate

def metropolis_multiple_chains(log_target, num_params, num_chains, tau, num_iter, 
                             theta_init, seeds, warm_up=0):
    """Run multiple Metropolis chains in parallel"""
    # Prepare storage
    thetas = []
    accept_rates = []
    
    # Run each chain
    for idx_chain in range(num_chains):
        print(f'Running chain {idx_chain + 1}/{num_chains}')
        samples, accept_rate = metropolis(log_target, num_params, tau, num_iter,
                                        theta_init=theta_init[idx_chain], 
                                        seed=seeds[idx_chain])
        thetas.append(samples)
        accept_rates.append(accept_rate)
    
    # Stack chains
    thetas = jnp.stack(thetas, axis=0)
    accept_rates = jnp.array(accept_rates)
    
    # Discard warmup
    thetas = thetas[:, warm_up:, :]
    
    return thetas, accept_rates

# =============================================================================
# STEP 4: Bimodal Distribution Example
# =============================================================================
print("\n\nSTEP 4: Bimodal Distribution Example")
print("-" * 50)

# Define bimodal distribution
def log_bimodal(x):
    """Log density of mixture: 0.5 * N(-3, 4) + 0.5 * N(1, 2)"""
    log_p1 = jnp.log(0.5) + norm.logpdf(x, loc=-3, scale=2)
    log_p2 = jnp.log(0.5) + norm.logpdf(x, loc=1, scale=jnp.sqrt(2))
    return jnp.logaddexp(log_p1, log_p2).sum()

# True statistics
true_mean = -1.0
true_var = 7.0
print(f"True mean: {true_mean}")
print(f"True variance: {true_var}")

# MCMC settings
num_chains = 4
num_iter = 1000
proposal_variance = 0.1
num_params = 1
warm_up = 500

# Initial values
key = random.PRNGKey(1)
theta_init = 5 * random.normal(key, shape=(num_chains, num_params))
seeds = jnp.arange(num_chains)

# Run chains
print("\nRunning MCMC for bimodal distribution...")
chains_bimodal, accepts = metropolis_multiple_chains(
    log_bimodal, num_params, num_chains, proposal_variance, 
    num_iter, theta_init, seeds, warm_up)

# Compute diagnostics
Rhat_bimodal = compute_Rhat(chains_bimodal)
ESS_bimodal = compute_effective_sample_size(chains_bimodal)

# Results
merged_samples = chains_bimodal.ravel()
estimated_mean = jnp.mean(merged_samples)
estimated_var = jnp.var(merged_samples)

print(f"\nEstimated mean: {estimated_mean:.3f}")
print(f"Estimated variance: {estimated_var:.3f}")
print(f"R-hat: {Rhat_bimodal[0]:.3f}")
print(f"ESS: {ESS_bimodal[0]:.0f}")
print(f"Relative efficiency: {ESS_bimodal[0]/(num_chains*(num_iter-warm_up)):.3f}")

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Trace plots
ax = axes[0]
for i in range(num_chains):
    ax.plot(chains_bimodal[i, :, 0], alpha=0.7, label=f'Chain {i+1}')
ax.set_xlabel('Iteration')
ax.set_ylabel('Sample value')
ax.set_title(f'Trace plots (R̂={Rhat_bimodal[0]:.3f})')
ax.legend()

# Histogram vs true density
ax = axes[1]
x_range = jnp.linspace(-8, 6, 1000)
true_density = 0.5 * norm.pdf(x_range, -3, 2) + 0.5 * norm.pdf(x_range, 1, jnp.sqrt(2))
ax.hist(merged_samples, bins=50, density=True, alpha=0.7, label='Samples')
ax.plot(x_range, true_density, 'r-', linewidth=2, label='True density')
ax.axvline(true_mean, color='g', linestyle='--', label='True mean')
ax.axvline(estimated_mean, color='b', linestyle='--', label='Est. mean')
ax.set_title('Posterior samples')
ax.legend()

# Autocorrelation
ax = axes[2]
autocorr = compute_autocorrelation(merged_samples, max_lag=50)
ax.plot(autocorr, 'o-')
ax.axhline(0, color='k', linestyle='--')
ax.set_xlabel('Lag')
ax.set_ylabel('Autocorrelation')
ax.set_title(f'Autocorrelation (ESS={ESS_bimodal[0]:.0f})')
ax.grid(True)

plt.tight_layout()
plt.show()

# =============================================================================
# STEP 5: Change Point Detection Model
# =============================================================================
print("\n\nSTEP 5: Change Point Detection Model")
print("-" * 50)

# Generate simulated data
np.random.seed(42)
N = 50
true_c = 25
true_lambda1 = 2.0
true_lambda2 = 5.0

accident_counts = jnp.concatenate([
    jnp.array(np.random.poisson(true_lambda1, true_c)),
    jnp.array(np.random.poisson(true_lambda2, N - true_c))
])

year = jnp.arange(1850, 1850 + N)

print(f"Simulated data: N={N}, true change point={true_c}")
print(f"True λ₁={true_lambda1}, true λ₂={true_lambda2}")

# Plot data
fig, ax = plt.subplots(1, 1, figsize=(12, 5))
ax.scatter(year, accident_counts, alpha=0.7)
ax.axvline(year[true_c], color='r', linestyle='--', label='True change point')
ax.set_xlabel('Year')
ax.set_ylabel('Accident count')
ax.set_title('Simulated Accident Data')
ax.legend()
plt.show()

# Gibbs sampler
def cpd_gibbs_sampler(x, alpha, beta, num_iter, c_init, lambda1_init, 
                     lambda2_init, warmup=0, seed=0):
    """Gibbs sampler for change point detection"""
    N = len(x)
    
    # Storage
    lambda1_samples = [lambda1_init]
    lambda2_samples = [lambda2_init]
    c_samples = [c_init]
    
    key = random.PRNGKey(seed)
    
    for k in range(num_iter):
        key, subkey = random.split(key)
        key1, key2, key3 = random.split(subkey, num=3)
        
        # Sample λ₁
        c_k = int(c_samples[k])
        a1 = alpha + jnp.sum(x[:c_k])
        b1 = 1.0 / (beta + c_k)
        lambda1_new = b1 * random.gamma(key1, a1)
        lambda1_samples.append(lambda1_new)
        
        # Sample λ₂
        a2 = alpha + jnp.sum(x[c_k:])
        b2 = 1.0 / (beta + N - c_k)
        lambda2_new = b2 * random.gamma(key2, a2)
        lambda2_samples.append(lambda2_new)
        
        # Sample c
        log_prob_c = []
        for ci in range(N):
            log_p = (jnp.sum(x[:ci]) * jnp.log(lambda1_samples[k+1]) - 
                    ci * lambda1_samples[k+1] +
                    jnp.sum(x[ci:]) * jnp.log(lambda2_samples[k+1]) - 
                    (N - ci) * lambda2_samples[k+1])
            log_prob_c.append(log_p)
        
        log_prob_c = jnp.array(log_prob_c)
        log_prob_c = log_prob_c - jnp.max(log_prob_c)
        prob_c = jnp.exp(log_prob_c)
        prob_c = prob_c / jnp.sum(prob_c)
        
        c_new = random.choice(key3, jnp.arange(N), p=prob_c)
        c_samples.append(c_new)
        
        if (k + 1) % (num_iter // 5) == 0:
            print(f'Iteration {k + 1}/{num_iter}')
    
    # Discard warmup
    if warmup > 0:
        lambda1_samples = lambda1_samples[warmup:]
        lambda2_samples = lambda2_samples[warmup:]
        c_samples = c_samples[warmup:]
    
    return (jnp.array(lambda1_samples), 
            jnp.array(lambda2_samples), 
            jnp.array(c_samples))

# Run multiple chains
print("\nRunning Gibbs sampler with multiple chains...")
alpha = 1.0
beta = 1.0
num_iter = 2000
num_chains = 4
warmup = 1000

# Storage for chains
all_samples = []

key = random.PRNGKey(1)
for chain_idx in range(num_chains):
    print(f'\nChain {chain_idx + 1}/{num_chains}')
    
    # Random initial values
    key, subkey = random.split(key)
    key1, key2, key3 = random.split(subkey, num=3)
    
    c_init = random.choice(key1, jnp.arange(N))
    l1_init = 1/beta * random.gamma(key2, alpha)
    l2_init = 1/beta * random.gamma(key3, alpha)
    
    # Run sampler
    l1, l2, c = cpd_gibbs_sampler(accident_counts, alpha, beta, num_iter,
                                  c_init, l1_init, l2_init, warmup, seed=chain_idx)
    
    # Stack parameters
    chain_samples = jnp.stack([c, l1, l2], axis=1)
    all_samples.append(chain_samples)

# Stack all chains
all_samples = jnp.stack(all_samples, axis=0)

# Compute diagnostics
print("\nComputing convergence diagnostics...")
Rhat_cpd = compute_Rhat(all_samples)
ESS_cpd = compute_effective_sample_size(all_samples)

parameter_names = ['c', 'λ₁', 'λ₂']
print("\nParameter | R-hat | ESS | Rel. Efficiency")
print("-" * 40)
for i, name in enumerate(parameter_names):
    rel_eff = ESS_cpd[i] / (num_chains * (num_iter - warmup))
    print(f"{name:9s} | {Rhat_cpd[i]:.3f} | {ESS_cpd[i]:.0f} | {rel_eff:.3f}")

# Posterior analysis
merged_samples = all_samples.reshape(-1, 3)
c_samples = merged_samples[:, 0].astype(int)
lambda1_samples = merged_samples[:, 1]
lambda2_samples = merged_samples[:, 2]

print("\nPosterior means:")
print(f"c: {jnp.mean(c_samples):.1f} (true: {true_c})")
print(f"λ₁: {jnp.mean(lambda1_samples):.3f} (true: {true_lambda1})")
print(f"λ₂: {jnp.mean(lambda2_samples):.3f} (true: {true_lambda2})")

# Visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Trace plots
for i, (param_name, true_val) in enumerate(zip(parameter_names, 
                                               [true_c, true_lambda1, true_lambda2])):
    ax = axes[0, i]
    for chain_idx in range(num_chains):
        ax.plot(all_samples[chain_idx, :, i], alpha=0.7, label=f'Chain {chain_idx+1}')
    ax.axhline(true_val, color='r', linestyle='--', label='True value')
    ax.set_title(f'{param_name} (R̂={Rhat_cpd[i]:.3f})')
    ax.set_xlabel('Iteration')
    ax.set_ylabel(param_name)
    if i == 0:
        ax.legend()

# Posterior distributions
for i, (param_name, samples, true_val) in enumerate(zip(
    parameter_names, 
    [c_samples, lambda1_samples, lambda2_samples],
    [true_c, true_lambda1, true_lambda2])):
    
    ax = axes[1, i]
    ax.hist(samples, bins=30, density=True, alpha=0.7)
    ax.axvline(true_val, color='r', linestyle='--', label='True value')
    ax.axvline(jnp.mean(samples), color='g', linestyle='--', label='Posterior mean')
    ax.set_xlabel(param_name)
    ax.set_title(f'Posterior of {param_name}')
    ax.legend()

plt.tight_layout()
plt.show()

# =============================================================================
# STEP 6: Advanced Analysis - Parameter Efficiency
# =============================================================================
print("\n\nSTEP 6: Parameter Efficiency Analysis")
print("-" * 50)

# Sweep over proposal variances for bimodal distribution
taus = jnp.logspace(-2, 2, 20)
R_effs = []

print("Testing different proposal variances...")
for idx_tau, tau in enumerate(taus):
    print(f'Testing τ = {tau:.3f} ({idx_tau+1}/{len(taus)})')
    
    chains, _ = metropolis_multiple_chains(
        log_bimodal, 1, 4, tau, 1000, 
        5*random.normal(random.PRNGKey(123), shape=(4, 1)), 
        jnp.arange(4), warm_up=500)
    
    ESS = compute_effective_sample_size(chains)
    total_samples = chains.size
    R_effs.append(ESS[0] / total_samples)

R_effs = jnp.array(R_effs)
idx_optimal = jnp.argmax(R_effs)

print(f"\nOptimal proposal variance: {taus[idx_optimal]:.3f}")
print(f"Maximum efficiency: {R_effs[idx_optimal]:.3f}")

# Plot efficiency curve
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
ax.semilogx(taus, R_effs, 'o-')
ax.axvline(taus[idx_optimal], color='r', linestyle='--', 
          label=f'Optimal τ = {taus[idx_optimal]:.3f}')
ax.set_xlabel('Proposal variance τ')
ax.set_ylabel('Relative efficiency')
ax.set_title('Efficiency vs Proposal Variance')
ax.legend()
ax.grid(True)
plt.show()

# =============================================================================
# SUMMARY
# =============================================================================
print("\n\n=== SUMMARY ===")
print("-" * 50)
print("MCMC Convergence Diagnostics:")
print("1. R-hat measures convergence by comparing within/between chain variance")
print("2. ESS accounts for autocorrelation to give effective sample size")
print("3. Multiple chains help diagnose convergence issues")
print("\nKey findings:")
print(f"- Bimodal distribution: R̂={Rhat_bimodal[0]:.3f}, ESS={ESS_bimodal[0]:.0f}")
print(f"- Change point model: R̂={Rhat_cpd.mean():.3f}, ESS={ESS_cpd.mean():.0f}")
print("- Optimal proposal variance maximizes efficiency")
print("- Gibbs sampling often more efficient than Metropolis for structured models")