In [1]:
import jax.numpy as np
from jax.experimental import optimizers
from jax import jit, grad
from jax.experimental.ode import odeint
from jax import random
from tqdm import tqdm
import itertools
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
key = random.PRNGKey(1234)



In [2]:
class ODEfit:
    def __init__(self, t, X, x0, N, dxdt):      
        self.t = t
        self.x0 = x0
        self.X = X
        self.N = N
        self.dim = X.shape[1]
        self.dxdt = dxdt      
        self.params = random.uniform(key, (3,))
        print('Initial guess: beta = %f, gamma = %f, delta = %f' % (self.params[0], 
                                                                    self.params[1], 
                                                                    self.params[2]))
        
        # Set optimizer initialization and update functions
        self.learning_rate = optimizers.exponential_decay(1e-3, 
                                                          decay_steps=100, 
                                                          decay_rate=0.99)
        self.opt_init, \
        self.opt_update, \
        self.get_params = optimizers.adam(self.learning_rate)
        self.opt_state = self.opt_init(self.params)
        
        # Logger
        self.itercount = itertools.count()
        self.loss_log = []
        
    def loss(self, params, batch):
        beta = params[0]
        gamma = params[1]
        delta = params[2]
        pred = odeint(self.dxdt, self.x0, self.t, self.N, beta, gamma, delta)
        loss = np.mean((pred - batch)**2)
        return loss
    
    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, batch):
        params = self.get_params(opt_state)
        g = grad(self.loss)(params, batch)
        return self.opt_update(i, g, opt_state)
    
    # Optimize parameters in a loop
    def train(self, nIter = 10000):
        for it in tqdm(range(nIter)):
            self.opt_state = self.step(next(self.itercount), self.opt_state, self.X)            
            if it % 50 == 0:
                self.params = self.get_params(self.opt_state)
                loss_value = self.loss(self.params, self.X)
                self.loss_log.append(loss_value)
            if it == nIter:
                break
        self.params = self.get_params(self.opt_state)

In [3]:
def SEIR(x, t, N, beta, gamma, delta):
    S, E, I, R = x
    f1 = -beta * S * I / N
    f2 = beta * S * I / N - delta * E
    f3 = delta * E - gamma * I
    f4 = gamma * I
    dxdt = np.array([f1, f2, f3, f4])
    return dxdt

In [4]:
# Set reference parameters
N = 1000.0
beta = 1.0  # infected person infects 1 other person per day
D = 4.0 # infections lasts four days
gamma = 1.0 / D
delta = 1.0 / 3.0  # incubation period of three days
noise = 0.05

S0, E0, I0, R0 = 999.0, 0.0, 1.0, 0.0  # initial conditions: one infected, rest susceptible

x0 = np.array([S0, E0, I0, R0])
t = np.sort(100.0*random.uniform(key, (100, )))

# Generate time-series data
X_true = odeint(SEIR, x0, t, N, beta, gamma, delta)
X = X_true + noise*X_true.std(0)*random.normal(key, X_true.shape)

In [5]:
model = ODEfit(t, X, x0, N, SEIR)

Initial guess: beta = 0.492109, gamma = 0.470864, delta = 0.140462


In [None]:
model.train(nIter = 20000)

 70%|██████▉   | 13901/20000 [00:18<00:06, 907.88it/s]

In [None]:
X_pred = odeint(SEIR, x0, t, N, model.params[0], model.params[1], model.params[2])
print('True values: beta = %f, gamma = %f, delta = %f' % (beta, gamma, delta))
print('Pred values: beta = %f, gamma = %f, delta = %f' % (model.params[0], model.params[1], model.params[2]))

In [None]:
plt.rcParams.update({'font.size': 16})
plt.rcParams['axes.linewidth']=3
plt.figure(figsize = (9,7))
# plt.plot(t, X[:, 0], '.', markersize = 12, alpha = 0.5)
plt.plot(t, X_true[:, 0], linewidth = 2, label = r'$S(t)$')
plt.plot(t, X_pred[:, 0], 'k--', linewidth = 2)

# plt.plot(t, X[:, 1], '.', markersize = 12, alpha = 0.5)
plt.plot(t, X_true[:, 1], linewidth = 2, label = r'$E(t)$')
plt.plot(t, X_pred[:, 1], 'k--', linewidth = 2)

plt.plot(t, X[:, 2], '.', markersize = 12, alpha = 0.5)
plt.plot(t, X_true[:, 2], linewidth = 2, label = r'$I(t)$')
plt.plot(t, X_pred[:, 2], 'k--', linewidth = 2)

# plt.plot(t, X[:, 3], '.', markersize = 12, alpha = 0.5)
plt.plot(t, X_true[:, 3], linewidth = 2, label = r'$R(t)$')
plt.plot(t, X_pred[:, 3], 'k--', linewidth = 2)
plt.xlabel(r't')
plt.ylabel(r'x(t)')
plt.legend(loc='best', frameon = False)
plt.tight_layout()
plt.savefig('/Users/paris/Downloads/covid.png', dpi = 300)

plt.figure(figsize = (7,5))
plt.plot(model.loss_log, 'k', linewidth = 2)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')