<a href="https://colab.research.google.com/github/assemzh/ProbProg-COVID-19/blob/master/virus_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Epidemiology model

https://nbviewer.jupyter.org/github/pyro-ppl/pyro/blob/sir-tutorial-ii/tutorial/source/epi_regional.ipynb?fbclid=IwAR3Gv8tLuiEjOmZh7-NQUa_ggm_QUqtSc5TxRZ0_pSxVA7Y3lWWzSFGKjrA 


In [None]:
!git clone https://github.com/pyro-ppl/pyro.git

In [None]:
%cd /content/pyro


In [None]:
!pip install .[extras]

In [None]:
import os
import logging
import urllib.request
from collections import OrderedDict

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist
from pyro.ops.tensor_utils import convolve

%matplotlib inline
pyro.enable_validation(True)           
torch.set_default_dtype(torch.double) 


  ## Model without Policies
  

In [None]:
class CovidModel(CompartmentalModel):
    def __init__(self, population, new_cases, new_recovered, new_deaths):
        '''
        population (int) – Total population = S + E + I + R.
        '''
        assert len(new_cases) == len(new_recovered) == len(new_deaths)

        compartments = ("S", "E", "I")  # R is implicit.
        duration = len(new_cases)
        super().__init__(compartments, duration, population)

        self.new_cases = new_cases
        self.new_deaths = new_deaths
        self.new_recovered = new_recovered
        

    def global_model(self):
        tau_e = pyro.sample("rec_time",dist.Normal(15.0, 1))
        tau_i = pyro.sample("incub_time",dist.Normal(5.0, 1))
        # R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
        R0 = pyro.sample("R0", dist.Normal(2, 0.5))
        rho = pyro.sample("rho", dist.Beta(10, 10))  # About 50% response rate.
        mort_rate = pyro.sample("mort_rate", dist.Beta(2, 100))  # About 2% mortality rate.
        rec_rate = pyro.sample("rec_rate",dist.Beta(10, 10)) # About 50% recovery rate.
        return R0, tau_e, tau_i, rho, mort_rate, rec_rate

    def initialize(self, params):
        # Start with a single infection.
        return {"S": self.population - 1, "E": 0, "I": 1}

    def transition(self, params, state, t):
        R0, tau_e, tau_i, rho, mort_rate, rec_rate = params

        # Sample flows between compartments.
        S2E = pyro.sample("S2E_{}".format(t),
                            infection_dist(individual_rate=R0 / tau_i,
                                          num_susceptible=state["S"],
                                          num_infectious=state["I"],
                                          population=self.population))
        E2I = pyro.sample("E2I_{}".format(t),
                            binomial_dist(state["E"], 1 / tau_e )) 
        I2R = pyro.sample("I2R_{}".format(t),
                            binomial_dist(state["I"], 1 / tau_i))
 
        # Update compartments with flows.
        state["S"] = state["S"] - S2E 
        state["E"] = state["E"] + S2E - E2I
        state["I"] = state["I"] + E2I - I2R

        # Condition on observations.
        t_is_observed = isinstance(t, slice) or t < self.duration
        pyro.sample("new_cases_{}".format(t),
                    binomial_dist(S2E, rho),
                    obs=self.new_cases[t] if t_is_observed else None)
        pyro.sample("new_deaths_{}".format(t),
                      binomial_dist(I2R, mort_rate),
                      obs=self.new_deaths[t] if t_is_observed else None)
        pyro.sample("new_recovered_{}".format(t),
                    binomial_dist(E2I, rec_rate),
                    obs=self.new_recovered[t] if t_is_observed else None)


## Create Country

