### Import packages

In [None]:
import sys
sys.path.append('..')

from epimodel import EpidemiologicalParameters, preprocess_data
from epimodel.pymc3_models.models import ComplexDifferentEffectsModel
from get_NPI_dataset import get_NPI_dataset

import numpy as np
import pymc3 as pm

import matplotlib.pyplot as plt 
import arviz as az
import pickle

### Create Dataset

In [None]:
save_path="../data/NPI_dataset"

In [None]:
data=get_NPI_dataset(save_path=save_path,start="2020-08-31", end=None)

### Load and Preprocess Dataset & Model

In [None]:
data = preprocess_data(f'{save_path}.csv', smoothing=1)
data.mask_reopenings(print_out = False)

In [None]:
ep = EpidemiologicalParameters()

In [None]:
bd = ep.get_model_build_dict()

In [None]:
#show model build dict
bd

In [None]:
#build model
with ComplexDifferentEffectsModel(data) as model:
    model.build_model(**bd)

In [None]:
#visualize & save model structure 
pm.model_to_graphviz(model).render("../figs/final_model_structure")
pm.model_to_graphviz(model)

In [None]:
#sampling
with model:
    model.trace = pm.sample(2000, tune=500, cores=4, chains=4, max_treedepth=18, target_accept=0.96)

In [None]:
# save results in a pickle file
pickle.dump(model.trace, open('../data/traces/NPI_trace.pkl', 'wb'))

In [None]:
#load trace
#file=open("../data/traces/NPI_trace.pkl","rb")
#model.trace=pickle.load(file)

### Get Insights

In [None]:
#params including Gelman-Rubin statistic
with model:
    display(az.summary(model.trace, round_to=2))

In [None]:
#show posterior predictive
with model:
    post_pred=pm.sample_posterior_predictive(model.trace)

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(post_pred['ObservedCases'].T, color="0.5", alpha=.1);
ax.set(
       title="Posterior predictive Cases", 
       xlabel="Days since 100 cases",
       ylabel="Positive cases");

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(post_pred['ObservedDeaths'].T, color="0.5", alpha=.1);
ax.set(
       title="Posterior predictive Deaths", 
       xlabel="Days since 100 cases",
       ylabel="Positive cases");

In [None]:
#show trace plots
axes = az.plot_trace(model.trace,['CMAlphaScales'])
fig = axes.ravel()[0].figure
fig.savefig("../figs/NPI_trace_plots.png")