In [1]:
import datetime
import io
import logging

import dateutil
import datetime
import numpy as np
import pandas as pd
import pymc3 as pm
import theano
import theano.tensor as T

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

import warnings
warnings.simplefilter("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
### Mock parameters

if 0:
    # Countermeasures
    CMs = ["Stay home", "Respirators"]
    nCMs = len(CMs)
    # Countries
    Cs = ["CZ", "SK", "DE", "PL"]
    nCs = len(Cs)
    # Days
    Ds = [f"03-{i}" for i in range(10, 21)]
    nDs = len(Ds)

    # Probability of testing positive after transmission, from 0
    DelayProb = [0.1 for i in range(10)]

    ### Mock input data

    # [country, CM, day] Which CMs are active, and to what extent
    ActiveCMs = np.random.exponential(0.2, size=(nCs, nCMs, nDs)).astype(theano.config.floatX)

    # Mock growth rate [country, day]
    grate = np.random.lognormal(0.2, 0.2, size=(nCs, nDs))
    # cummulation of the daily rates * measurememt error
    Confirmed = np.cumprod(grate, axis=1) * np.random.lognormal(0.0, 0.3, size=(nCs, nDs)).astype(theano.config.floatX)
    #Confirmed[:,7] = np.nan

In [4]:
from epimodel.region_data import RegionDataset
rds = RegionDataset.from_csv('../data/regions.csv')
rds.read_csv_groups('../data/data')
print(rds.col_groups)

[2020-04-01 01:37:48,735] INFO(epimodel.region_data): Name index has 6 potential conflicts: ['american samoa', 'georgia', 'guam', 'northern mariana islands', 'puerto rico', 'united states minor outlying islands']
[2020-04-01 01:37:48,737] INFO(epimodel.region_data): Loading group JH from ../data/data-JH.csv ...
[2020-04-01 01:37:48,987] INFO(epimodel.region_data): Loading group SCM from ../data/data-SCM.csv ...


{'basic': ['Level', 'Name', 'OfficialName', 'OtherNames', 'Continent', 'Subregion', 'Country', 'ISOa3', 'M49Code', 'Lat', 'Lon', 'Population'], 'JH': {'JH_Active', 'JH_Confirmed', 'JH_Recovered', 'JH_Deaths'}, 'SCM': {'Asymptomatic isolation - blanket', 'Diagnostic criteria loosened', 'Mask wearing', 'Testing', 'Miscellaneous hygiene measures', 'Contact tracing', 'Nonessential business suspension', 'International travel restriction', 'Symptomatic isolation - blanket', 'Asymptomatic isolation - targeted', 'Assisting people to stay home', 'Asymptomatic isolation - semi-targeted', 'Diagnostic criteria tightened', 'Symptomatic isolation - semi-targeted', 'Public education and incentives', 'School closure', 'Symptomatic isolation - targeted', 'Domestic travel restriction', 'Gatherings banned', 'Resumption', 'Testing criteria', 'Public cleaning', 'Healthcare specialisation', 'Activity cancellation', 'Public interaction reduction', 'Hand washing'}}


In [5]:
### Real params and data

if 1:
    # Countermeasures
    CMs = ['Asymptomatic isolation - semi-targeted', 'International travel restriction', 'School closure', 'Healthcare specialisation', 'Resumption', 'Asymptomatic isolation - targeted', 'Assisting people to stay home', 'Diagnostic criteria tightened', 'Public cleaning', 'Asymptomatic isolation - blanket', 'Public interaction reduction', 'Domestic travel restriction', 'Symptomatic isolation - targeted', 'Nonessential business suspension', 'Mask wearing', 'Public education and incentives', 'Activity cancellation', 'Testing criteria', 'Symptomatic isolation - blanket']
    #CMS += ['Gatherings banned', 'Contact tracing', 'Testing']
    nCMs = len(CMs)
    # Countries
    Cs = ["DK", "CZ", "GE", "FR", "ES", "GB", "PL", "GR", "CH", "BE", "FI", "HU", "NO", "RO", "SE", "SI", "SK"]
    nCs = len(Cs)
    # Days
    FullDs = [datetime.date(2020,2,i) for i in range(20, 30)] + [datetime.date(2020,3,i) for i in range(1, 29)]
    # HACK: Assume fixed 7 days to CM effect
    CM_Ds = FullDs[:-7]
    JH_Ds = FullDs[7:]
    JHName = "JH_Confirmed"
    # Deaths!
    CM_Ds = FullDs[:-14]
    JH_Ds = FullDs[14:]
    JHName = "JH_Deaths"

    nDs = len(CM_Ds)

    # Probability of testing positive after transmission, from 0
    DelayProb = [0.00, 0.01, 0.02, 0.05, 0.09, 0.13, 0.15, 0.15, 0.13, 0.10, 0.07, 0.05, 0.03, 0.01, 0.01]

    # [country, CM, day] Which CMs are active, and to what extent
    sd = rds.series.loc[Cs, CMs]
    for cm in CMs:
        d = rds.series[cm]
        print(f"{cm[:29]:30}, {d.min().min():.3f}, {d.mean().mean():.3f}, {d.max().max():.3f}")
        rds.series[cm] /= d.max().max()
    ActiveCMs = np.stack([sd.loc[Cs, [(cm, d) for d in CM_Ds]].values for cm in CMs], axis=1)
    assert ActiveCMs.shape == (nCs, nCMs, nDs)
    ActiveCMs = ActiveCMs.astype(theano.config.floatX)

    # [country, day]
    Confirmed = rds.series.loc[tuple(Cs), [(JHName, d) for d in JH_Ds]].values
    assert Confirmed.shape == (nCs, nDs)
    Confirmed[Confirmed < 5.0] = np.nan
    Confirmed = np.ma.masked_invalid(Confirmed.astype(theano.config.floatX))


[2020-04-01 01:37:52,306] INFO(numexpr.utils): NumExpr defaulting to 4 threads.


Asymptomatic isolation - semi , 0.000, 0.039, 5.000
International travel restrict , 0.000, 0.999, 6.000
School closure                , 0.000, 0.386, 18.000
Healthcare specialisation     , 0.000, 0.067, 8.000
Resumption                    , 0.000, 0.002, 1.000
Asymptomatic isolation - targ , 0.000, 0.011, 1.000
Assisting people to stay home , 0.000, 0.072, 8.000
Diagnostic criteria tightened , 0.000, 0.004, 1.000
Public cleaning               , 0.000, 0.022, 5.000
Asymptomatic isolation - blan , 0.000, 0.091, 3.000
Public interaction reduction  , 0.000, 0.018, 4.000
Domestic travel restriction   , 0.000, 0.025, 2.000
Symptomatic isolation - targe , 0.000, 0.232, 2.000
Nonessential business suspens , 0.000, 0.174, 9.000
Mask wearing                  , 0.000, 0.195, 100.000
Public education and incentiv , 0.000, 0.062, 7.000
Activity cancellation         , 0.000, 0.269, 17.000
Testing criteria              , 0.000, 0.006, 0.900
Symptomatic isolation - blank , 0.000, 0.018, 2.000


In [6]:
with pm.Model() as model:
    # [] Baseline growth rate (wide prior OK, mean estimates ~10% daily growth)
    BaseGrowthRate = pm.Lognormal("BaseGrowthRate", np.log(1.1), 2.0)
    # [country] Initial size of epidemic (the day before the start, only those detected; wide prior OK)
    InitialSize = pm.Lognormal("InitialSize", 0.0, 10, shape=(nCs,))
    # [country] Country growth rate
    # TODO: Estimate growth rate variance
    CountryGrowthRate = pm.Lognormal("CountryGrowthRate", pm.math.log(BaseGrowthRate), 0.2, shape=(nCs,))
    # [CM] How much countermeasures reduce growth rate
    # TODO: Estimate variance, or use another dist.
    #CMReduction = pm.Bound(pm.Lognormal, upper=1.0)("CMReduction", 0.0, 0.1, shape=(nCMs,))
    CMReduction = pm.Lognormal("CMReduction", 0.0, 0.1, shape=(nCMs,))
    # [country, CM, day] Reduction factor for each CM,C,D
    ActiveCMReduction = T.reshape(CMReduction, (1, nCMs, 1)) ** ActiveCMs
    # [country, day] Reduction factor from CMs for each C,D (noise added below)
    GrowthReduction = pm.Deterministic("GrowthReduction", T.prod(ActiveCMReduction, axis=1))
    # [country, day] The ideal predicted daily growth
    PreictedGrowth = pm.Deterministic("PreictedGrowth", T.reshape(CountryGrowthRate, (nCs, 1)) * GrowthReduction)
    # [country, day] The actual (still hidden) growth each day
    # TODO: Estimate noise varince (should be small, measurement variance below)
    #       Miscalibration: too low: time effects pushed into CMs, too high: explains away CMs
    DailyGrowth =  pm.Lognormal("DailyGrowth", pm.math.log(PreictedGrowth), 0.1, shape=(nCs, nDs))

    # Below I assume plain exponentia growth of confirmed rather than e.g. depending on actives etc.

    # [country, day] The number of cases that would be detected with noiseless testing
    # (Noise source includes both false-P/N rates and local variance in test volume and targetting)
    # (Since we ony care about growth rates and assume consistent testing, it is fine to ignore real size)
    Size = pm.Deterministic("Size", T.reshape(InitialSize, (nCs, 1)) * DailyGrowth.cumprod(axis=1))
    # [country, day] Cummulative tested positives
    Observed = pm.Lognormal("Observed", pm.math.log(Size), 0.4, shape=(nCs, nDs), observed=Confirmed)


In [None]:
print(model.check_test_point())
with model:
    trace = pm.sample(1000, chains=2, cores=4, init='adapt_diag', tune=1000)

[2020-04-01 01:37:57,183] INFO(pymc3): Auto-assigning NUTS sampler...
[2020-04-01 01:37:57,184] INFO(pymc3): Initializing NUTS using adapt_diag...


BaseGrowthRate_log__         -1.61
InitialSize_log__           -54.77
CountryGrowthRate_log__      11.74
CMReduction_log__            26.29
DailyGrowth_log__           564.53
Observed_missing              0.00
Observed                  -5176.18
Name: Log-probability of test_point, dtype: float64


[2020-04-01 01:38:00,269] INFO(pymc3): Multiprocess sampling (2 chains in 4 jobs)
[2020-04-01 01:38:00,270] INFO(pymc3): NUTS: [Observed_missing, DailyGrowth, CMReduction, CountryGrowthRate, InitialSize, BaseGrowthRate]
Sampling 2 chains, 0 divergences:   3%|▎         | 123/4000 [00:09<07:09,  9.03draws/s]

In [None]:
pm.traceplot(trace, var_names=["BaseGrowthRate", "CountryGrowthRate", "DailyGrowth", "CMReduction"])

In [None]:
pm.forestplot(trace, varnames=['CMReduction'], credible_interval=0.9)
print(', '.join(f"{i}: {c}" for i, c in enumerate(CMs)))