In [None]:
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import poisson, halfnorm, norm
from scipy.special import factorial

print("=== BAYESIAN POISSON REGRESSION - HAND CALCULATIONS IN PYTHON ===\n")

# =============================================================================
# STEP 1: Poisson and Half-Normal Distributions
# =============================================================================
print("STEP 1: Probability Distributions")
print("-" * 50)

# Poisson PMF
def poisson_pmf(y, mu):
    """Poisson probability mass function"""
    return (mu**y * np.exp(-mu)) / factorial(y)

def log_poisson_pmf(y, mu):
    """Log Poisson PMF"""
    return y * np.log(mu) - mu - np.log(factorial(y))

# Half-normal PDF
def half_normal_pdf(x):
    """Half-normal PDF with scale=1"""
    if x < 0:
        return 0
    return np.sqrt(2/np.pi) * np.exp(-x**2/2)

def log_half_normal_pdf(x):
    """Log half-normal PDF"""
    if x < 0:
        return -np.inf
    return 0.5 * np.log(2/np.pi) - x**2/2

# Normal PDF
def log_normal_pdf(x, mu, sigma2):
    """Log normal PDF"""
    return -0.5 * np.log(2*np.pi*sigma2) - (x-mu)**2/(2*sigma2)

print("Example calculations:")
print(f"Poisson PMF: P(Y=3|μ=2) = {poisson_pmf(3, 2):.4f}")
print(f"Log Poisson: log P(Y=3|μ=2) = {log_poisson_pmf(3, 2):.4f}")
print(f"Half-normal PDF: p(κ=1) = {half_normal_pdf(1):.4f}")
print(f"Log half-normal: log p(κ=1) = {log_half_normal_pdf(1):.4f}")

# =============================================================================
# STEP 2: Small Example Data
# =============================================================================
print("\n\nSTEP 2: Small Example with 3 Data Points")
print("-" * 50)

# Original ages and standardization
age_small = jnp.array([35, 40, 45])
x_small = jnp.array([-1, 0, 1])  # Pre-standardized for simplicity
y_small = jnp.array([3, 4, 5])   # Deaths

print(f"Ages: {age_small}")
print(f"Standardized x: {x_small}")
print(f"Deaths y: {y_small}")

# Design matrix (with intercept)
X_small = jnp.column_stack([jnp.ones(3), x_small])
print(f"\nDesign matrix X:")
print(X_small)

# =============================================================================
# STEP 3: Forward Pass with Example Parameters
# =============================================================================
print("\n\nSTEP 3: Forward Pass Calculation")
print("-" * 50)

# Example parameters
w = jnp.array([1.0, 0.5])  # w_0 = 1.0, w_1 = 0.5
kappa = 1.0

print(f"Parameters:")
print(f"  w = {w} (intercept={w[0]}, slope={w[1]})")
print(f"  κ = {kappa}")

# Step 3.1: Compute linear predictor f = Xw
print("\nStep 3.1: Linear predictor f = Xw")
f = X_small @ w
for i in range(3):
    f_i = X_small[i] @ w
    print(f"  f_{i} = {X_small[i]} · {w} = {f_i:.3f}")
print(f"  f = {f}")

# Step 3.2: Compute mean μ = exp(f)
print("\nStep 3.2: Mean μ = exp(f)")
mu = jnp.exp(f)
for i in range(3):
    print(f"  μ_{i} = exp({f[i]:.3f}) = {mu[i]:.3f}")
print(f"  μ = {mu}")

# =============================================================================
# STEP 4: Log Likelihood Calculation
# =============================================================================
print("\n\nSTEP 4: Log Likelihood Calculation")
print("-" * 50)

print("log p(y|w) = Σ_n [y_n log(μ_n) - μ_n - log(y_n!)]")
print("\nFor each data point:")

log_lik_terms = []
for i in range(3):
    y_i = y_small[i]
    mu_i = mu[i]
    log_y_fact = np.log(factorial(y_i))
    
    term1 = y_i * np.log(mu_i)
    term2 = -mu_i
    term3 = -log_y_fact
    total = term1 + term2 + term3
    
    print(f"\nn={i}: y={y_i}, μ={mu_i:.3f}")
    print(f"  {y_i} × log({mu_i:.3f}) = {term1:.3f}")
    print(f"  -μ = {term2:.3f}")
    print(f"  -log({y_i}!) = {term3:.3f}")
    print(f"  Subtotal = {total:.3f}")
    
    log_lik_terms.append(total)

log_likelihood = sum(log_lik_terms)
print(f"\nTotal log likelihood = {log_likelihood:.3f}")

# Verify with scipy
log_lik_scipy = sum(poisson.logpmf(y_small, mu))
print(f"Verification (scipy): {log_lik_scipy:.3f}")

