In [9]:
import datetime
import time as time_module
import sys
import os 
import pickle

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
import theano
import matplotlib
import pymc3 as pm
import theano.tensor as tt

try: 
    import covid19_inference as cov19
except ModuleNotFoundError:
    sys.path.append('../..')
    import covid19_inference as cov19

path_to_save = '../../figures/'
path_save_pickled = '../../data'
rerun = True

In [None]:

confirmed_cases = cov19.get_jhu_confirmed_cases()

country = 'Germany'
date_data_begin = datetime.datetime(2020,3,1)
# date_data_end   = cov19.get_last_date(confirmed_cases)
date_data_end = datetime.datetime(2020,4,21)

#date_data_end   = datetime.datetime(2020,3,28)
num_days_data = (date_data_end-date_data_begin).days
diff_data_sim = 16 # should be significantly larger than the expected delay, in 
                   # order to always fit the same number of data points.
num_days_future = 28
date_begin_sim = date_data_begin - datetime.timedelta(days = diff_data_sim)
date_end_sim   = date_data_end   + datetime.timedelta(days = num_days_future)
num_days_sim = (date_end_sim-date_begin_sim).days


cases_obs = cov19.filter_one_country(confirmed_cases, country,
                                     date_data_begin, date_data_end)

print('Cases yesterday ({}): {} and '
      'day before yesterday: {}'.format(date_data_end.isoformat(), *cases_obs[:-3:-1]))

prior_date_mild_dist_begin =  datetime.datetime(2020,3,9)
prior_date_strong_dist_begin =  datetime.datetime(2020,3,16)
prior_date_contact_ban_begin =  datetime.datetime(2020,3,23)

change_points = [dict(pr_mean_date_begin_transient = prior_date_mild_dist_begin,
                      pr_sigma_date_begin_transient = 3,
                      pr_median_lambda = 0.2,
                      pr_sigma_lambda = 0.5),
                 dict(pr_mean_date_begin_transient = prior_date_strong_dist_begin,
                      pr_sigma_date_begin_transient = 1,
                      pr_median_lambda = 1/8,
                      pr_sigma_lambda = 0.5),
                 dict(pr_mean_date_begin_transient = prior_date_contact_ban_begin,
                      pr_sigma_date_begin_transient = 1,
                      pr_median_lambda = 1/8/2,
                      pr_sigma_lambda = 0.5)]
fact_larger = 4

change_points_days = [dict(pr_mean_date_begin_transient = prior_date_mild_dist_begin,
                      pr_sigma_date_begin_transient = 7,
                      pr_median_lambda = 0.2,
                      pr_sigma_lambda = 0.5,
                      pr_sigma_transient_len = 0.3),
                     dict(pr_mean_date_begin_transient = prior_date_strong_dist_begin,
                      pr_sigma_date_begin_transient = 7,
                      pr_median_lambda = 1/8,
                      pr_sigma_lambda = 0.5,
                      pr_sigma_transient_len = 0.3),
                     dict(pr_mean_date_begin_transient = prior_date_contact_ban_begin,
                      pr_sigma_date_begin_transient = 7,
                      pr_median_lambda = 1/8/2,
                      pr_sigma_lambda = 0.5,
                      pr_sigma_transient_len = 0.3)]

change_points_transient = [dict(pr_mean_date_begin_transient = prior_date_mild_dist_begin,
                      pr_sigma_date_begin_transient = 3,
                      pr_median_lambda = 0.2,
                      pr_sigma_lambda = 0.5,
                      pr_sigma_transient_len = 0.3*fact_larger),
                     dict(pr_mean_date_begin_transient = prior_date_strong_dist_begin,
                      pr_sigma_date_begin_transient = 1,
                      pr_median_lambda = 1/8,
                      pr_sigma_lambda = 0.5,
                      pr_sigma_transient_len = 0.3*fact_larger),
                     dict(pr_mean_date_begin_transient = prior_date_contact_ban_begin,
                      pr_sigma_date_begin_transient = 1,
                      pr_median_lambda = 1/8/2,
                      pr_sigma_lambda = 0.5,
                      pr_sigma_transient_len = 0.3*fact_larger)]

