# Topic 6: Bayesian Hierarchical Models and Missing Data

## Learning Objectives
- Understand hierarchical model structure and benefits
- Implement partial pooling vs complete pooling
- Handle missing data in Bayesian framework
- Apply model checking for hierarchical models

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pymc as pm
import arviz as az

plt.style.use('seaborn-v0_8')
np.random.seed(42)

## 1. Hierarchical Model Example: School Performance

### Model Structure:
- **Level 1**: Student scores within schools
- **Level 2**: School-specific parameters
- **Level 3**: Population-level hyperparameters

$$y_{ij} \sim N(\mu_j, \sigma^2)$$
$$\mu_j \sim N(\alpha, \tau^2)$$

In [None]:
# Generate hierarchical data
n_schools = 8
n_students_per_school = [20, 15, 25, 18, 22, 16, 19, 21]
true_alpha = 75  # Overall mean
true_tau = 8     # Between-school SD
true_sigma = 12  # Within-school SD

# Generate school means
true_school_means = np.random.normal(true_alpha, true_tau, n_schools)

# Generate student scores
schools = []
scores = []
school_names = [f'School_{i+1}' for i in range(n_schools)]

for i, (n_students, school_mean) in enumerate(zip(n_students_per_school, true_school_means)):
    school_scores = np.random.normal(school_mean, true_sigma, n_students)
    schools.extend([i] * n_students)
    scores.extend(school_scores)

schools = np.array(schools)
scores = np.array(scores)

print(f"Generated data: {len(scores)} students across {n_schools} schools")
print(f"True parameters: α={true_alpha}, τ={true_tau}, σ={true_sigma}")

# Hierarchical Model
with pm.Model() as hierarchical_model:
    # Hyperpriors
    alpha = pm.Normal('alpha', 70, 10)  # Overall mean
    tau = pm.HalfNormal('tau', 10)      # Between-school SD
    sigma = pm.HalfNormal('sigma', 10)  # Within-school SD
    
    # School-specific means
    mu_school = pm.Normal('mu_school', alpha, tau, shape=n_schools)
    
    # Likelihood
    y_obs = pm.Normal('y_obs', mu_school[schools], sigma, observed=scores)
    
    # Sample
    trace_hier = pm.sample(2000, return_inferencedata=True, random_seed=42)

print("\nHierarchical Model Results:")
print(az.summary(trace_hier, var_names=['alpha', 'tau', 'sigma']))

# Compare with pooled and unpooled models
# Complete pooling (ignore schools)
with pm.Model() as pooled_model:
    mu_pooled = pm.Normal('mu_pooled', 70, 10)
    sigma_pooled = pm.HalfNormal('sigma_pooled', 10)
    y_obs = pm.Normal('y_obs', mu_pooled, sigma_pooled, observed=scores)
    trace_pooled = pm.sample(1000, return_inferencedata=True, random_seed=42)

# No pooling (separate analysis for each school)
unpooled_means = []
for i in range(n_schools):
    school_scores = scores[schools == i]
    unpooled_means.append(np.mean(school_scores))

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Raw data
school_data = [scores[schools == i] for i in range(n_schools)]
axes[0,0].boxplot(school_data, labels=school_names)
axes[0,0].set_ylabel('Test Scores')
axes[0,0].set_title('Raw Data by School')
axes[0,0].tick_params(axis='x', rotation=45)
axes[0,0].grid(True, alpha=0.3)

# Compare estimates
hier_means = trace_hier.posterior['mu_school'].mean(dim=['chain', 'draw']).values
pooled_mean = trace_pooled.posterior['mu_pooled'].mean().values

x_pos = np.arange(n_schools)
width = 0.25

axes[0,1].bar(x_pos - width, true_school_means, width, label='True', alpha=0.7)
axes[0,1].bar(x_pos, unpooled_means, width, label='No Pooling', alpha=0.7)
axes[0,1].bar(x_pos + width, hier_means, width, label='Partial Pooling', alpha=0.7)
axes[0,1].axhline(pooled_mean, color='red', linestyle='--', label='Complete Pooling')