# =============================================================================
# STEP 5: Log Prior Calculation
# =============================================================================
print("\n\nSTEP 5: Log Prior on Weights")
print("-" * 50)

print("log p(w|κ) = log N(w|0, κ²I)")
print("         = -D/2 log(2π) - D log(κ) - ||w||²/(2κ²)")

D = 2  # dimension
term1 = -D/2 * np.log(2*np.pi)
term2 = -D * np.log(kappa)
term3 = -np.sum(w**2)/(2*kappa**2)

print(f"\nD = {D}")
print(f"  -{D}/2 × log(2π) = {term1:.3f}")
print(f"  -{D} × log({kappa}) = {term2:.3f}")
print(f"  -||w||²/(2κ²) = -(({w[0]}² + {w[1]}²))/(2×{kappa}²) = {term3:.3f}")

log_prior_w = term1 + term2 + term3
print(f"\nTotal log prior on w = {log_prior_w:.3f}")

# =============================================================================
# STEP 6: Log Hyperprior Calculation
# =============================================================================
print("\n\nSTEP 6: Log Hyperprior on κ")
print("-" * 50)

print("log p(κ) = log N_+(κ|0,1)")
print("        = 0.5 log(2/π) - κ²/2")

term1_hyper = 0.5 * np.log(2/np.pi)
term2_hyper = -kappa**2/2

print(f"\n  0.5 × log(2/π) = {term1_hyper:.3f}")
print(f"  -κ²/2 = -{kappa}²/2 = {term2_hyper:.3f}")

log_hyperprior = term1_hyper + term2_hyper
print(f"\nTotal log hyperprior = {log_hyperprior:.3f}")

# =============================================================================
# STEP 7: Total Log Joint
# =============================================================================
print("\n\nSTEP 7: Total Log Joint Distribution")
print("-" * 50)

print("log p(y, w, κ) = log p(y|w) + log p(w|κ) + log p(κ)")

log_joint = log_likelihood + log_prior_w + log_hyperprior

print(f"\n  log p(y|w) = {log_likelihood:.3f}")
print(f"  log p(w|κ) = {log_prior_w:.3f}")
print(f"  log p(κ) = {log_hyperprior:.3f}")
print(f"\nTotal = {log_joint:.3f}")

# =============================================================================
# STEP 8: Metropolis-Hastings Step
# =============================================================================
print("\n\nSTEP 8: Metropolis-Hastings Example Step")
print("-" * 50)

# Current state
theta_current = jnp.array([1.0, 0.5, 1.0])  # [w_0, w_1, κ]
print(f"Current θ = {theta_current}")

# Proposal
step_size = 0.1
noise = jnp.array([0.1, -0.1, -0.05])  # Example noise
theta_proposed = theta_current + step_size * noise
print(f"Proposed θ' = {theta_proposed}")

# Compute log joint for proposed state
w_prop = theta_proposed[:2]
kappa_prop = theta_proposed[2]

# Check constraint
if kappa_prop > 0:
    # Compute components for proposed state
    f_prop = X_small @ w_prop
    mu_prop = jnp.exp(f_prop)
    
    log_lik_prop = jnp.sum(poisson.logpmf(y_small, mu_prop))
    log_prior_w_prop = jnp.sum(log_normal_pdf(w_prop, 0, kappa_prop**2))
    log_hyperprior_prop = log_half_normal_pdf(kappa_prop)
    log_joint_prop = log_lik_prop + log_prior_w_prop + log_hyperprior_prop
    
    print(f"\nProposed log joint = {log_joint_prop:.3f}")
    print(f"Current log joint = {log_joint:.3f}")
    
    # Acceptance ratio (log scale)
    log_ratio = log_joint_prop - log_joint
    accept_prob = min(1, np.exp(log_ratio))
    
    print(f"\nLog ratio = {log_ratio:.3f}")
    print(f"Acceptance probability = {accept_prob:.3f}")
    
    # Accept/reject
    u = np.random.rand()
    print(f"Random u = {u:.3f}")
    if u < accept_prob:
        print("ACCEPT")
    else:
        print("REJECT")
else:
    print("\nREJECT (κ' < 0)")

# =============================================================================
# STEP 9: Posterior Predictive Calculation
# =============================================================================
print("\n\nSTEP 9: Posterior Predictive Example")
print("-" * 50)

# Suppose we have 3 posterior samples
w_samples = jnp.array([[2.0, 0.3],
                       [1.8, 0.4],
                       [2.1, 0.2]])
kappa_samples = jnp.array([0.8, 0.9, 0.7])

print("Posterior samples:")
for i in range(3):
    print(f"  Sample {i+1}: w={w_samples[i]}, κ={kappa_samples[i]}")

# Prediction at age 75
age_75 = 75
x_75_standardized = 2.5  # Pre-computed for example
X_75 = jnp.array([1, x_75_standardized])

