In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scipy.stats import norm, gamma, poisson
import jax.random as random

print("=== MCMC HAND CALCULATIONS IN PYTHON ===\n")

# =============================================================================
# PART 1: BIMODAL DISTRIBUTION
# =============================================================================
print("PART 1: BIMODAL DISTRIBUTION")
print("=" * 60)

# 1.1 Theoretical Calculations
print("\n1.1 Theoretical Calculations")
print("-" * 40)

# Define the bimodal distribution
def p_bimodal(x):
    """Bimodal distribution: 0.5 * N(x|-3, 4) + 0.5 * N(x|1, 2)"""
    p1 = 0.5 * norm.pdf(x, loc=-3, scale=2)  # scale = sqrt(variance)
    p2 = 0.5 * norm.pdf(x, loc=1, scale=np.sqrt(2))
    return p1 + p2

print("Distribution: p(x) = 0.5 × N(x|-3, 4) + 0.5 × N(x|1, 2)")
print("\nTheoretical moments:")
print("E[X] = 0.5 × (-3) + 0.5 × 1 = -1")
print("E[X²] = 0.5 × (4 + 9) + 0.5 × (2 + 1) = 8")
print("Var[X] = E[X²] - (E[X])² = 8 - 1 = 7")

# 1.2 Metropolis Sampling - Hand Calculation
print("\n\n1.2 Metropolis Sampling - Hand Calculation")
print("-" * 40)

x0 = 0
tau = 0.5
print(f"Initial: x₀ = {x0}")
print(f"Proposal variance: τ = {tau}")

# Iteration 1
print("\nIteration 1:")
epsilon1 = 0.8  # Example random number
x_proposed = x0 + tau * epsilon1
print(f"1. Propose: x' = {x0} + {tau} × {epsilon1} = {x_proposed}")

p_x0 = p_bimodal(x0)
p_xprop = p_bimodal(x_proposed)
alpha = p_xprop / p_x0

print(f"2. Compute acceptance ratio:")
print(f"   p(x₀) = p({x0}) = {p_x0:.4f}")
print(f"   p(x') = p({x_proposed}) = {p_xprop:.4f}")
print(f"   α = {p_xprop:.4f} / {p_x0:.4f} = {alpha:.3f}")
print(f"3. Since α = {alpha:.3f} > 1, accept: x₁ = {x_proposed}")

x1 = x_proposed

# Iteration 2
print("\nIteration 2:")
epsilon2 = -1.2
x_proposed2 = x1 + tau * epsilon2
print(f"1. Propose: x' = {x1} + {tau} × {epsilon2} = {x_proposed2}")

p_x1 = p_bimodal(x1)
p_xprop2 = p_bimodal(x_proposed2)
alpha2 = p_xprop2 / p_x1

print(f"2. Compute acceptance ratio:")
print(f"   p(x₁) = p({x1}) = {p_x1:.4f}")
print(f"   p(x') = p({x_proposed2}) = {p_xprop2:.4f}")
print(f"   α = {p_xprop2:.4f} / {p_x1:.4f} = {alpha2:.3f}")

u = 0.3  # Example uniform random number
print(f"3. Generate u ~ U(0,1) = {u}")
if u < alpha2:
    x2 = x_proposed2
    print(f"   Since {u} < {alpha2:.3f}, accept: x₂ = {x_proposed2}")
else:
    x2 = x1
    print(f"   Since {u} ≥ {alpha2:.3f}, reject: x₂ = {x1}")

# 1.3 Multiple Chains and R-hat Calculation
print("\n\n1.3 Multiple Chains and R-hat Calculation")
print("-" * 40)

# Example chains
chain1 = np.array([0.4, -0.2, -0.5, 0.8, 1.2])
chain2 = np.array([-3.1, -2.8, -3.2, -2.9, -3.0])

print(f"Chain 1: {chain1}")
print(f"Chain 2: {chain2}")
print("Notice: Chain 2 is stuck in the left mode!")

# Compute R-hat
M = 2  # number of chains
S = 5  # samples per chain

# Chain means
chain1_mean = np.mean(chain1)
chain2_mean = np.mean(chain2)
overall_mean = (chain1_mean + chain2_mean) / 2

print(f"\nChain means:")
print(f"θ̄₁ = {chain1_mean:.3f}")
print(f"θ̄₂ = {chain2_mean:.3f}")
print(f"θ̄ = {overall_mean:.3f}")

# Within-chain variances
s1_sq = np.var(chain1, ddof=1)
s2_sq = np.var(chain2, ddof=1)
W = (s1_sq + s2_sq) / 2

print(f"\nWithin-chain variances:")
print(f"s₁² = {s1_sq:.3f}")
print(f"s₂² = {s2_sq:.3f}")
print(f"W = {W:.3f}")

