  \\[
  \begin{aligned}
  \frac{dS}{dt} &= -\frac{\beta SI}{N} \\
  \frac{dI}{dt} &= \frac{\beta SI}{N} - \gamma I -\theta I \\
  \frac{dR}{dt} &= \gamma I \\
  \frac{dQ}{dt} &= \theta I
  \end{aligned}
  \\]

In [1]:
from abc import ABCMeta, abstractmethod
from collections import namedtuple

import numpy as np
from scipy.optimize import minimize
import plotly.graph_objects as go 
import plotly.express as px 
import plotly.io as pio
from plotly.subplots import make_subplots

from easymh import mh

In [161]:
class Law(metaclass=ABCMeta):
    @staticmethod
    @abstractmethod
    def sample(n, d):
        pass

    @staticmethod
    @abstractmethod
    def loglikely(n, d, k):
        pass
   
    @staticmethod
    def likelihood(n, d, k):
        return np.exp(loglikely(n, d, k))

    
class Bin(Law):
    def sample(n, d):
        return np.random.binomial(n, d)
    
    def loglikely(n, d, k):
        return k*np.log(d) + (n-k)*np.log(1-d)
 
       
class Poi(Law):
    def sample(n, d):
        return np.random.poisson(n * d)
    
    def loglikely(n, d, k):
        return k*np.log(n*d) - n*d

In [162]:
class Dynamic(metaclass=ABCMeta):
    @abstractmethod
    def estimate(self, region, T):
        pass


class SIR(Dynamic):
    def __init__(self, beta, gamma, dt=1):
        self.beta = beta * dt
        self.gamma = gamma * dt
        
    def __repr__(self):
        return "β={}, γ={}".format(self.beta, self.gamma)
    
    def estimate(self, region, T):
        S = np.zeros(T+1)
        I = np.zeros(T+1)
        R = np.zeros(T+1)
        S[0] = region.S
        I[0] = region.I
        R[0] = region.R
        N = S[0] + I[0] + R[0]
        
        for t in range(T):
            a, b = self.beta*S[t]*I[t]/N, self.gamma*I[t]
            S[t+1] = S[t] - a
            I[t+1] = I[t] + a - b
            R[t+1] = R[t] + b        
        
        Epidemic = namedtuple('Epidemic', 'S I R')
        return Epidemic(S, I, R)

    def predict(self, region, T):
        return self.estimate(self, region, T)
    
    @staticmethod
    def plot(epidemic):
        fig = go.Figure()
        fig.update_layout(margin=dict(b=0, l=0, r=0, t=25))
        T = len(epidemic.S) - 1
        fig.add_scatter(x=np.arange(T+1), y=epidemic.S.astype(int), name="Susceptible", hovertemplate="%{y}")
        fig.add_scatter(x=np.arange(T+1), y=epidemic.I.astype(int), name="Infectious", hovertemplate="%{y}")
        fig.add_scatter(x=np.arange(T+1), y=epidemic.R.astype(int), name="Removed", hovertemplate="%{y}")
        return fig        


class SIRQ(Dynamic):
    def __init__(self, beta, gamma, theta, dt=1):
        self.beta = beta * dt
        self.gamma = gamma * dt
        self.theta = theta * dt
        
    def __repr__(self):
        return "β={}, γ={}, θ={}".format(self.beta, self.gamma, self.theta)

    def estimate(self, region, T):
        S = np.zeros(T+1)
        I = np.zeros(T+1)
        R = np.zeros(T+1)
        Q = np.zeros(T+1)
        
        S[0] = region.S
        I[0] = region.I
        R[0] = region.R
        Q[0] = region.Q
        N = S[0] + I[0] + R[0] + Q[0]

        for t in range(T):
            a, b, c = self.beta*S[t]*I[t]/N, self.gamma*I[t], self.theta*I[t]
            S[t+1] = S[t] - a
            I[t+1] = I[t] + a - b - c
            R[t+1] = R[t] + b
            Q[t+1] = Q[t] + c
        
        Epidemic = namedtuple('Epidemic', 'S I R Q')
        return Epidemic(S, I, R, Q)

    def predict(self, region, T):
        return self.estimate(self, region, T)
    
    @staticmethod
    def plot(epidemic):
        fig = go.Figure()
        fig.update_layout(margin=dict(b=0, l=0, r=0, t=25))
        T = len(epidemic.S) - 1
        fig.add_scatter(x=np.arange(T+1), y=epidemic.S.astype(int), name="Susceptible", hovertemplate="%{y}")
        fig.add_scatter(x=np.arange(T+1), y=epidemic.I.astype(int), name="Infectious", hovertemplate="%{y}")
        fig.add_scatter(x=np.arange(T+1), y=epidemic.R.astype(int), name="Removed", hovertemplate="%{y}")
        fig.add_scatter(x=np.arange(T+1), y=epidemic.Q.astype(int), name="Quarantined", hovertemplate="%{y}")
        return fig  
    