axes[0,1].set_xlabel('School')
axes[0,1].set_ylabel('Mean Score')
axes[0,1].set_title('Comparison of Estimates')
axes[0,1].set_xticks(x_pos)
axes[0,1].set_xticklabels(school_names, rotation=45)
axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)

# Shrinkage plot
axes[1,0].scatter(unpooled_means, hier_means, s=60, alpha=0.7)
axes[1,0].plot([60, 90], [60, 90], 'k--', alpha=0.5)
axes[1,0].axhline(trace_hier.posterior['alpha'].mean().values, color='red', 
                 linestyle=':', label='Population mean')
axes[1,0].set_xlabel('No Pooling Estimate')
axes[1,0].set_ylabel('Partial Pooling Estimate')
axes[1,0].set_title('Shrinkage Effect')
axes[1,0].legend()
axes[1,0].grid(True, alpha=0.3)

# Posterior distributions
az.plot_forest(trace_hier, var_names=['mu_school'], ax=axes[1,1])
axes[1,1].set_title('School-Specific Means (95% CI)')

plt.tight_layout()
plt.show()

# Calculate shrinkage
shrinkage = 1 - (hier_means - trace_hier.posterior['alpha'].mean().values) / \
                (np.array(unpooled_means) - trace_hier.posterior['alpha'].mean().values)

print("\nShrinkage Analysis:")
for i, (name, shr) in enumerate(zip(school_names, shrinkage)):
    n_students = n_students_per_school[i]
    print(f"{name}: {shr:.3f} (n={n_students})")

## 2. Missing Data Handling

Bayesian methods naturally handle missing data through imputation.

In [None]:
# Create dataset with missing values
np.random.seed(42)
n = 100
x_complete = np.random.normal(0, 1, n)
y_complete = 2 + 3 * x_complete + np.random.normal(0, 1, n)

# Introduce missing values (MCAR - Missing Completely at Random)
missing_prob = 0.3
x_missing_mask = np.random.random(n) < missing_prob
y_missing_mask = np.random.random(n) < missing_prob

x_observed = x_complete.copy()
y_observed = y_complete.copy()
x_observed[x_missing_mask] = np.nan
y_observed[y_missing_mask] = np.nan

# Keep only cases with at least one observed value
complete_missing = x_missing_mask & y_missing_mask
keep_mask = ~complete_missing

x_obs = x_observed[keep_mask]
y_obs = y_observed[keep_mask]
n_keep = np.sum(keep_mask)

print(f"Original data: {n} observations")
print(f"After removing completely missing: {n_keep} observations")
print(f"Missing x values: {np.sum(np.isnan(x_obs))}")
print(f"Missing y values: {np.sum(np.isnan(y_obs))}")

# Bayesian model with missing data
with pm.Model() as missing_data_model:
    # Priors for regression parameters
    alpha = pm.Normal('alpha', 0, 10)
    beta = pm.Normal('beta', 0, 10)
    sigma = pm.HalfNormal('sigma', 2)
    
    # Priors for missing data parameters
    mu_x = pm.Normal('mu_x', 0, 2)
    sigma_x = pm.HalfNormal('sigma_x', 2)
    
    # Impute missing x values
    x_imputed = pm.Normal('x_imputed', mu_x, sigma_x, 
                         shape=n_keep, observed=x_obs)
    
    # Regression model
    mu_y = alpha + beta * x_imputed
    
    # Impute missing y values
    y_imputed = pm.Normal('y_imputed', mu_y, sigma, 
                         shape=n_keep, observed=y_obs)
    
    # Sample
    trace_missing = pm.sample(2000, return_inferencedata=True, random_seed=42)

print("\nMissing Data Model Results:")
print(az.summary(trace_missing, var_names=['alpha', 'beta', 'sigma', 'mu_x', 'sigma_x']))

# Compare with complete case analysis
complete_cases = ~(np.isnan(x_obs) | np.isnan(y_obs))
x_complete_cases = x_obs[complete_cases]
y_complete_cases = y_obs[complete_cases]

print(f"\nComplete case analysis: {np.sum(complete_cases)} observations")
print(f"True parameters: α=2, β=3, σ=1")