# Between-chain variance
B = S * np.var([chain1_mean, chain2_mean], ddof=1)

print(f"\nBetween-chain variance:")
print(f"B = {B:.3f}")

# R-hat
R_sq = (S - 1) / S + B / (S * W)
R_hat = np.sqrt(R_sq)

print(f"\nR̂² = (S-1)/S + B/(S×W)")
print(f"   = {(S-1)/S:.3f} + {B/(S*W):.3f}")
print(f"   = {R_sq:.3f}")
print(f"R̂ = {R_hat:.3f}")
print(f"\nThis high R̂ = {R_hat:.3f} indicates poor convergence!")

# =============================================================================
# PART 2: CHANGE POINT DETECTION
# =============================================================================
print("\n\nPART 2: CHANGE POINT DETECTION")
print("=" * 60)

# 2.1 Problem Setup
print("\n2.1 Problem Setup")
print("-" * 40)

data = np.array([3, 2, 1, 5, 6, 7])
N = len(data)
alpha = 1
beta = 1

print(f"Data: x = {data}")
print(f"N = {N} years")
print(f"Hyperparameters: α = {alpha}, β = {beta}")

# 2.2 Gibbs Sampling - Hand Calculation
print("\n\n2.2 Gibbs Sampling - Iteration 1")
print("-" * 40)

# Initial values
c0 = 3
lambda1_0 = 2.0
lambda2_0 = 5.0

print(f"Initial values: c⁽⁰⁾ = {c0}, λ₁⁽⁰⁾ = {lambda1_0}, λ₂⁽⁰⁾ = {lambda2_0}")

# Step 1: Sample λ₁
print("\nStep 1: Sample λ₁ | x, c, λ₂")
sum_before = np.sum(data[:c0])
alpha_1 = alpha + sum_before
beta_1 = beta + c0

print(f"Sum of data before c = {c0}: {sum_before}")
print(f"α' = {alpha} + {sum_before} = {alpha_1}")
print(f"β' = {beta} + {c0} = {beta_1}")
print(f"λ₁⁽¹⁾ ~ Gamma({alpha_1}, {beta_1})")

# Sample (using mean for demonstration)
lambda1_1 = alpha_1 / beta_1
print(f"Sample λ₁⁽¹⁾ = {lambda1_1:.2f} (using mean for demo)")

# Step 2: Sample λ₂
print("\nStep 2: Sample λ₂ | x, c, λ₁")
sum_after = np.sum(data[c0:])
alpha_2 = alpha + sum_after
beta_2 = beta + (N - c0)

print(f"Sum of data after c = {c0}: {sum_after}")
print(f"α' = {alpha} + {sum_after} = {alpha_2}")
print(f"β' = {beta} + ({N} - {c0}) = {beta_2}")
print(f"λ₂⁽¹⁾ ~ Gamma({alpha_2}, {beta_2})")

# Sample (using mean for demonstration)
lambda2_1 = alpha_2 / beta_2
print(f"Sample λ₂⁽¹⁾ = {lambda2_1:.2f} (using mean for demo)")

# Step 3: Sample c
print("\nStep 3: Sample c | x, λ₁, λ₂")
print(f"Using λ₁ = {lambda1_1:.2f}, λ₂ = {lambda2_1:.2f}")

log_probs = []
for k in range(1, N):
    sum_before_k = np.sum(data[:k])
    sum_after_k = np.sum(data[k:])
    log_p = (sum_before_k * np.log(lambda1_1) - k * lambda1_1 +
             sum_after_k * np.log(lambda2_1) - (N - k) * lambda2_1)
    log_probs.append(log_p)
    print(f"\nc = {k}:")
    print(f"  Sum before: {sum_before_k}, Sum after: {sum_after_k}")
    print(f"  log p ∝ {sum_before_k}×log({lambda1_1:.2f}) - {k}×{lambda1_1:.2f} + "
          f"{sum_after_k}×log({lambda2_1:.2f}) - {N-k}×{lambda2_1:.2f}")
    print(f"  log p = {log_p:.3f}")

# Normalize probabilities
log_probs = np.array(log_probs)
log_probs_shifted = log_probs - np.max(log_probs)
probs = np.exp(log_probs_shifted)
probs = probs / np.sum(probs)

print(f"\nNormalized probabilities:")
for k, p in enumerate(probs, 1):
    print(f"P(c = {k}) = {p:.3f}")

most_likely_c = np.argmax(probs) + 1
print(f"\nMost likely: c⁽¹⁾ = {most_likely_c}")

# 2.3 Posterior Predictive
print("\n\n2.3 Posterior Predictive Distribution")
print("-" * 40)

# Simulate posterior samples (small example)
n_samples = 5
samples = {
    'c': [3, 3, 4, 3, 2],
    'lambda1': [1.9, 2.1, 1.8, 2.0, 2.2],
    'lambda2': [5.8, 6.2, 5.5, 6.0, 5.9]
}

