# Bayesian Inference for Time-Varying R(t) with the SAFIR Model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/OJWatson/emidm/blob/main/docs/notebooks/safir_inference.ipynb)

This notebook demonstrates how to:

1. Set up a realistic age-structured SAFIR model with UK demographics
2. Simulate a COVID-19-like epidemic with time-varying R(t) matching the UK 2020 pattern
3. Use Bayesian inference to estimate the time-varying R(t) from death data

---

**Authors:** Oliver (OJ) Watson ©. 2025. MIT Licence 2.0. <br>
**Affiliation:** MRC-GIDA, School of Public Health, Imperial College London

## 1. Setup and Installation

In [None]:
# Install emidm with JAX and Bayesian inference support (uncomment if needed)
# !pip install "emidm[jax,bayes] @ git+https://github.com/OJWatson/emidm.git"

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# JAX imports
import jax
import jax.numpy as jnp

# emidm imports
from emidm import run_diff_safir, to_dataframe, poisson_nll

# Set random seed for reproducibility
np.random.seed(42)

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 2. UK Demographics and Contact Matrix

We use a simplified 8 age-group structure based on UK demographics:
- 0-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+

The contact matrix is derived from Prem et al. (2017, 2021) and represents average daily contacts between age groups in the UK.

In [None]:
# UK population by age group (in thousands, 2020 estimates)
# Age groups: 0-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+
uk_population = np.array([
    8000,   # 0-9
    7500,   # 10-19
    8500,   # 20-29
    8800,   # 30-39
    8200,   # 40-49
    9000,   # 50-59
    7200,   # 60-69
    9800,   # 70+
]) * 1000  # Convert to actual population

# For computational efficiency, we'll scale down the population
# while maintaining the age structure
scale_factor = 100  # 1:100 scaling
population = (uk_population / scale_factor).astype(int)

print(f"Total population (scaled): {population.sum():,}")
print(f"Population by age group: {population}")

In [None]:
# UK contact matrix (average daily contacts between age groups)
# Based on Prem et al. (2017) - simplified to 8 age groups
# Rows: contactor age group, Columns: contactee age group
uk_contact_matrix = np.array([
    # 0-9   10-19  20-29  30-39  40-49  50-59  60-69  70+
    [4.50,  1.20,  0.80,  1.50,  1.20,  0.60,  0.40,  0.30],  # 0-9
    [1.20,  8.50,  2.00,  1.00,  1.50,  1.00,  0.50,  0.30],  # 10-19
    [0.80,  2.00,  5.50,  2.50,  1.50,  1.00,  0.60,  0.40],  # 20-29
    [1.50,  1.00,  2.50,  4.00,  2.00,  1.20,  0.70,  0.50],  # 30-39
    [1.20,  1.50,  1.50,  2.00,  3.50,  1.80,  0.90,  0.60],  # 40-49
    [0.60,  1.00,  1.00,  1.20,  1.80,  3.00,  1.50,  0.80],  # 50-59
    [0.40,  0.50,  0.60,  0.70,  0.90,  1.50,  2.50,  1.20],  # 60-69
    [0.30,  0.30,  0.40,  0.50,  0.60,  0.80,  1.20,  2.00],  # 70+
])

# Make the matrix symmetric (reciprocal contacts)
# C_ij * N_i = C_ji * N_j (reciprocity)
contact_matrix = (uk_contact_matrix + uk_contact_matrix.T) / 2

print("Contact matrix shape:", contact_matrix.shape)
print("\nContact matrix:")
print(np.round(contact_matrix, 2))

In [None]:
# Visualize the contact matrix
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(contact_matrix, cmap='YlOrRd')
ax.set_xticks(range(8))
ax.set_yticks(range(8))
age_labels = ['0-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', '70+']
ax.set_xticklabels(age_labels, rotation=45)
ax.set_yticklabels(age_labels)
ax.set_xlabel('Age of contact')
ax.set_ylabel('Age of individual')
ax.set_title('UK Contact Matrix (daily contacts)')
plt.colorbar(im, label='Average daily contacts')
plt.tight_layout()
plt.show()