# Simple linear regression on complete cases
from scipy import stats as scipy_stats
slope, intercept, r_value, p_value, std_err = scipy_stats.linregress(x_complete_cases, y_complete_cases)
print(f"Complete case estimates: α={intercept:.3f}, β={slope:.3f}")

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Missing data pattern
missing_pattern = pd.DataFrame({
    'x': ~np.isnan(x_obs),
    'y': ~np.isnan(y_obs)
})

pattern_counts = missing_pattern.value_counts().sort_index()
pattern_labels = [f'x:{x}, y:{y}' for (x, y) in pattern_counts.index]

axes[0,0].bar(range(len(pattern_counts)), pattern_counts.values)
axes[0,0].set_xticks(range(len(pattern_counts)))
axes[0,0].set_xticklabels(pattern_labels, rotation=45)
axes[0,0].set_ylabel('Count')
axes[0,0].set_title('Missing Data Patterns')
axes[0,0].grid(True, alpha=0.3)

# Observed vs imputed
x_imputed_samples = trace_missing.posterior['x_imputed'].values.reshape(-1, n_keep)
y_imputed_samples = trace_missing.posterior['y_imputed'].values.reshape(-1, n_keep)

x_imputed_mean = np.nanmean(x_imputed_samples, axis=0)
y_imputed_mean = np.nanmean(y_imputed_samples, axis=0)

# Plot observed data
obs_mask = ~np.isnan(x_obs) & ~np.isnan(y_obs)
axes[0,1].scatter(x_obs[obs_mask], y_obs[obs_mask], alpha=0.7, label='Observed', s=30)

# Plot imputed data
imp_mask = np.isnan(x_obs) | np.isnan(y_obs)
if np.any(imp_mask):
    axes[0,1].scatter(x_imputed_mean[imp_mask], y_imputed_mean[imp_mask], 
                     alpha=0.7, label='Imputed', s=30, marker='s')

# Regression lines
x_plot = np.linspace(-3, 3, 100)
alpha_mean = trace_missing.posterior['alpha'].mean().values
beta_mean = trace_missing.posterior['beta'].mean().values

axes[0,1].plot(x_plot, alpha_mean + beta_mean * x_plot, 'r-', 
              label='Bayesian (with imputation)', linewidth=2)
axes[0,1].plot(x_plot, intercept + slope * x_plot, 'g--', 
              label='Complete case analysis', linewidth=2)
axes[0,1].plot(x_plot, 2 + 3 * x_plot, 'k:', 
              label='True relationship', linewidth=2)

axes[0,1].set_xlabel('x')
axes[0,1].set_ylabel('y')
axes[0,1].set_title('Observed vs Imputed Data')
axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)

# Posterior distributions
az.plot_posterior(trace_missing, var_names=['alpha', 'beta'], ax=axes[1,0])
axes[1,0].set_title('Posterior Distributions')

# Imputation uncertainty
if np.any(np.isnan(x_obs)):
    missing_x_idx = np.where(np.isnan(x_obs))[0]
    if len(missing_x_idx) > 0:
        idx = missing_x_idx[0]  # Show first missing x
        x_samples = x_imputed_samples[:, idx]
        axes[1,1].hist(x_samples, bins=30, alpha=0.7, density=True)
        axes[1,1].axvline(x_complete[keep_mask][idx], color='red', linestyle='--', 
                         label='True value')
        axes[1,1].set_xlabel('Imputed x value')
        axes[1,1].set_ylabel('Density')
        axes[1,1].set_title(f'Imputation Uncertainty (obs {idx})')
        axes[1,1].legend()
        axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Key Takeaways

### Hierarchical Models:
- **Partial pooling** balances individual and group information
- **Shrinkage** toward group mean depends on sample size
- **Borrowing strength** improves estimates for small groups
- **Natural regularization** prevents overfitting

### Missing Data:
- **Bayesian imputation** naturally quantifies uncertainty
- **Joint modeling** of missing data mechanism and outcome
- **Multiple imputation** through posterior sampling
- **Assumptions matter**: MCAR, MAR, MNAR

## Next: Topic 7 - MCMC Methods