print(f"\nPrediction at age {age_75} (x*={x_75_standardized}):")
print(f"X* = {X_75}")

# For each posterior sample
f_star_samples = []
mu_star_samples = []
y_star_samples = []

for i in range(3):
    w_i = w_samples[i]
    
    # Linear predictor
    f_star = X_75 @ w_i
    f_star_samples.append(f_star)
    
    # Mean
    mu_star = jnp.exp(f_star)
    mu_star_samples.append(mu_star)
    
    # Sample from Poisson
    y_star = np.random.poisson(mu_star)
    y_star_samples.append(y_star)
    
    print(f"\nSample {i+1}:")
    print(f"  f* = {X_75} · {w_i} = {f_star:.3f}")
    print(f"  μ* = exp({f_star:.3f}) = {mu_star:.3f}")
    print(f"  y* ~ Poisson({mu_star:.3f}) → {y_star}")

# Summary statistics
print(f"\nPosterior predictive summary:")
print(f"  E[f*|y] = {np.mean(f_star_samples):.3f}")
print(f"  E[μ*|y] = {np.mean(mu_star_samples):.3f}")
print(f"  E[y*|y] = {np.mean(y_star_samples):.1f}")
print(f"  SD[y*|y] = {np.std(y_star_samples):.1f}")

# =============================================================================
# STEP 10: Visualization of Calculations
# =============================================================================
print("\n\nSTEP 10: Visualization")
print("-" * 50)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Data and model fit
ax = axes[0, 0]
ax.plot(age_small, y_small, 'ko', markersize=10, label='Data')
age_range = np.linspace(30, 50, 100)
x_range = (age_range - 40) / 5  # Simple standardization
X_range = np.column_stack([np.ones_like(x_range), x_range])
mu_range = np.exp(X_range @ w)
ax.plot(age_range, mu_range, 'b-', label=f'μ = exp({w[0]:.1f} + {w[1]:.1f}x)')
ax.set_xlabel('Age')
ax.set_ylabel('Deaths')
ax.set_title('Data and Model Fit')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Log-likelihood surface
ax = axes[0, 1]
w0_range = np.linspace(0.5, 1.5, 50)
w1_range = np.linspace(0, 1, 50)
W0, W1 = np.meshgrid(w0_range, w1_range)
LL = np.zeros_like(W0)

for i in range(len(w0_range)):
    for j in range(len(w1_range)):
        w_ij = np.array([W0[j,i], W1[j,i]])
        f_ij = X_small @ w_ij
        mu_ij = np.exp(f_ij)
        LL[j,i] = np.sum(poisson.logpmf(y_small, mu_ij))

contour = ax.contour(W0, W1, LL, levels=20)
ax.clabel(contour, inline=True, fontsize=8)
ax.plot(w[0], w[1], 'r*', markersize=15, label='Current w')
ax.set_xlabel('$w_0$')
ax.set_ylabel('$w_1$')
ax.set_title('Log-Likelihood Surface')
ax.legend()

# Plot 3: Prior distributions
ax = axes[1, 0]
kappa_range = np.linspace(0, 3, 100)
prior_kappa = [half_normal_pdf(k) for k in kappa_range]
ax.plot(kappa_range, prior_kappa, 'g-', linewidth=2, label='Prior p(κ)')
ax.axvline(kappa, color='r', linestyle='--', label=f'Current κ={kappa}')
ax.set_xlabel('κ')
ax.set_ylabel('Density')
ax.set_title('Hyperprior Distribution')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Components of log joint
ax = axes[1, 1]
components = ['Log Likelihood', 'Log Prior', 'Log Hyperprior', 'Total']
values = [log_likelihood, log_prior_w, log_hyperprior, log_joint]
colors = ['blue', 'green', 'orange', 'red']
bars = ax.bar(components, values, color=colors, alpha=0.7)
ax.set_ylabel('Log Probability')
ax.set_title('Components of Log Joint Distribution')
for bar, val in zip(bars, values):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.1 * np.sign(height),
            f'{val:.3f}', ha='center', va='bottom' if val > 0 else 'top')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# =============================================================================
# SUMMARY
# =============================================================================
print("\n\n=== SUMMARY OF HAND CALCULATIONS ===")
print("-" * 50)
print("We computed step-by-step:")
print("1. Linear predictor: f = Xw")
print("2. Mean parameter: μ = exp(f)")
print("3. Log likelihood: Σ[y log(μ) - μ - log(y!)]")
print("4. Log prior on weights: Normal distribution")
print("5. Log hyperprior on κ: Half-normal distribution")
print("6. Total log joint: sum of all components")
print("7. Metropolis-Hastings acceptance probability")
print("8. Posterior predictive sampling")
print("\nAll calculations shown manually with verification!")