## 3. Simulating the UK 2020 Epidemic

We'll create a time-varying R(t) that roughly matches the UK epidemic in 2020:

1. **March-April**: Initial exponential growth (R ≈ 2.5-3.0)
2. **Late March**: First lockdown begins, R drops sharply
3. **April-June**: Lockdown period (R ≈ 0.7-0.9)
4. **July-August**: Gradual reopening (R ≈ 0.9-1.1)
5. **September-October**: Schools reopen, R increases (R ≈ 1.2-1.4)
6. **November**: Second wave acceleration (R ≈ 1.1-1.3)
7. **December**: Tier restrictions and Alpha variant emergence (R ≈ 1.3-1.5)

In [None]:
def create_uk_2020_rt(T=300):
    """Create a time-varying R(t) mimicking UK 2020 epidemic.
    
    Parameters
    ----------
    T : int
        Number of days (starting from ~March 1, 2020)
    
    Returns
    -------
    np.ndarray
        R(t) values for each day
    """
    R_t = np.zeros(T + 1)
    
    for t in range(T + 1):
        if t < 20:  # Early March: exponential growth
            R_t[t] = 2.8
        elif t < 30:  # Mid-March: awareness, some behavior change
            R_t[t] = 2.8 - (t - 20) * 0.1  # 2.8 -> 1.8
        elif t < 45:  # Late March - Early April: lockdown kicks in
            R_t[t] = 1.8 - (t - 30) * 0.07  # 1.8 -> 0.75
        elif t < 90:  # April - May: strict lockdown
            R_t[t] = 0.75 + 0.05 * np.sin((t - 45) * 0.1)  # ~0.7-0.8
        elif t < 120:  # June: gradual easing
            R_t[t] = 0.8 + (t - 90) * 0.01  # 0.8 -> 1.1
        elif t < 150:  # July: summer, outdoor activities
            R_t[t] = 1.0 + 0.1 * np.sin((t - 120) * 0.15)  # ~0.9-1.1
        elif t < 180:  # August: "Eat Out to Help Out", travel
            R_t[t] = 1.1 + (t - 150) * 0.005  # 1.1 -> 1.25
        elif t < 210:  # September: schools reopen
            R_t[t] = 1.25 + (t - 180) * 0.008  # 1.25 -> 1.5
        elif t < 240:  # October: second wave building
            R_t[t] = 1.4 - (t - 210) * 0.005  # Tier restrictions: 1.4 -> 1.25
        elif t < 270:  # November: second lockdown
            R_t[t] = 1.25 - (t - 240) * 0.015  # 1.25 -> 0.8
        else:  # December: Alpha variant, Christmas mixing
            R_t[t] = 0.8 + (t - 270) * 0.02  # 0.8 -> 1.4+
    
    return R_t

# Create R(t) for ~10 months (March - December 2020)
T = 300  # days
true_R_t = create_uk_2020_rt(T)

# Plot the true R(t)
fig, ax = plt.subplots(figsize=(12, 4))
days = np.arange(T + 1)
ax.plot(days, true_R_t, 'b-', linewidth=2)
ax.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='R=1 threshold')
ax.fill_between(days, 0, true_R_t, where=true_R_t > 1, alpha=0.3, color='red', label='R > 1')
ax.fill_between(days, 0, true_R_t, where=true_R_t <= 1, alpha=0.3, color='green', label='R ≤ 1')
ax.set_xlabel('Days since March 1, 2020')
ax.set_ylabel('R(t)')
ax.set_title('True Time-Varying Reproduction Number (UK 2020 Pattern)')
ax.legend()
ax.set_xlim(0, T)
ax.set_ylim(0, 3.5)