class Sample:
    def __init__(self, epidemic, ts, ms, ns, law, seed=None):
        if seed is not None:
            np.random.seed(seed)
        
        positive = np.zeros_like(ts)
        for i, (t, m, n) in enumerate(zip(ts, ms, ns)):
            positive[i] = law.sample(n, epidemic.I[t]/m)
            
        self.t = ts
        self.m = ms
        self.n = ns
        self.positive = positive
        self._law = law
        
    def __repr__(self):
        return " t: {} \n m: {} \n n: {} \n positive: {}".format(self.t, self.m, self.n, self.positive)
    
    def plot(self, fig):
        fig.add_scatter(
            x=self.t, y=self.positive / self.n * self.m, 
            mode="markers", name="Guessed", hovertemplate="%{y}"
        )
        return fig

In [226]:
def loglikely(epidemic, sample, law):
    ms = sample.m
    ns = sample.n
    ds = epidemic.I[sample.t] / ms
    ks = sample.positive
    return sum(law.loglikely(n, d, k) for n, d, k in zip(ns, ds, ks))


def likelihood(epidemic, sample, law):
    return np.exp(loglikely(epidemic, sample, law))


Id = lambda x: x
one = lambda x: 1


class InferSIR():
    def __init__(self, law=Poi, algo="map"):
        self.law = law
        self.algo = algo
        
    def __str__(self):
        return "β={}, γ={}, loglikely={}".format(self.beta, self.gamma, self.loglikely)
    
    def plot(self, region, sample, law=None):
        if law is None:
            law = self.law

        x, y = np.logspace(-2, 0, 50), np.logspace(-2, 0, 50)
        z = np.zeros((len(y), len(x)))
        for i in range(len(y)):
            for j in range(len(x)):
                dynamic = SIR(x[j], y[i])
                epidemic = dynamic.estimate(region, sample.t[-1])
                z[i, j] = loglikely(epidemic, sample, law)

        fig = go.Figure(data=go.Contour(z=np.log(np.max(z)-z+1), x=x, y=y, showscale=False))
        fig.update_layout(
            showlegend=False,
            margin=dict(b=0, l=0, r=0, t=25),
            xaxis=dict(scaleanchor="y", scaleratio=1, constrain="domain", range=(-2, 0))
        )
        fig.update_xaxes(type="log")
        fig.update_yaxes(type="log")
        return fig
    
    def fit_beta_gamma_map(self, region, sample, law=None, **kvarg):
        if law is None:
            law = self.law
            
        def func(x):
            dynamic = SIR(x[0], x[1])
            epidemic = dynamic.estimate(region, sample.t[-1])
            return -loglikely(epidemic, sample, law)
        
        res = minimize(func, (0.5,0.5), method='nelder-mead', options={'xatol': 1e-8, 'disp': True})
        self.beta, self.gamma = res.x
        self.loglikely = -res.fun
        fig = self.plot(region, sample, law)
        fig.add_scatter(x=[self.beta], y=[self.gamma])
        fig.show()
        
    def fit_beta_gamma_mh(self, region, sample, law=None, method='naive', **kvarg):
        if law is None:
            law = self.law
        
        def func(x):
            dynamic = SIR(*x)
            epidemic = dynamic.estimate(region, sample.t[-1])
            return likelihood(epidemic, sample, law)
        
        def func2(x):
            dynamic = SIR(*np.power(10, x))
            epidemic = dynamic.estimate(region, sample.t[-1])
            return likelihood(epidemic, sample, law) * np.exp(x[0]) * np.exp(x[1])
            
        if method == 'naive':
            res, walker = mh([0.5, 0.5], func, np.array([[0.01, 1], [0.01, 1]]), width=0.1, **kvarg)
        elif method == 'mirror':
            res, walker = mh([0.5, 0.5], func, np.array([[0.01, 1], [0.01, 1]]), width=0.1, ascdes=(np.log, np.exp), **kvarg)
        elif method == 'repar':
            res, walker = mh([-1., -1.], func2, np.array([[-2, 0], [-2, 0]]), width=0.1, **kvarg)
            res = np.power(10, res)
            walker = np.power(10, walker)

        self.beta, self.gamma = res
        self.loglikely = np.log(func(res))
        self.walker = walker

        fig = self.plot(region, sample, law)
        fig.add_scatter(x=self.walker[:, 0], y=self.walker[:, 1], mode="markers+lines")
        fig.show()  