In [None]:
import datetime
import time as time_module
import sys
import os 

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
import theano
import matplotlib
import pymc3 as pm
import theano.tensor as tt

try: 
    import covid19_inference as cov19
except ModuleNotFoundError:
    sys.path.append('..')
    import covid19_inference as cov19



First we load the desired data from a source we also need to tell the class to download the data

## Data retrieval

In [None]:
rki = cov19.data_retrieval.RKI(True)
#rki.download_all_available_data() #True argument above does this


Wait for the download to finish. It will print a message!

We can now access this downloaded data by the attribute but normaly one would use the build in filter methods.
```
rki.data
```
More infos on the data retrieval can be found in the [documentation](https://covid19-inference.readthedocs.io/en/latest/doc/data_retrieval.html).

## Create the model

In [None]:
date_begin_data = datetime.datetime(2020,3,10)
date_end_data   = datetime.datetime(2020,4,19)

df_bundeslaender = rki.filter_all_bundesland(date_begin_data, date_end_data)
new_cases_obs = np.diff(np.array(df_bundeslaender),axis=0)[:,:]

diff_data_sim = 16 # should be significantly larger than the expected delay, in 
                   # order to always fit the same number of data points.
num_days_forecast = 10


# Change points
prior_date_mild_dist_begin =  datetime.datetime(2020,3,11)
prior_date_strong_dist_begin =  datetime.datetime(2020,3,18)
prior_date_contact_ban_begin =  datetime.datetime(2020,3,25)

change_points = [dict(pr_mean_date_transient = prior_date_mild_dist_begin,
                      pr_sigma_date_transient = 1.5,
                      pr_median_lambda = 0.2,
                      pr_sigma_lambda = 0.5,
                     pr_sigma_transient_len=0.5),
                 dict(pr_mean_date_transient = prior_date_strong_dist_begin,
                      pr_sigma_date_transient = 1.5,
                      pr_median_lambda = 1/8,
                      pr_sigma_lambda = 0.5,
                     pr_sigma_transient_len=0.5),
                 dict(pr_mean_date_transient = prior_date_contact_ban_begin,
                      pr_sigma_date_transient = 1.5,
                      pr_median_lambda = 1/8/2,
                      pr_sigma_lambda = 0.5,
                     pr_sigma_transient_len=0.5)];

In [None]:
params_model = dict(new_cases_obs = new_cases_obs[:],
                    data_begin = date_begin_data,
                    fcast_len = num_days_forecast,
                    diff_data_sim = diff_data_sim,
                    N_population = 83e6) 
# normally one would put as N_population an array with the number of inhabitants of 
# of each state

with cov19.Cov19Model(**params_model) as model:
    # Create the an array of the time dependent infection rate lambda
    lambda_t_log = cov19.lambda_t_with_sigmoids(pr_median_lambda_0 = 0.4, pr_sigma_lambda_0 = 0.5,
                                                change_points_list = change_points)
    
    pr_median_delay = 10
    mu = pm.Lognormal(name="mu", mu=np.log(1/8), sigma=0.2)
    
    # This builds a decorrelated prior for I_begin for faster inference. 
    # It is not necessary to use it, one can simply remove it and use the default argument 
    # for pr_I_begin in cov19.SIR
    prior_I = cov19.make_prior_I(lambda_t_log, mu, pr_median_delay = pr_median_delay)
    
    # Use lambda_t_log and mu to run the SIR model
    new_I_t = cov19.SIR(lambda_t_log, mu, pr_I_begin = prior_I)
    
    # Delay the cases by a lognormal reporting delay    
    new_cases_inferred_raw = cov19.delay_cases(new_I_t, pr_median_delay=pr_median_delay, 
                                               pr_median_scale_delay=0.3)
    
    # Modulate the inferred cases by a abs(sin(x)) function, to account for weekend effects
    new_cases_inferred = cov19.week_modulation(new_cases_inferred_raw)
    
    # Define the likelihood, uses the new_cases_obs set as model parameter
    cov19.student_t_likelihood(new_cases_inferred)


## MCMC sampling

In [None]:
trace = pm.sample(model=model, tune=1000, draws=1000, init='advi+adapt_diag')

# Plotting

In [None]:
varnames = cov19.plotting.get_all_free_RVs_names(model)
num_cols = 5
num_rows = int(np.ceil(len(varnames)/num_cols))
x_size = num_cols * 2.5
y_size = num_rows * 2.5

fig, axes = plt.subplots(num_rows, num_cols, figsize = (x_size, y_size),squeeze=False)
i_ax = 0
for i_row, axes_row in enumerate(axes):
    for i_col, ax in enumerate(axes_row):
        if i_ax >= len(varnames):
            ax.set_visible(False)
            continue 
        else:
            cov19.plotting.plot_hist(model, trace, ax, varnames[i_ax], 
                                     colors=('tab:blue', 'tab:green'))
        if not i_col == 0:
            ax.set_ylabel('')
        if i_col == 0 and i_row == 0:
            ax.legend()
        i_ax += 1
fig.subplots_adjust(wspace=0.25, hspace=0.4)

In [None]:
for var_name in ['lambda_0_L2', 'lambda_1_L2', 'lambda_2_L2', 'lambda_3_L2',
                 'transient_day_1_L2', 'transient_day_2_L2', 'transient_day_3_L2', 'delay_L2']:
    f, ax = plt.subplots()
    ax.violinplot(trace[var_name], showextrema=False, vert=False)
    ax.set_yticks(np.arange(1,17))
    ax.set_yticklabels(df_bundeslaender.columns)
    ax.set_xlabel(var_name)

### Plot
We first define the x range on which we want to plot. Should be within the simulation period.

In [None]:
bd = datetime.datetime(2020,3,10)
ed = datetime.datetime(2020,4,19) + datetime.timedelta(days=num_days_forecast-1)

Next we have to get our trace corresponding to that date range. We do that by calling the helper function `cov19.plot._get_array_from_trace_via_date()`.

In [None]:
y_all_regions, x = cov19.plot._get_array_from_trace_via_date(model,trace,"new_cases",bd,ed)


Depending on which trace variable we want to plot we have to look at the shape and maybe reshape it to match date length and maybe also sort the regions. Also the plot function wants a one dimensional array.

After that we plot the timeseries. And set the format of the date axis.

In [None]:
trace.varnames

In [None]:
for i in range(16):
    y = y_all_regions[:,:,i]
    ax = cov19.plot._timeseries(x,y)
    ax.set_title(df_bundeslaender.columns[i])
    cov19.plot._format_date_xticks(ax)
plt.show()