# Bayesian Inference with Differentiable Models

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/OJWatson/emidm/blob/main/docs/notebooks/bayesian_inference.ipynb)

This tutorial demonstrates how to perform Bayesian inference on epidemiological model parameters using **emidm**'s differentiable models and BlackJAX for Hamiltonian Monte Carlo (HMC) sampling.

## Why Bayesian Inference?

While gradient-based optimization (as shown in the calibration tutorial) gives us point estimates, Bayesian inference provides:

1. **Uncertainty quantification**: Full posterior distributions over parameters
2. **Prior incorporation**: Include domain knowledge about plausible parameter values
3. **Model comparison**: Compare models via marginal likelihoods

## Setup

This tutorial requires JAX and BlackJAX:

```bash
pip install emidm[jax] blackjax
```

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from emidm.diff import DiffConfig, run_diff_sir
from emidm.optim import mse_loss

# Check if blackjax is available
try:
    import blackjax
    HAS_BLACKJAX = True
except ImportError:
    HAS_BLACKJAX = False
    print("BlackJAX not installed. Install with: pip install blackjax")

## Step 1: Generate Synthetic Data

We'll create synthetic epidemic data with known parameters, then try to recover them with uncertainty estimates.

In [None]:
# True parameters
BETA_TRUE = 0.35
GAMMA = 0.1
N_AGENTS = 200
T = 40
I0 = 5

# Generate "observed" data
key = jax.random.PRNGKey(42)
observed = run_diff_sir(
    N_agents=N_AGENTS,
    I0=I0,
    beta=BETA_TRUE,
    gamma=GAMMA,
    T=T,
    config=DiffConfig(tau=0.5, hard=True),
    key=key,
)

observed_I = observed["I"]

# Visualize
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(observed["t"], observed["S"], label="S", color="blue")
ax.plot(observed["t"], observed["I"], label="I", color="red")
ax.plot(observed["t"], observed["R"], label="R", color="green")
ax.set_xlabel("Time (days)")
ax.set_ylabel("Population")
ax.set_title(f"Synthetic Epidemic (β = {BETA_TRUE})")
ax.legend()
ax.grid(alpha=0.3)
plt.show()

## Step 2: Define the Log-Posterior

For Bayesian inference, we need to define:

1. **Likelihood**: How probable is the observed data given parameters?
2. **Prior**: What do we believe about parameters before seeing data?

The posterior is proportional to: `posterior ∝ likelihood × prior`

In [None]:
# Model configuration (fixed during inference)
MODEL_KEY = jax.random.PRNGKey(0)  # Fixed for reproducibility

def log_likelihood(beta):
    """Gaussian log-likelihood for the infection curve."""
    pred = run_diff_sir(
        N_agents=N_AGENTS,
        I0=I0,
        beta=beta,
        gamma=GAMMA,
        T=T,
        config=DiffConfig(tau=0.5, hard=True),
        key=MODEL_KEY,
    )
    
    # Gaussian likelihood with sigma = 10 (observation noise)
    # Using larger sigma to smooth the likelihood surface
    sigma = 10.0
    residuals = pred["I"] - observed_I
    return -0.5 * jnp.sum((residuals / sigma) ** 2)


def log_prior(beta):
    """Log-prior: beta ~ Normal(0.3, 0.15), constrained to [0, 1]."""
    # Gaussian prior centered at 0.3
    prior_mean = 0.3
    prior_std = 0.15
    log_p = -0.5 * ((beta - prior_mean) / prior_std) ** 2
    
    # Soft constraint to keep beta in reasonable range
    # (penalize values outside [0.05, 0.8])
    penalty = jnp.where(beta < 0.05, -100 * (0.05 - beta) ** 2, 0.0)
    penalty += jnp.where(beta > 0.8, -100 * (beta - 0.8) ** 2, 0.0)
    
    return log_p + penalty


def log_posterior(beta):
    """Log-posterior = log-likelihood + log-prior."""
    return log_likelihood(beta) + log_prior(beta)


# Test the log-posterior
test_beta = jnp.array(0.3)
print(f"Log-posterior at β=0.3: {log_posterior(test_beta):.2f}")
print(f"Log-posterior at β=0.35 (true): {log_posterior(jnp.array(BETA_TRUE)):.2f}")

## Step 3: Run MCMC Sampling

We'll use BlackJAX's NUTS (No-U-Turn Sampler) algorithm, which is an adaptive form of Hamiltonian Monte Carlo that automatically tunes the step size and trajectory length.

