In [2]:
!pip install jaxlib jax numpyro

Collecting numpyro
[?25l  Downloading https://files.pythonhosted.org/packages/35/f1/7bada66676245f9e085b870b1051ba183b377af287002e10a2e1bea1b498/numpyro-0.4.1-py3-none-any.whl (176kB)
[K     |█▉                              | 10kB 17.4MB/s eta 0:00:01[K     |███▊                            | 20kB 16.0MB/s eta 0:00:01[K     |█████▋                          | 30kB 14.2MB/s eta 0:00:01[K     |███████▍                        | 40kB 13.6MB/s eta 0:00:01[K     |█████████▎                      | 51kB 11.4MB/s eta 0:00:01[K     |███████████▏                    | 61kB 11.5MB/s eta 0:00:01[K     |█████████████                   | 71kB 11.5MB/s eta 0:00:01[K     |██████████████▉                 | 81kB 11.8MB/s eta 0:00:01[K     |████████████████▊               | 92kB 11.4MB/s eta 0:00:01[K     |██████████████████▋             | 102kB 11.6MB/s eta 0:00:01[K     |████████████████████▍           | 112kB 11.6MB/s eta 0:00:01[K     |██████████████████████▎         | 122kB 11.6

In [None]:
import jax.numpy as np
from jax import vmap, jit
from jax.experimental.ode import odeint
import jax.random as random
from jax.config import config
config.update("jax_enable_x64", True)

from numpyro import sample
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import numpy as onp
import matplotlib.pyplot as plt
from functools import partial
import time

%matplotlib inline

In [None]:
class ODEfit:
    def __init__(self, t, X, x0, N, dxdt): 
        self.t = t
        self.x0 = x0
        self.X = X
        self.N = N
        self.dxdt = dxdt      
        
    def model(self, X):
        # Priors
        beta = sample('beta', dist.TruncatedNormal(low=0.0, loc=0.5, scale=2.0))
        gamma = sample('gamma', dist.TruncatedNormal(low=0.0, loc=0.5, scale=1.0))
        delta = sample('delta', dist.TruncatedNormal(low=0.0, loc=0.5, scale=1.0))
        noise = sample('noise_var', dist.LogNormal(0.0, 10.0))
            
        # Likelihood
        z = odeint(self.dxdt, self.x0, self.t, self.N, beta, gamma, delta)[:,2]
        sample("X", dist.Normal(z, noise), obs=X)
        
    def train(self, settings, rng_key):
        start = time.time()
        kernel = NUTS(self.model, 
                      target_accept_prob = settings['target_accept_prob'])
        mcmc = MCMC(kernel, 
                    num_warmup = settings['num_warmup'], 
                    num_samples = settings['num_samples'],
                    num_chains = settings['num_chains'],
                    progress_bar=True,
                    jit_model_args=True)
        mcmc.run(rng_key, self.X)
        mcmc.print_summary()
        elapsed = time.time() - start
        print('\nMCMC elapsed time: %.2f seconds' % (elapsed))
        return mcmc.get_samples()
    
    def predict(self, beta, gamma, delta):
        X = odeint(self.dxdt, self.x0, self.t, self.N, beta, gamma, delta)
        return X

In [None]:
def SEIR(z, t, N, beta, gamma, delta):
    """
    SEIR model. beta is the transmission rate that is subject to human activity, while gamma, delta are almost disease specific.
    """
    S, E, I, R = z
    f1 = -beta * S * I / N
    f2 = beta * S * I / N - delta * E
    f3 = delta * E - gamma * I
    f4 = gamma * I
    dxdt = np.array([f1, f2, f3, f4])
    return dxdt

In [None]:
# Set reference parameters
N = 1000.0
beta = 1.0  # infected person infects 1 other person per day
D = 4.0 # infections lasts four days
gamma = 1.0 / D
delta = 1.0 / 3.0  # incubation period of three days
noise = 0.05
key = random.PRNGKey(1234)

S0, E0, I0, R0 = 999.0, 0.0, 1.0, 0.0  # initial conditions: one infected, rest susceptible

x0 = np.array([S0, E0, I0, R0])
#t = np.sort(100.0*random.uniform(key, (100, )))
t = np.linspace(0,100,100)

# Generate time-series data
X_true = odeint(SEIR, x0, t, N, beta, gamma, delta)
data = X_true + noise*X_true.std(0)*random.normal(key, X_true.shape)

In [None]:
model = ODEfit(t, data[:,2], x0, N, SEIR)
rng_key_train, rng_key_predict = random.split(random.PRNGKey(0))

In [None]:
num_warmup = 1000
num_samples = 2000
num_chains = 1
target_accept_prob = 0.85
settings = {'num_warmup': num_warmup,
            'num_samples': num_samples,
            'num_chains': num_chains,
            'target_accept_prob': target_accept_prob}
samples = model.train(settings, rng_key_train)  
print('True values: beta = %f, gamma = %f, delta = %f' % (beta, gamma, delta))

warmup:  23%|██▎       | 699/3000 [01:18<03:07, 12.25it/s, 39 steps of size 1.16e-02. acc. prob=0.84] 

In [None]:
vmap_args = (samples['beta'], samples['gamma'], samples['delta'])
X_pred = vmap(model.predict)(*vmap_args)
mean_prediction, std_prediction = np.mean(X_pred, axis=0), np.std(X_pred, axis=0)
lower = mean_prediction - 2.0*std_prediction
upper = mean_prediction + 2.0*std_prediction

In [None]:
plt.rcParams.update({'font.size': 16})
plt.rcParams['axes.linewidth']=3
plt.figure(figsize = (16,9))
plt.plot(t,X_true, linewidth = 3)
for i in range(4):
    plt.plot(t, mean_prediction[:,i], 'k--', linewidth = 2)
    plt.fill_between(t, lower[:,i], upper[:,i], alpha=0.3)
# plt.title("Posterior predictive (80% CI) with SEIR pattern.")
plt.plot(t, data[:,2], 'o', markersize = 10, alpha = 0.5)
plt.legend()
plt.gca().set(ylim=(0, 1000), xlabel="days", ylabel="population")