if rerun:

    traces = []
    models = []
    
    model_delay = cov19.SIR_with_change_points(new_cases_obs = np.diff(cases_obs),
                                        change_points_list = change_points,
                                        date_begin_simulation = date_begin_sim,
                                        num_days_sim = num_days_sim,
                                        diff_data_sim = diff_data_sim,
                                        N = 83e6,
                                        priors_dict= dict(pr_sigma_delay = 0.2*fact_larger),
                                        weekends_modulated=True,
                                        weekend_modulation_type = 'abs_sine')
    models.append(model_delay)
    traces.append(pm.sample(model=model_delay, init='advi+adapt_diag', draws=4000, tune=1000))
    
    model_days = cov19.SIR_with_change_points(new_cases_obs = np.diff(cases_obs),
                                    change_points_list = change_points_days,
                                    date_begin_simulation = date_begin_sim,
                                    num_days_sim = num_days_sim,
                                    diff_data_sim = diff_data_sim,
                                    N = 83e6,
                                    priors_dict = None,
                                    weekends_modulated=True,
                                    weekend_modulation_type = 'abs_sine')
    
    models.append(model_days)
    traces.append(pm.sample(model=model_days, init='advi+adapt_diag', draws=4000, tune=4000))
    
    model_transient = cov19.SIR_with_change_points(new_cases_obs = np.diff(cases_obs),
                                    change_points_list = change_points_transient,
                                    date_begin_simulation = date_begin_sim,
                                    num_days_sim = num_days_sim,
                                    diff_data_sim = diff_data_sim,
                                    N = 83e6,
                                    priors_dict = None,
                                    weekends_modulated=True,
                                    weekend_modulation_type = 'abs_sine')
    
    models.append(model_transient)
    traces.append(pm.sample(model=model_transient, init='advi+adapt_diag', draws=4000, tune=4000))



    pickle.dump([models, traces], open(path_save_pickled + 'SIR_with_sine_with_larger_prior.pickled', 'wb'))

else: 
    models, traces = pickle.load(open(path_save_pickled + 'SIR_with_sine_with_larger_prior.pickled', 'rb'))

    

Cases yesterday (2020-04-21T00:00:00): 148291 and day before yesterday: 147065
pr_beta_I_begin was set to default value 100
pr_median_lambda_0 was set to default value 0.4
pr_sigma_lambda_0 was set to default value 0.5
pr_median_mu was set to default value 0.125
pr_sigma_mu was set to default value 0.2
pr_median_delay was set to default value 8
pr_beta_sigma_obs was set to default value 10
week_end_days was set to default value (6, 7)
pr_mean_weekend_factor was set to default value 0.7
pr_sigma_weekend_factor was set to default value 0.17
pr_median_transient_len of change point 0 was set to default value 3
pr_median_transient_len of change point 1 was set to default value 3
pr_median_transient_len of change point 2 was set to default value 3
pr_sigma_transient_len of change point 0 was set to default value 0.3
pr_sigma_transient_len of change point 1 was set to default value 0.3
pr_sigma_transient_len of change point 2 was set to default value 0.3


Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
Average Loss = 474.75:   9%|▊         | 17464/200000 [00:34<05:56, 511.59it/s]
Convergence achieved at 17500
Interrupted at 17,499 [8%]: Average Loss = 596.4
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [offset_modulation_rad, weekend_factor, sigma_obs, delay, mu, transient_len_2, transient_len_1, transient_len_0, transient_begin_2, transient_begin_1, transient_begin_0, lambda_3, lambda_2, lambda_1, lambda_0, I_begin]
Sampling 4 chains, 0 divergences:  62%|██████▏   | 12352/20000 [12:21<07:24, 17.19draws/s]

In [None]:
exec(open('figures_revised.py').read())

In [None]:
create_figure_distributions(models[0], traces[0],
                              additional_insets = None, xlim_lambda = (0, 0.53), color = 'tab:green',
                              num_changepoints=3, xlim_tbegin=7, save_to = path_to_save +'Fig_larger_prior_delay_dist')
create_figure_distributions(models[1], traces[1],
                              additional_insets = None, xlim_lambda = (0, 0.53), color = 'tab:green',
                              num_changepoints=3, xlim_tbegin=8, save_to = path_to_save +'Fig_larger_prior_day_dist')
create_figure_distributions(models[2], traces[2],
                              additional_insets = None, xlim_lambda = (0, 0.53), color = 'tab:green',
                              num_changepoints=3, xlim_tbegin=7, save_to = path_to_save +'Fig_larger_prior_transient_dist')


In [None]:
create_figure_timeseries(traces[0], 'tab:green',
                       plot_red_axis=True, save_to = path_to_save +'Fig_larger_prior_delay_cases')
create_figure_timeseries(traces[1], 'tab:green',
                       plot_red_axis=True, save_to = path_to_save +'Fig_larger_prior_day_cases')
create_figure_timeseries(traces[2], 'tab:green',
                       plot_red_axis=True, save_to = path_to_save +'Fig_larger_prior_transient_cases')

In [None]:
print('\nlarge prior delay\n')
print(pm.loo(traces[0], models[0]))

print('\nlarge prior days\n')
print(pm.loo(traces[1], models[1]))

print('\nlarge prior transients\n')
print(pm.loo(traces[2], models[2]))