print("Posterior samples (example):")
for i in range(n_samples):
    print(f"Sample {i+1}: c={samples['c'][i]}, "
          f"λ₁={samples['lambda1'][i]:.1f}, λ₂={samples['lambda2'][i]:.1f}")

# Predict for year 7
print("\nPredicting accidents for year 7:")
predictions = []

for i in range(n_samples):
    c_i = samples['c'][i]
    lambda1_i = samples['lambda1'][i]
    lambda2_i = samples['lambda2'][i]
    
    if 7 <= c_i:
        # Before change point
        lambda_7 = lambda1_i
        pred = np.random.poisson(lambda_7)
    else:
        # After change point
        lambda_7 = lambda2_i
        pred = np.random.poisson(lambda_7)
    
    predictions.append(pred)
    print(f"Sample {i+1}: c={c_i}, 7>c: {7>c_i}, "
          f"λ={lambda_7:.1f}, x₇={pred}")

print(f"\nPosterior predictive mean: E[x₇|x] = {np.mean(predictions):.1f}")

# 2.4 Decision Making
print("\n\n2.4 Decision Making")
print("-" * 40)

# Check if accident rate decreased
print("Question: Has the accident rate decreased?")
print("Compute: P(λ₁ > λ₂ | x)")

count_decreased = sum(1 for i in range(n_samples) 
                     if samples['lambda1'][i] > samples['lambda2'][i])
prob_decreased = count_decreased / n_samples

print(f"\nSamples where λ₁ > λ₂: {count_decreased}/{n_samples}")
print(f"P(λ₁ > λ₂ | x) = {prob_decreased:.2f}")

if prob_decreased < 0.5:
    print(f"\nConclusion: Strong evidence that accident rate INCREASED")
else:
    print(f"\nConclusion: Evidence that accident rate DECREASED")

# =============================================================================
# PART 3: SUMMARY AND VISUALIZATION
# =============================================================================
print("\n\nPART 3: SUMMARY AND VISUALIZATION")
print("=" * 60)

# Visualize bimodal distribution
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Bimodal distribution
ax = axes[0, 0]
x_range = np.linspace(-8, 6, 1000)
y_bimodal = p_bimodal(x_range)
ax.plot(x_range, y_bimodal, 'b-', linewidth=2)
ax.axvline(-3, color='r', linestyle='--', alpha=0.5, label='Mode 1')
ax.axvline(1, color='g', linestyle='--', alpha=0.5, label='Mode 2')
ax.axvline(-1, color='k', linestyle='--', alpha=0.5, label='True mean')
ax.set_title('Bimodal Distribution')
ax.set_xlabel('x')
ax.set_ylabel('p(x)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Chain traces
ax = axes[0, 1]
ax.plot(chain1, 'b-', label='Chain 1', marker='o')
ax.plot(chain2, 'r-', label='Chain 2', marker='o')
ax.set_title(f'Chain Traces (R̂={R_hat:.2f})')
ax.set_xlabel('Iteration')
ax.set_ylabel('Sample value')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Change point data
ax = axes[1, 0]
ax.scatter(range(1, N+1), data, s=100, alpha=0.7)
ax.axvline(c0 + 0.5, color='r', linestyle='--', label=f'Initial c={c0}')
ax.set_title('Change Point Data')
ax.set_xlabel('Year')
ax.set_ylabel('Accident count')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Posterior for c
ax = axes[1, 1]
c_values = range(1, N)
ax.bar(c_values, probs, alpha=0.7, color='blue')
ax.set_title('Posterior P(c|x) - One iteration')
ax.set_xlabel('Change point c')
ax.set_ylabel('Probability')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# =============================================================================
# SUMMARY
# =============================================================================
print("\n\n=== SUMMARY ===")
print("=" * 60)

print("\n1. BIMODAL DISTRIBUTION:")
print("   - Purpose: Test MCMC convergence on challenging distribution")
print("   - Problem: Multiple modes can trap chains")
print("   - Diagnostic: R̂ compares within/between chain variance")
print(f"   - Result: R̂ = {R_hat:.2f} indicates {('poor' if R_hat > 1.1 else 'good')} convergence")

print("\n2. CHANGE POINT DETECTION:")
print("   - Purpose: Find when accident rates changed")
print("   - Problem: Unknown change point and rates")
print("   - Method: Gibbs sampling (conditional distributions)")
print("   - Result: Can identify change point and quantify uncertainty")

print("\n3. KEY INSIGHTS:")
print("   - Multiple chains essential for convergence diagnosis")
print("   - R̂ near 1 indicates good mixing")
print("   - ESS accounts for autocorrelation")
print("   - Gibbs sampling efficient for conditional conjugacy")
print("   - Posterior predictive checks model fit")