In [1]:
import random
import jax.numpy as jnp
import jax.random as jaxran
import pandas as pd
import numpyro
import numpyro.distributions as dist

In [2]:
from numpyro import handlers
from numpyro.examples.runge_kutta import runge_kutta_4
from numpyro.infer import NUTS, MCMC
def seir_update(day, seir_state, *, mu, beta, nu, sigma, gamma):
    s = seir_state[..., 0]
    e = seir_state[..., 1]
    i = seir_state[..., 2]
    r = seir_state[..., 3]
    n = s + e + i + r
    s_upd = mu * (n - s) - beta * (s * i / n) - nu * s
    e_upd = beta * (s * i / n) - (mu + sigma) * e
    i_upd = sigma * e - (mu + gamma) * i
    r_upd = gamma * i - mu * r + nu * s
    return jnp.stack((s_upd, e_upd, i_upd, r_upd), axis=-1)

In [3]:
covid_data = pd.read_csv('data/COVID19_Denmark.csv')
population = covid_data['Population'][0]
initial_seir_state = jnp.array([population - 1., 1., 0., 0.])
infection_data = covid_data['StillInfected'].to_numpy()
recovery_data = covid_data['Recovered'].to_numpy() + covid_data[
    'Deaths'].to_numpy()  # Dead people are also recovered in the model

In [4]:
step_size = 0.1
num_days = infection_data.shape[0]
num_steps = int(num_days / step_size)
seir = runge_kutta_4(seir_update, step_size, num_steps)

In [5]:
def model(initial_seir_state, infection_data, recovery_data):
    if infection_data is not None:
        assert num_days == infection_data.shape[0]
    if recovery_data is not None:
        assert num_days == recovery_data.shape[0]
    beta = numpyro.sample('beta', dist.HalfCauchy(scale=1000.))
    gamma = numpyro.sample('gamma', dist.HalfCauchy(scale=1000.))
    sigma = numpyro.sample('sigma', dist.HalfCauchy(scale=1000.))
    mu = numpyro.sample('mu', dist.HalfCauchy(scale=1000.))
    nu = jnp.array(0.0)  # No vaccine yet
    rng_key = numpyro.rng_key('rng_key')
    sim, lp_reg = seir(rng_key, initial_seir_state, mu=mu, beta=beta, nu=nu, sigma=sigma, gamma=gamma)
    sim = jnp.reshape(sim, (num_days, int(1 / step_size), 4))[:, -1, :] + 1e-3
    with numpyro.plate('data', num_days):
        numpyro.sample('infections', dist.Poisson(sim[:, 2]), obs=infection_data)
        numpyro.sample('recovery', dist.Poisson(sim[:, 3]), obs=recovery_data)

In [6]:
rngkey = jaxran.PRNGKey(random.randint(0, 10000))
rng_seed, rngkey = jaxran.split(rngkey)

In [None]:
num_warmup = 2500
num_samples = 10000
kernel = NUTS(handlers.seed(model, rng_seed))
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rngkey, initial_seir_state, infection_data, recovery_data)
mcmc.print_summary()

sample:  75%|███████▌  | 9382/12500 [37:58<13:42,  3.79it/s, 1023 steps of size 1.09e-05. acc. prob=0.85]  