In [None]:
if HAS_BLACKJAX:
    # Initialize sampler
    rng_key = jax.random.PRNGKey(123)
    initial_position = jnp.array(0.30)  # Starting point
    
    # Set up NUTS with window adaptation for warmup
    # Note: num_steps is passed to .run(), not the constructor (BlackJAX >= 1.0)
    warmup = blackjax.window_adaptation(
        blackjax.nuts,
        log_posterior,
        initial_step_size=0.001,  # Small initial step size for stability
    )
    
    # Run warmup
    print("Running warmup (this may take a minute)...")
    rng_key, warmup_key = jax.random.split(rng_key)
    (state, params), _ = warmup.run(warmup_key, initial_position, num_steps=500)
    print(f"Adapted step size: {params['step_size']:.6f}")
    
    # Set up sampling kernel
    nuts_kernel = blackjax.nuts(log_posterior, **params).step
    
    # Sampling loop
    def one_step(carry, rng_key):
        state = carry
        state, info = nuts_kernel(rng_key, state)
        return state, state.position
    
    # Run sampling
    print("Running sampling...")
    n_samples = 500
    rng_key, sample_key = jax.random.split(rng_key)
    keys = jax.random.split(sample_key, n_samples)
    
    final_state, samples = jax.lax.scan(one_step, state, keys)
    samples = np.array(samples)
    
    print(f"\nCollected {len(samples)} samples")
    print(f"Posterior mean: {samples.mean():.4f}")
    print(f"Posterior std: {samples.std():.4f}")
    print(f"True value: {BETA_TRUE}")
else:
    print("Skipping MCMC - BlackJAX not installed")
    samples = None

## Step 4: Analyze the Posterior

Let's visualize the posterior distribution and check convergence.

In [None]:
if samples is not None:
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    
    # Trace plot
    ax = axes[0]
    ax.plot(samples, alpha=0.7, linewidth=0.5)
    ax.axhline(y=BETA_TRUE, color="red", linestyle="--", label=f"True β = {BETA_TRUE}")
    ax.set_xlabel("Sample")
    ax.set_ylabel("β")
    ax.set_title("Trace Plot")
    ax.legend()
    
    # Histogram
    ax = axes[1]
    ax.hist(samples, bins=40, density=True, alpha=0.7, color="steelblue", edgecolor="white")
    ax.axvline(x=BETA_TRUE, color="red", linestyle="--", linewidth=2, label=f"True β = {BETA_TRUE}")
    ax.axvline(x=samples.mean(), color="green", linestyle="-", linewidth=2, label=f"Mean = {samples.mean():.3f}")
    ax.set_xlabel("β")
    ax.set_ylabel("Density")
    ax.set_title("Posterior Distribution")
    ax.legend()
    
    # Autocorrelation
    ax = axes[2]
    max_lag = 50
    acf = [np.corrcoef(samples[:-lag], samples[lag:])[0, 1] if lag > 0 else 1.0 for lag in range(max_lag)]
    ax.bar(range(max_lag), acf, color="steelblue", alpha=0.7)
    ax.axhline(y=0, color="black", linewidth=0.5)
    ax.set_xlabel("Lag")
    ax.set_ylabel("Autocorrelation")
    ax.set_title("Autocorrelation Function")
    
    plt.tight_layout()
    plt.show()
    
    # Summary statistics
    print("\nPosterior Summary:")
    print(f"  Mean: {samples.mean():.4f}")
    print(f"  Std:  {samples.std():.4f}")
    print(f"  2.5%: {np.percentile(samples, 2.5):.4f}")
    print(f"  97.5%: {np.percentile(samples, 97.5):.4f}")
    print(f"  True:  {BETA_TRUE}")

## Step 5: Posterior Predictive Check

We can use samples from the posterior to generate predictions and visualize uncertainty in the model fit.

In [None]:
if samples is not None:
    # Generate predictions from posterior samples
    n_pred = 100  # Number of posterior samples to use
    sample_indices = np.random.choice(len(samples), n_pred, replace=False)
    
    predictions = []
    for idx in sample_indices:
        beta_sample = samples[idx]
        pred = run_diff_sir(
            N_agents=N_AGENTS,
            I0=I0,
            beta=beta_sample,
            gamma=GAMMA,
            T=T,
            config=DiffConfig(tau=0.5, hard=True),
            key=MODEL_KEY,
        )
        predictions.append(np.array(pred["I"]))
    
    predictions = np.array(predictions)
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 5))
    
    # Posterior predictive interval
    pred_mean = predictions.mean(axis=0)
    pred_lower = np.percentile(predictions, 2.5, axis=0)
    pred_upper = np.percentile(predictions, 97.5, axis=0)
    
    t = np.array(observed["t"])
    ax.fill_between(t, pred_lower, pred_upper, alpha=0.3, color="steelblue", label="95% CI")
    ax.plot(t, pred_mean, color="steelblue", linewidth=2, label="Posterior mean")
    ax.scatter(t, np.array(observed_I), color="black", s=30, zorder=5, label="Observed")
    
    ax.set_xlabel("Time (days)", fontsize=12)
    ax.set_ylabel("Number Infected", fontsize=12)
    ax.set_title("Posterior Predictive Check", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## Summary

In this tutorial, we demonstrated:

1. **Differentiable likelihood**: Using `run_diff_sir` to compute gradients through the epidemic model
2. **Bayesian inference**: Combining likelihood with priors to form a posterior
3. **MCMC sampling**: Using BlackJAX NUTS to sample from the posterior
4. **Posterior analysis**: Visualizing uncertainty and performing predictive checks

The key advantage of differentiable models is that HMC/NUTS can use gradient information to explore the posterior efficiently, which is especially important for:

- **High-dimensional problems**: Many parameters to estimate
- **Complex posteriors**: Multi-modal or correlated parameters
- **Expensive models**: Where efficient exploration is crucial