In [None]:
# function to make the time series of confirmed and daily confirmed cases for a specific country
def create_country (country, start_date, end_date, state = False) : 

    url = 'https://raw.githubusercontent.com/assemzh/ProbProg-COVID-19/master/full_grouped.csv'
    data = pd.read_csv(url)

    data.Date = pd.to_datetime(data.Date)

    if state :
        df = data.loc[data["Province/State"] == country, ["Province/State", "Date", "Confirmed", "Deaths", "Recovered", "Active", "New cases", "New deaths", "New recovered"]]
    else : 
        df = data.loc[data["Country/Region"] == country, ["Country/Region", "Date", "Confirmed", "Deaths", "Recovered", "Active", "New cases", "New deaths", "New recovered"]]
    df.columns = ["country", "date", "confirmed", "deaths", "recovered", "active", "new_cases", "new_deaths", "new_recovered"]

    # group by country and date
    df = df.groupby(['country','date'])['confirmed', 'deaths', 'recovered',"active", "new_cases", "new_deaths", "new_recovered"].sum().reset_index()

    # convert date string to datetime
    df.date = pd.to_datetime(df.date)
    df = df.sort_values(by = "date")
    df = df[df.date >= start_date]
    df = df[df.date <= end_date]

    active = df['active'].tolist()
    recovered = df['recovered'].tolist()
    deaths = df['deaths'].tolist()
    new_cases = df['new_cases'].tolist()
    new_recovered = df['new_recovered'].tolist()
    new_deaths = df['new_deaths'].tolist()
    
    active = torch.tensor(list(map(float, active))).view(len(active),1) 
    recovered = torch.tensor(list(map(float, recovered))).view(len(recovered),1) 
    deaths = torch.tensor(list(map(float, deaths))).view(len(deaths),1) 
    new_cases = torch.tensor(list(map(float, new_cases))).view(len(new_cases),1)  
    new_recovered = torch.tensor(list(map(float, new_recovered))).view(len(new_recovered),1) 
    new_deaths = torch.tensor(list(map(float, new_deaths))).view(len(new_deaths),1) 


    return_data = {
        'active':active,
        'recovered':recovered,
        'deaths':deaths,
        'new_cases':new_cases,
        'new_recovered': new_recovered,
        'new_deaths':new_deaths }
        
    return return_data


## Get data for countries


In [None]:
Japan = create_country("Japan", start_date = "2020-02-01", end_date = "2020-04-01")
Sweden = create_country("Sweden", start_date = "2020-02-01", end_date = "2020-04-01")


  app.launch_new_instance()


##Train the model using MCMC.



In [None]:
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )

In [None]:
%%time
pyro.set_rng_seed(20200607)
Japan_mcmc = Japan_model.fit_mcmc(num_samples=200)

INFO 	 Running inference...
Warmup:   0%|          | 0/400 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.65, incub_time=6.58, mort_rate=0.00376, rec_rate=0.438, rec_time=15.9, rho=0.221
Sample: 100%|██████████| 400/400 [01:55,  3.47it/s, step size=2.67e-03, acc. prob=0.883]

CPU times: user 1min 53s, sys: 480 ms, total: 1min 54s
Wall time: 1min 55s





In [None]:
Japan_mcmc.summary()


                     mean       std    median      5.0%     95.0%     n_eff     r_hat
       rec_time     16.00      0.00     16.00     15.99     16.00      3.09      1.92
     incub_time      6.53      0.00      6.54      6.53      6.54      7.69      1.00
             R0      1.58      0.01      1.57      1.56      1.60      2.48      2.73
            rho      0.29      0.01      0.29      0.28      0.30      2.58      2.46
      mort_rate      0.02      0.00      0.02      0.02      0.02      2.50      2.79
       rec_rate      0.14      0.01      0.14      0.13      0.14      2.57      2.51

Number of divergences: 0


In [None]:
Sweden_model = CovidModel(10230000, Sweden["new_cases"], Sweden["new_recovered"], Sweden["new_deaths"] )


In [None]:
%%time
pyro.set_rng_seed(20200607)
Sweden_mcmc = Sweden_model.fit_mcmc(num_samples=10)

INFO 	 Running inference...
Warmup:   0%|          | 0/20 [00:00, ?it/s]INFO 	 Heuristic init: R0=2.4, incub_time=3.83, mort_rate=0.0132, rec_rate=0.273, rec_time=16.3, rho=0.285
Sample: 100%|██████████| 20/20 [00:02,  8.45it/s, step size=2.57e-05, acc. prob=1.000]

CPU times: user 2.37 s, sys: 9.97 ms, total: 2.38 s
Wall time: 2.38 s





In [None]:
Sweden_mcmc.summary()


                     mean       std    median      5.0%     95.0%     n_eff     r_hat
       rec_time     16.33      0.00     16.33     16.33     16.33      2.78      2.28
     incub_time      3.83      0.00      3.83      3.83      3.83      7.60      1.42
             R0      2.40      0.00      2.40      2.40      2.40      7.69      1.09
            rho      0.54      0.03      0.52      0.51      0.58      2.39      4.38
      mort_rate      0.02      0.00      0.02      0.01      0.02      2.38      5.53
       rec_rate      0.27      0.00      0.27      0.27      0.27      2.38      4.84

Number of divergences: 0