# Add month labels
months = ['Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
month_days = [0, 31, 61, 92, 122, 153, 184, 214, 245, 275]
ax.set_xticks(month_days)
ax.set_xticklabels(months)
plt.tight_layout()
plt.show()

In [None]:
# Run the SAFIR model with time-varying R(t)
print("Running SAFIR simulation...")

result = run_diff_safir(
    population=population,
    contact_matrix=contact_matrix,
    R0=true_R_t[0],  # Initial R0
    R_t=true_R_t,    # Time-varying R(t)
    T=T,
    dt=0.25,         # Sub-daily timestep
    seed=42,
    tau=0.1,
    hard=True,
    n_seed=50,       # Initial infections
)

print(f"Simulation complete. Total deaths: {result['D'][-1]:.0f}")

In [None]:
# Plot the epidemic trajectory
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

t = np.array(result['t'])

# Infections
ax = axes[0, 0]
ax.plot(t, result['I'], 'orange', linewidth=2)
ax.set_ylabel('Infectious (I)')
ax.set_title('Active Infections')
ax.set_xlim(0, T)

# Deaths (cumulative)
ax = axes[0, 1]
ax.plot(t, result['D'], 'red', linewidth=2)
ax.set_ylabel('Cumulative Deaths (D)')
ax.set_title('Cumulative Deaths')
ax.set_xlim(0, T)

# Daily deaths (what we'll fit to)
ax = axes[1, 0]
daily_deaths = np.diff(np.array(result['D']), prepend=0)
daily_deaths = np.maximum(daily_deaths, 0)  # Ensure non-negative
ax.bar(t, daily_deaths, color='red', alpha=0.7, width=1)
ax.set_xlabel('Days')
ax.set_ylabel('Daily Deaths')
ax.set_title('Daily Deaths (Observed Data)')
ax.set_xlim(0, T)

# R(t) for reference
ax = axes[1, 1]
ax.plot(t, true_R_t, 'b-', linewidth=2)
ax.axhline(y=1.0, color='r', linestyle='--', alpha=0.5)
ax.set_xlabel('Days')
ax.set_ylabel('R(t)')
ax.set_title('True R(t)')
ax.set_xlim(0, T)
ax.set_ylim(0, 3.5)

# Add month labels to all plots
for ax in axes.flat:
    ax.set_xticks(month_days)
    ax.set_xticklabels(months)

plt.tight_layout()
plt.show()

## 4. Bayesian Inference Setup

We'll estimate R(t) using a **piecewise constant** parameterization with 2-week intervals. This gives us:
- ~21 R(t) values to estimate for 300 days
- A random walk prior to encourage smoothness

### Model Structure

**Likelihood:**
$$D_t \sim \text{Poisson}(\hat{D}_t(R_{1:K}))$$

where $\hat{D}_t$ is the predicted deaths from the SAFIR model.

**Prior (Random Walk):**
$$\log(R_1) \sim \mathcal{N}(\log(2.5), 0.5^2)$$
$$\log(R_k) - \log(R_{k-1}) \sim \mathcal{N}(0, \sigma_{RW}^2)$$

The random walk prior encourages smoothness while allowing flexibility.

In [None]:
# Store the "observed" daily deaths
observed_deaths = np.array(daily_deaths)

# Define the inference window (2-week intervals)
interval_days = 14
n_intervals = (T + 1) // interval_days + 1

print(f"Number of R(t) parameters to estimate: {n_intervals}")
print(f"Interval length: {interval_days} days")

In [None]:
def expand_R_intervals(R_intervals, T, interval_days=14):
    """Expand interval R values to daily R(t) array.
    
    Parameters
    ----------
    R_intervals : array
        R values for each interval
    T : int
        Total number of days
    interval_days : int
        Days per interval
    
    Returns
    -------
    array
        Daily R(t) values (length T+1)
    """
    R_daily = jnp.zeros(T + 1)
    for i, R_val in enumerate(R_intervals):
        start = i * interval_days
        end = min((i + 1) * interval_days, T + 1)
        R_daily = R_daily.at[start:end].set(R_val)
    return R_daily

# Test the expansion
test_R = jnp.ones(n_intervals) * 1.5
test_expanded = expand_R_intervals(test_R, T, interval_days)
print(f"Expanded R(t) shape: {test_expanded.shape}")

In [None]:
def log_prior(log_R_intervals, sigma_rw=0.3):
    """Random walk prior on log(R) values.
    
    Parameters
    ----------
    log_R_intervals : array
        Log R values for each interval
    sigma_rw : float
        Standard deviation of random walk increments
    
    Returns
    -------
    float
        Log prior probability
    """
    # Prior on first R value: log(R_1) ~ N(log(2.5), 0.5^2)
    log_prior_first = -0.5 * ((log_R_intervals[0] - jnp.log(2.5)) / 0.5) ** 2
    
    # Random walk prior on increments
    increments = jnp.diff(log_R_intervals)
    log_prior_rw = -0.5 * jnp.sum((increments / sigma_rw) ** 2)
    
    # Soft constraint: R should be positive and reasonable (0.1 to 5)
    R_intervals = jnp.exp(log_R_intervals)
    penalty = -100.0 * jnp.sum(jnp.where(R_intervals < 0.1, 1.0, 0.0))
    penalty += -100.0 * jnp.sum(jnp.where(R_intervals > 5.0, 1.0, 0.0))
    
    return log_prior_first + log_prior_rw + penalty

In [None]:
def log_likelihood(log_R_intervals, observed_deaths, population, contact_matrix, T, interval_days, seed=0):
    """Compute log-likelihood of observed deaths given R(t).
    
    Uses Poisson likelihood for death counts.
    """
    # Expand to daily R(t)
    R_intervals = jnp.exp(log_R_intervals)
    R_t = expand_R_intervals(R_intervals, T, interval_days)
    
    # Run model
    result = run_diff_safir(
        population=population,
        contact_matrix=contact_matrix,
        R0=float(R_intervals[0]),
        R_t=R_t,
        T=T,
        dt=0.25,
        seed=seed,
        tau=0.5,      # Higher tau for smoother gradients during inference
        hard=False,   # Soft sampling for better gradients
        n_seed=50,
    )
    
    # Compute daily deaths from model
    predicted_D = result['D']
    predicted_daily = jnp.diff(predicted_D, prepend=0)
    predicted_daily = jnp.maximum(predicted_daily, 0.1)  # Avoid log(0)
    
    # Poisson log-likelihood
    # log P(k | λ) = k*log(λ) - λ - log(k!)
    # We ignore the factorial term as it's constant
    obs = jnp.asarray(observed_deaths)
    log_lik = jnp.sum(obs * jnp.log(predicted_daily + 1e-8) - predicted_daily)
    
    return log_lik


def log_posterior(log_R_intervals, observed_deaths, population, contact_matrix, T, interval_days, sigma_rw=0.3):
    """Compute log-posterior = log-likelihood + log-prior."""
    ll = log_likelihood(log_R_intervals, observed_deaths, population, contact_matrix, T, interval_days)
    lp = log_prior(log_R_intervals, sigma_rw)
    return ll + lp

## 5. Maximum A Posteriori (MAP) Estimation

Before running full MCMC, let's find the MAP estimate using gradient descent. This gives us a good starting point and validates that our model is working.

In [None]:
from emidm import optimize_params

# Initial guess: constant R = 1.5
init_log_R = jnp.log(jnp.ones(n_intervals) * 1.5)

# Define negative log-posterior for minimization
def neg_log_posterior(log_R_intervals):
    return -log_posterior(
        log_R_intervals, 
        observed_deaths, 
        population, 
        contact_matrix, 
        T, 
        interval_days,
        sigma_rw=0.25
    )

print("Finding MAP estimate...")
print(f"Initial negative log-posterior: {neg_log_posterior(init_log_R):.2f}")

In [None]:
# Run optimization
map_log_R, history = optimize_params(
    loss_fn=neg_log_posterior,
    init_params=init_log_R,
    n_steps=300,
    learning_rate=0.05,
)

map_R = jnp.exp(map_log_R)
print(f"\nFinal negative log-posterior: {history['loss'][-1]:.2f}")
print(f"\nMAP R(t) estimates:")
for i, r in enumerate(map_R):
    print(f"  Interval {i+1} (days {i*interval_days}-{(i+1)*interval_days-1}): R = {r:.2f}")

In [None]:
# Plot optimization progress
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

ax = axes[0]
ax.plot(history['loss'])
ax.set_xlabel('Iteration')
ax.set_ylabel('Negative Log-Posterior')
ax.set_title('Optimization Progress')

ax = axes[1]
# Plot true vs estimated R(t)
map_R_expanded = expand_R_intervals(map_R, T, interval_days)
ax.plot(np.arange(T+1), true_R_t, 'b-', linewidth=2, label='True R(t)')
ax.step(np.arange(T+1), np.array(map_R_expanded), 'r-', linewidth=2, where='post', label='MAP Estimate')
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Days')
ax.set_ylabel('R(t)')
ax.set_title('True vs MAP Estimated R(t)')
ax.legend()
ax.set_xlim(0, T)
ax.set_ylim(0, 3.5)
ax.set_xticks(month_days)
ax.set_xticklabels(months)

plt.tight_layout()
plt.show()

In [None]:
# Compare predicted vs observed deaths with MAP estimate
map_R_expanded = expand_R_intervals(map_R, T, interval_days)

map_result = run_diff_safir(
    population=population,
    contact_matrix=contact_matrix,
    R0=float(map_R[0]),
    R_t=map_R_expanded,
    T=T,
    dt=0.25,
    seed=42,
    tau=0.1,
    hard=True,
    n_seed=50,
)

map_daily_deaths = np.diff(np.array(map_result['D']), prepend=0)
map_daily_deaths = np.maximum(map_daily_deaths, 0)

fig, ax = plt.subplots(figsize=(12, 5))
ax.bar(np.arange(T+1), observed_deaths, color='red', alpha=0.5, width=1, label='Observed')
ax.plot(np.arange(T+1), map_daily_deaths, 'b-', linewidth=2, label='MAP Prediction')
ax.set_xlabel('Days')
ax.set_ylabel('Daily Deaths')
ax.set_title('Observed vs MAP Predicted Daily Deaths')
ax.legend()
ax.set_xlim(0, T)
ax.set_xticks(month_days)
ax.set_xticklabels(months)
plt.tight_layout()
plt.show()

## 6. Full Bayesian Inference with MCMC

Now we'll use the No-U-Turn Sampler (NUTS) via BlackJAX to get full posterior distributions for R(t).

**Note:** This can be computationally intensive. For demonstration, we use a smaller number of samples.

In [None]:
try:
    from emidm import run_blackjax_nuts
    HAS_BLACKJAX = True
    print("BlackJAX available - will run MCMC")
except ImportError:
    HAS_BLACKJAX = False
    print("BlackJAX not available - skipping MCMC (install with: pip install blackjax)")

In [None]:
if HAS_BLACKJAX:
    # Define log-density function for MCMC
    def log_density(log_R_intervals):
        return log_posterior(
            log_R_intervals,
            observed_deaths,
            population,
            contact_matrix,
            T,
            interval_days,
            sigma_rw=0.25
        )
    
    print("Running NUTS sampler...")
    print("(This may take several minutes)")
    
    # Use MAP estimate as starting point
    samples = run_blackjax_nuts(
        logdensity_fn=log_density,
        initial_position=map_log_R,
        rng_seed=42,
        num_warmup=100,
        num_samples=200,
    )
    
    print(f"\nSampling complete!")
    print(f"Samples shape: {samples.shape}")
else:
    print("Skipping MCMC - using MAP estimate only")
    samples = None

In [None]:
if samples is not None:
    # Convert log(R) samples to R
    R_samples = np.exp(np.array(samples))
    
    # Compute posterior statistics
    R_mean = R_samples.mean(axis=0)
    R_std = R_samples.std(axis=0)
    R_q05 = np.percentile(R_samples, 5, axis=0)
    R_q95 = np.percentile(R_samples, 95, axis=0)
    
    print("Posterior R(t) estimates (mean ± std):")
    for i in range(n_intervals):
        print(f"  Interval {i+1}: R = {R_mean[i]:.2f} ± {R_std[i]:.2f} (90% CI: [{R_q05[i]:.2f}, {R_q95[i]:.2f}])")

In [None]:
# Final comparison plot
fig, axes = plt.subplots(2, 1, figsize=(14, 10))

# R(t) comparison
ax = axes[0]
ax.plot(np.arange(T+1), true_R_t, 'b-', linewidth=2, label='True R(t)')

if samples is not None:
    # Expand posterior samples to daily values
    for i in range(min(50, len(R_samples))):  # Plot subset of samples
        R_expanded = expand_R_intervals(jnp.array(R_samples[i]), T, interval_days)
        ax.step(np.arange(T+1), np.array(R_expanded), 'gray', alpha=0.1, where='post')
    
    # Plot posterior mean
    R_mean_expanded = expand_R_intervals(jnp.array(R_mean), T, interval_days)
    ax.step(np.arange(T+1), np.array(R_mean_expanded), 'r-', linewidth=2, where='post', label='Posterior Mean')
else:
    # Just plot MAP
    map_R_expanded = expand_R_intervals(map_R, T, interval_days)
    ax.step(np.arange(T+1), np.array(map_R_expanded), 'r-', linewidth=2, where='post', label='MAP Estimate')

ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
ax.set_ylabel('R(t)')
ax.set_title('Estimated vs True Time-Varying Reproduction Number')
ax.legend()
ax.set_xlim(0, T)
ax.set_ylim(0, 3.5)
ax.set_xticks(month_days)
ax.set_xticklabels(months)

# Deaths comparison
ax = axes[1]
ax.bar(np.arange(T+1), observed_deaths, color='red', alpha=0.5, width=1, label='Observed Deaths')

if samples is not None:
    # Could run model with posterior samples for prediction intervals
    # For now, just show MAP prediction
    pass

ax.plot(np.arange(T+1), map_daily_deaths, 'b-', linewidth=2, label='Model Prediction')
ax.set_xlabel('Days')
ax.set_ylabel('Daily Deaths')
ax.set_title('Model Fit to Death Data')
ax.legend()
ax.set_xlim(0, T)
ax.set_xticks(month_days)
ax.set_xticklabels(months)

plt.tight_layout()
plt.show()

## 7. Summary

In this notebook, we demonstrated:

1. **Setting up a realistic SAFIR model** with UK demographics and contact patterns
2. **Simulating a UK 2020-like epidemic** with time-varying R(t) capturing:
   - Initial exponential growth
   - First lockdown impact
   - Summer reopening
   - Second wave
3. **Bayesian inference** for estimating R(t) from death data:
   - Piecewise constant parameterization (2-week intervals)
   - Random walk prior for smoothness
   - Poisson likelihood for death counts
   - MAP estimation via gradient descent
   - Full posterior via NUTS (if BlackJAX available)

### Key Insights

- **Differentiable models enable gradient-based inference** - Much faster than derivative-free methods
- **Death data has a lag** - R(t) changes are reflected in deaths ~2-3 weeks later
- **Uncertainty quantification** - Bayesian approach provides credible intervals
- **Model misspecification** - Real data would require accounting for reporting delays, weekday effects, etc.

### Extensions

- Use a Gaussian Process prior for smoother R(t) estimates
- Incorporate multiple data streams (cases, hospitalizations, deaths)
- Add observation model for reporting delays and noise
- Estimate other parameters (IFR, hospitalization rates) jointly