

*Original code by Ryan S. McGee. Modified by T.W. Alleman in consultation with the BIOMATH research unit headed by prof. Ingmar Nopens.*

Copyright (c) 2020 by T.W. Alleman, BIOMATH, Ghent University. All Rights Reserved.

This notebook accompanies our preprint: "*A deterministic, age-stratified, extended SEIRD model for assessing the effect of non-pharmaceutical interventions on SARS-CoV-2 spread in Belgium*"(https://doi.org/10.1101/2020.07.17.20156034)

# Load required packages

In [1]:
import random
import os
import numpy as np
import json
import corner
import random

import pandas as pd
import datetime
import scipy
import matplotlib.dates as mdates
import matplotlib
import math
import xarray as xr
import emcee
import matplotlib.pyplot as plt
import datetime

from covid19model.optimization import objective_fcns,pso
from covid19model.models import models
from covid19model.models.utils import draw_sample_COVID19_SEIRD_google
from covid19model.models.time_dependant_parameter_fncs import google_lockdown, ramp_fun, contact_matrix
from covid19model.data import google
from covid19model.data import sciensano
from covid19model.data import model_parameters
from covid19model.visualization.output import population_status, infected, _apply_tick_locator 
from covid19model.visualization.optimization import plot_fit, traceplot


# OPTIONAL: Load the "autoreload" extension so that package code can change
%load_ext autoreload
# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

# Load data

In [2]:
initN, Nc_home, Nc_work, Nc_schools, Nc_transport, Nc_leisure, Nc_others, Nc_total = model_parameters.get_interaction_matrices(dataset='willem_2012')
levels = initN.size
Nc_all = {'total': Nc_total, 'home':Nc_home, 'work': Nc_work, 'schools': Nc_schools, 'transport': Nc_transport, 'leisure': Nc_leisure, 'others': Nc_others}


In [None]:
plt.imshow(Nc_schools, cmap='viridis')

In [None]:
plt.imshow(Nc_leisure, cmap='viridis')

In [None]:
plt.imshow(Nc_others, cmap='viridis')

In [3]:
df_sciensano = sciensano.get_sciensano_COVID19_data(update=False)
df_sciensano.tail(2)

Unnamed: 0_level_0,H_tot,ICU_tot,H_in,H_out,H_tot_cumsum,D_tot,D_25_44,D_45_64,D_65_74,D_75_84,D_85+
DATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2020-11-24,4820,1105,286,528,4594,83,0.0,12.0,10,24,37
2020-11-25,4570,1071,272,429,4437,17,1.0,4.0,3,3,6


In [4]:
df_google = google.get_google_mobility_data(update=False, plot=False)
df_google.tail(2)

Unnamed: 0_level_0,retail_recreation,grocery,parks,transport,work,residential
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2020-11-21,-58.0,-11.0,13.0,-41.0,-20.0,13.0
2020-11-22,-63.0,-18.0,-11.0,-45.0,-21.0,10.0


In [None]:

# def switch_beta(t,param,samples_dict):
#     if t < pd.to_datetime('2020-05-04'):
#         return np.random.choice(samples_dict['beta'],1,replace=False)
#     elif pd.to_datetime('2020-05-04') < t <= pd.to_datetime('2020-09-01'):
#         return np.random.choice(samples_dict['beta_summer'],1,replace=False)
#     else:
#         return np.random.choice(samples_dict['beta'],1,replace=False)

# Wave 2: September 2020 - present

In [5]:
with open('../../data/interim/model_parameters/COVID19_SEIRD/calibrations/national/google/initial_states_2020-09-01.json', 'r') as fp:
    initial_states = json.load(fp)    

In [6]:
# Start of data collection
start_data = '2020-09-01'
# Start data of recalibration ramp
start_calibration = '2020-09-01'
# Last datapoint used to recalibrate the ramp
end_calibration = '2020-11-23'
# Path where figures should be stored
fig_path = '../../results/calibrations/COVID19_SEIRD/national/'
# Path where MCMC samples should be saved
samples_path = '../../data/interim/model_parameters/COVID19_SEIRD/calibrations/national/'


## 6 prevention parameters

In [7]:
# Spatial unit: Belgium
spatial_unit = 'BE_6_prev_thin'

In [8]:
def contact_matrix(t, df_google, Nc_all, prev_home=1, prev_schools=1, prev_work=1, prev_transport=1, prev_leisure=1, prev_others=1, school=None, work=None, transport=None, leisure=None, others=None):
    """
    t : timestamp
        current date
    Nc_all : dictionnary
        contact matrices for home, schools, work, transport, leisure and others
    prev_... : float [0,1]
        prevention parameter to estimate
    school, work, transport, leisure, others : float [0,1]
        level of opening of these sectors
        if None, it is calculated from google mobility data
        only school cannot be None!
    """
    
    if t < pd.Timestamp('2020-03-15'):
        CM = Nc_all['total']
    else:
        
        if school is None:
            raise ValueError(
            "Please indicate to which extend schools are open")
        
        if pd.Timestamp('2020-03-15') < t <= df_google.index[-1]:
            #take t.date() because t can be more than a date! (e.g. when tau_days is added)
            row = -df_google[df_google.index == pd.Timestamp(t.date())]/100 
        else:
            row = -df_google.iloc[[-1],:]/100

        if work is None:
            work=(1-row['work'].values)[0]
        if transport is None:
            transport=(1-row['transport'].values)[0]
        if leisure is None:
            leisure=(1-row['retail_recreation'].values)[0]
        if others is None:
            others=(1-row['grocery'].values)[0]

        CM = (prev_home*(1/2.3)*Nc_all['home'] + 
              prev_schools*school*Nc_all['schools'] + 
              prev_work*work*Nc_all['work'] + 
              prev_transport*transport*Nc_all['transport'] + 
              prev_leisure*leisure*Nc_all['leisure'] + 
              prev_others*others*Nc_all['others']) 


    return CM


In [9]:
def wave2_policies(t,param,df_google, Nc_all, l , tau, 
                   prev_schools, prev_work, prev_transport, prev_leisure, prev_others, prev_home):
    
    # Convert tau and l to dates
    tau_days = pd.Timedelta(tau, unit='D')
    l_days = pd.Timedelta(l, unit='D')

    # Define additional dates where intensity or school policy changes
    t1 = pd.Timestamp('2020-03-15') # start of lockdown
    t2 = pd.Timestamp('2020-05-15') # gradual re-opening of schools (assume 50% of nominal scenario)
    t3 = pd.Timestamp('2020-07-01') # start of summer: COVID-urgency very low
    t4 = pd.Timestamp('2020-08-01')
    t5 = pd.Timestamp('2020-09-01') # september: lockdown relaxation narrative in newspapers reduces sense of urgency
    t6 = pd.Timestamp('2020-10-19') # lockdown
    t7 = pd.Timestamp('2020-11-16') # schools re-open
    t8 = pd.Timestamp('2020-12-18') # schools close
    t9 = pd.Timestamp('2021-01-04') # schools re-open

    if t5 < t <= t6 + tau_days:
        return contact_matrix(t, df_google, Nc_all, school=1)
    elif t6 + tau_days < t <= t6 + tau_days + l_days:
        policy_old = contact_matrix(t, df_google, Nc_all, school=1)
        policy_new = contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_transport, 
                                    prev_leisure, prev_others, school=0)
        return ramp_fun(policy_old, policy_new, t, tau_days, l, t6)
    elif t6 + tau_days + l_days < t <= t7:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_transport, 
                              prev_leisure, prev_others, school=0)
    elif t7 < t <= t8:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_transport, 
                              prev_leisure, prev_others, school=1)
    elif t8 < t <= t9:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_transport, 
                              prev_leisure, prev_others, school=0)
    else:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_transport, 
                              prev_leisure, prev_others, school=1)



In [10]:
# Load the model parameters using `get_COVID19_SEIRD_parameters()`.
params = model_parameters.get_COVID19_SEIRD_parameters()

params.update({'df_google': df_google,
              'Nc_all' : Nc_all,
               'l' : 5,
               'tau' : 5,
               'prev_schools': 0.5,
               'prev_work': 0.5,
               'prev_transport': 0.5,
               'prev_leisure': 0.5,
               'prev_others': 0.5,
               'prev_home' : 0.5
              })

# Initialize the model
model = models.COVID19_SEIRD(initial_states, params, time_dependent_parameters={'Nc': wave2_policies})

In [None]:
warmup=0
maxiter = 100
popsize = 100
processes = 32 # voor eraser!!
steps_mcmc = 3000#3000
discard = 1000

# define dataset
data=[df_sciensano['H_in'][start_calibration:end_calibration]]
states = [["H_in"]]

####################################################
####### CALIBRATING BETA AND COMPLIANCE RAMP #######
####################################################

print('------------------------------------')
print('CALIBRATING BETA AND COMPLIANCE RAMP')
print('------------------------------------\n')
print('Using data from '+start_calibration+' until '+end_calibration+'\n')
print('1) Particle swarm optimization\n')

# set PSO optimisation settings
parNames = ['sigma_data','beta','l','tau',
            'prev_schools', 'prev_work', 'prev_transport', 'prev_leisure', 'prev_others', 'prev_home']
bounds=((1,2000),(0.010,0.060),(0.1,20),(0.1,20),
        (0,1),(0,1),(0,1),(0,1),(0,1),(0,1))

# run PSO optimisation
theta = pso.fit_pso(model,data,parNames,states,bounds,maxiter=maxiter,popsize=popsize,
                    start_date=start_calibration,warmup=warmup, processes=processes) ## PROCESSES=1 to debug!

# run MCMC sampler
print('\n2) Markov-Chain Monte-Carlo sampling\n')
parNames_mcmc = parNames
bounds_mcmc=((1,2000),(0.020,0.060),(0.001,20),(0.001,20),
             (0,1),(0,1),(0,1),(0,1),(0,1),(0,1))

ndim = len(theta)
nwalkers = ndim*2
perturbations = ([1]+(ndim-1)*[1e-3]) * np.random.randn(nwalkers, ndim)
pos = theta + perturbations

sampler = emcee.EnsembleSampler(nwalkers, ndim, objective_fcns.log_probability,
                    args=(model, bounds_mcmc, data, states, parNames_mcmc, None, start_calibration, warmup))
sampler.run_mcmc(pos, steps_mcmc, progress=True)

thin = 0
try:
    autocorr = sampler.get_autocorr_time()
    thin = int(0.5 * np.min(autocorr))
except:
    print('Warning: The chain is shorter than 50 times the integrated autocorrelation time for 4 parameter(s).\nUse this estimate with caution and run a longer chain!')
from covid19model.optimization.run_optimization import checkplots
checkplots(sampler, discard, thin, fig_path, spatial_unit, figname='BETA_RAMP_GOOGLE_WAVE2', 
           labels=['$\sigma_{data}$','$\\beta$','l','$\\tau$',
                   'prev_schools', 'prev_work', 'prev_transport', 'prev_leisure', 'prev_others', 'prev_home'])

#############################################
####### Output to dictionary ################
#############################################

flat_samples = sampler.get_chain(discard=discard,thin=thin,flat=True)

samples_dict_wave2 = {}
for count,name in enumerate(parNames_mcmc):
    samples_dict_wave2[name] = flat_samples[:,count].tolist()

samples_dict_wave2.update({
    'theta_pso' : list(theta_pso),
    'warmup' : warmup,
    'calibration_data' : states[0][0],
    'start_date' : start_calibration,
    'end_date' : end_calibration,
    'maxiter' : maxiter,
    'popsize': popsize,
    'steps_mcmc': steps_mcmc,
    'discard' : discard,
})

with open(samples_path+str(spatial_unit)+'_'+str(datetime.date.today())+'_WAVE2_GOOGLE.json', 'w') as fp:
    json.dump(samples_dict_wave2, fp)

------------------------------------
CALIBRATING BETA AND COMPLIANCE RAMP
------------------------------------

Using data from 2020-09-01 until 2020-11-23

1) Particle swarm optimization

No constraints given.
New best for swarm at iteration 1: [2.02484483e+02 2.32991882e-02 1.76434121e+01 8.85235632e+00
 0.00000000e+00 7.30273203e-01 1.00000000e+00 1.00000000e+00
 6.42328315e-01 1.76285056e-01] 496.3010578759139
Best after iteration 1: [2.02484483e+02 2.32991882e-02 1.76434121e+01 8.85235632e+00
 0.00000000e+00 7.30273203e-01 1.00000000e+00 1.00000000e+00
 6.42328315e-01 1.76285056e-01] 496.3010578759139
Best after iteration 2: [2.02484483e+02 2.32991882e-02 1.76434121e+01 8.85235632e+00
 0.00000000e+00 7.30273203e-01 1.00000000e+00 1.00000000e+00
 6.42328315e-01 1.76285056e-01] 496.3010578759139
New best for swarm at iteration 3: [1.48997355e+02 2.58051802e-02 9.02577875e+00 9.44811959e+00
 2.94525174e-01 1.00000000e+00 9.94404046e-03 1.00000000e+00
 7.16618933e-01 8.05543598e-01] 4

Best after iteration 26: [5.38947983e+01 2.84715652e-02 4.05680275e+00 2.21881709e+00
 3.40926848e-03 8.61000387e-03 2.73097216e-01 1.00000000e+00
 7.95426264e-01 1.82394044e-01] 375.96204370468234
Best after iteration 27: [5.38947983e+01 2.84715652e-02 4.05680275e+00 2.21881709e+00
 3.40926848e-03 8.61000387e-03 2.73097216e-01 1.00000000e+00
 7.95426264e-01 1.82394044e-01] 375.96204370468234
Best after iteration 28: [5.38947983e+01 2.84715652e-02 4.05680275e+00 2.21881709e+00
 3.40926848e-03 8.61000387e-03 2.73097216e-01 1.00000000e+00
 7.95426264e-01 1.82394044e-01] 375.96204370468234
Best after iteration 29: [5.38947983e+01 2.84715652e-02 4.05680275e+00 2.21881709e+00
 3.40926848e-03 8.61000387e-03 2.73097216e-01 1.00000000e+00
 7.95426264e-01 1.82394044e-01] 375.96204370468234
Best after iteration 30: [5.38947983e+01 2.84715652e-02 4.05680275e+00 2.21881709e+00
 3.40926848e-03 8.61000387e-03 2.73097216e-01 1.00000000e+00
 7.95426264e-01 1.82394044e-01] 375.96204370468234
Best after

New best for swarm at iteration 56: [5.25044901e+01 2.85109157e-02 4.85427402e+00 2.54789716e+00
 2.69862868e-03 0.00000000e+00 2.41938488e-01 1.00000000e+00
 6.64191786e-01 5.15435543e-02] 374.7664965022086
Best after iteration 56: [5.25044901e+01 2.85109157e-02 4.85427402e+00 2.54789716e+00
 2.69862868e-03 0.00000000e+00 2.41938488e-01 1.00000000e+00
 6.64191786e-01 5.15435543e-02] 374.7664965022086
Best after iteration 57: [5.25044901e+01 2.85109157e-02 4.85427402e+00 2.54789716e+00
 2.69862868e-03 0.00000000e+00 2.41938488e-01 1.00000000e+00
 6.64191786e-01 5.15435543e-02] 374.7664965022086
New best for swarm at iteration 58: [5.27010898e+01 2.85049310e-02 4.93080288e+00 2.51163552e+00
 2.70736474e-03 0.00000000e+00 2.43238558e-01 1.00000000e+00
 6.64831493e-01 6.96885758e-02] 374.76638707759383
Best after iteration 58: [5.27010898e+01 2.85049310e-02 4.93080288e+00 2.51163552e+00
 2.70736474e-03 0.00000000e+00 2.43238558e-01 1.00000000e+00
 6.64831493e-01 6.96885758e-02] 374.766387

Best after iteration 82: [5.27168751e+01 2.84983513e-02 3.83632409e+00 2.90153649e+00
 6.98804777e-04 0.00000000e+00 3.04432974e-01 1.00000000e+00
 6.81543196e-01 2.97386888e-02] 374.65447920490544
New best for swarm at iteration 83: [5.26925946e+01 2.85018719e-02 3.78091551e+00 2.89701075e+00
 8.56836038e-04 0.00000000e+00 3.02621798e-01 1.00000000e+00
 6.81823314e-01 3.07147915e-02] 374.64251326252224
Best after iteration 83: [5.26925946e+01 2.85018719e-02 3.78091551e+00 2.89701075e+00
 8.56836038e-04 0.00000000e+00 3.02621798e-01 1.00000000e+00
 6.81823314e-01 3.07147915e-02] 374.64251326252224
Best after iteration 84: [5.26925946e+01 2.85018719e-02 3.78091551e+00 2.89701075e+00
 8.56836038e-04 0.00000000e+00 3.02621798e-01 1.00000000e+00
 6.81823314e-01 3.07147915e-02] 374.64251326252224
Best after iteration 85: [5.26925946e+01 2.85018719e-02 3.78091551e+00 2.89701075e+00
 8.56836038e-04 0.00000000e+00 3.02621798e-01 1.00000000e+00
 6.81823314e-01 3.07147915e-02] 374.64251326252224

  lnpdiff = f + nlp - state.log_prob[j]
 10%|█         | 304/3000 [42:08<6:30:26,  8.69s/it]

In [None]:
fig,ax = plt.subplots()
plt.hist(samples_dict_wave2['prev_schools'])

In [None]:
end_sim = '2021-01-01'

fig,ax=plt.subplots(figsize=(10,4))
for i in range(100):
    # Sampling
    idx, model.parameters['beta'] = random.choice(list(enumerate(samples_dict_wave2['beta'])))
    model.parameters['l'] = samples_dict_wave2['l'][idx] 
    model.parameters['tau'] = samples_dict_wave2['tau'][idx]    
    model.parameters['prev_schools'] = samples_dict_wave2['prev_schools'][idx]    
    model.parameters['prev_work'] = samples_dict_wave2['prev_work'][idx]     
    model.parameters['prev_transport'] = samples_dict_wave2['prev_transport'][idx]    
    model.parameters['prev_leisure'] = samples_dict_wave2['prev_leisure'][idx]     
    model.parameters['prev_others'] = samples_dict_wave2['prev_others'][idx]      

    # Simulate
    y_model = model.sim(end_sim,start_date=start_calibration,warmup=0)
    # Plot
    ax.plot(y_model['time'],y_model["H_in"].sum(dim="Nc"),color='blue',alpha=0.05)

ax.scatter(df_sciensano[start_calibration:end_calibration].index,df_sciensano['H_in'][start_calibration:end_calibration],color='black',alpha=0.6,linestyle='None',facecolors='none')
ax = _apply_tick_locator(ax)
ax.set_xlim('2020-09-01',end_sim)
fig.savefig(fig_path+'others/FIT_WAVE2_GOOGLE_'+spatial_unit+'_'+str(datetime.date.today())+'.pdf', dpi=400, bbox_inches='tight')

## 4 prevention parameters

In [None]:
# Spatial unit: Belgium
spatial_unit = 'BE_4_prev_full'

In [None]:
def contact_matrix(t, df_google, Nc_all, prev_home=1, prev_schools=1, prev_work=1, prev_rest=1, school=None, work=None, transport=None, leisure=None, others=None):
    """
    t : timestamp
        current date
    Nc_all : dictionnary
        contact matrices for home, schools, work, transport, leisure and others
    prev_... : float [0,1]
        prevention parameter to estimate (rest = transport, leisure, others)
    school, work, transport, leisure, others : float [0,1]
        level of opening of these sectors
        if None, it is calculated from google mobility data
        only school cannot be None!
    """
    
    if t < pd.Timestamp('2020-03-15'):
        CM = Nc_all['total']
    else:
        
        if school is None:
            raise ValueError(
            "Please indicate to which extend schools are open")
        
        if pd.Timestamp('2020-03-15') < t <= df_google.index[-1]:
            #take t.date() because t can be more than a date! (e.g. when tau_days is added)
            row = -df_google[df_google.index == pd.Timestamp(t.date())]/100 
        else:
            row = -df_google.iloc[[-1],:]/100

        if work is None:
            work=(1-row['work'].values)[0]
        if transport is None:
            transport=(1-row['transport'].values)[0]
        if leisure is None:
            leisure=(1-row['retail_recreation'].values)[0]
        if others is None:
            others=(1-row['grocery'].values)[0]

        CM = (prev_home*(1/2.3)*Nc_all['home'] + 
              prev_schools*school*Nc_all['schools'] + 
              prev_work*work*Nc_all['work'] + 
              prev_rest*transport*Nc_all['transport'] + 
              prev_rest*leisure*Nc_all['leisure'] + 
              prev_rest*others*Nc_all['others']) 


    return CM


In [None]:
def wave2_policies(t,param,df_google, Nc_all, l , tau, 
                   prev_schools, prev_work, prev_rest, prev_home):
    
    # Convert tau and l to dates
    tau_days = pd.Timedelta(tau, unit='D')
    l_days = pd.Timedelta(l, unit='D')

    # Define additional dates where intensity or school policy changes
    t1 = pd.Timestamp('2020-03-15') # start of lockdown
    t2 = pd.Timestamp('2020-05-15') # gradual re-opening of schools (assume 50% of nominal scenario)
    t3 = pd.Timestamp('2020-07-01') # start of summer: COVID-urgency very low
    t4 = pd.Timestamp('2020-08-01')
    t5 = pd.Timestamp('2020-09-01') # september: lockdown relaxation narrative in newspapers reduces sense of urgency
    t6 = pd.Timestamp('2020-10-19') # lockdown
    t7 = pd.Timestamp('2020-11-16') # schools re-open
    t8 = pd.Timestamp('2020-12-18') # schools close
    t9 = pd.Timestamp('2021-01-04') # schools re-open

    if t5 < t <= t6 + tau_days:
        return contact_matrix(t, df_google, Nc_all, school=1)
    elif t6 + tau_days < t <= t6 + tau_days + l_days:
        policy_old = contact_matrix(t, df_google, Nc_all, school=1)
        policy_new = contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_rest, 
                                    school=0)
        return ramp_fun(policy_old, policy_new, t, tau_days, l, t6)
    elif t6 + tau_days + l_days < t <= t7:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_rest, 
                              school=0)
    elif t7 < t <= t8:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_rest, 
                              school=1)
    elif t8 < t <= t9:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_rest, 
                              school=0)
    else:
        return contact_matrix(t, df_google, Nc_all, prev_home, prev_schools, prev_work, prev_rest, 
                              school=1)



In [None]:
# Load the model parameters using `get_COVID19_SEIRD_parameters()`.
params = model_parameters.get_COVID19_SEIRD_parameters()

params.update({'df_google': df_google,
              'Nc_all' : Nc_all,
               'l' : 5,
               'tau' : 5,
               'prev_schools': 0.5,
               'prev_work': 0.5,
               'prev_rest': 0.5,
               'prev_home' : 0.5
              })

# Initialize the model
model = models.COVID19_SEIRD(initial_states, params, time_dependent_parameters={'Nc': wave2_policies})

In [None]:
warmup=0
maxiter = 100
popsize = 100
processes = 32 # voor eraser!!
steps_mcmc = 3000#3000
discard = 1000#1000

# define dataset
data=[df_sciensano['H_in'][start_calibration:end_calibration]]
states = [["H_in"]]

In [None]:
####################################################
####### CALIBRATING BETA AND COMPLIANCE RAMP #######
####################################################

print('------------------------------------')
print('CALIBRATING BETA AND COMPLIANCE RAMP')
print('------------------------------------\n')
print('Using data from '+start_calibration+' until '+end_calibration+'\n')
print('1) Particle swarm optimization\n')

# set PSO optimisation settings
parNames = ['sigma_data','beta','l','tau',
            'prev_schools', 'prev_work', 'prev_rest', 'prev_home']
bounds=((1,2000),(0.010,0.060),(0.1,20),(0.1,20),
        (0,1),(0,1),(0,1),(0,1))

# run PSO optimisation
theta = pso.fit_pso(model,data,parNames,states,bounds,maxiter=maxiter,popsize=popsize,
                    start_date=start_calibration,warmup=warmup, processes=processes) ## PROCESSES=1 to debug!



In [None]:
# run MCMC sampler
print('\n2) Markov-Chain Monte-Carlo sampling\n')
parNames_mcmc = parNames
bounds_mcmc=((1,2000),(0.020,0.060),(0.001,20),(0.001,20),
             (0,1),(0,1),(0,1),(0,1))

ndim = len(theta)
nwalkers = ndim*2
perturbations = ([1]+(ndim-1)*[1e-3]) * np.random.randn(nwalkers, ndim)
pos = theta + perturbations

sampler = emcee.EnsembleSampler(nwalkers, ndim, objective_fcns.log_probability,
                    args=(model, bounds_mcmc, data, states, parNames_mcmc, None, start_calibration, warmup))
sampler.run_mcmc(pos, steps_mcmc, progress=True)

thin = 0
try:
    autocorr = sampler.get_autocorr_time()
    thin = int(0.5 * np.min(autocorr))
except:
    print('Warning: The chain is shorter than 50 times the integrated autocorrelation time for 4 parameter(s).\nUse this estimate with caution and run a longer chain!')
from covid19model.optimization.run_optimization import checkplots
checkplots(sampler, discard, thin, fig_path, spatial_unit, figname='BETA_RAMP_GOOGLE_WAVE2', 
           labels=['$\sigma_{data}$','$\\beta$','l','$\\tau$',
                   'prev_schools', 'prev_work', 'prev_rest', 'prev_home'])

#############################################
####### Output to dictionary ################
#############################################

flat_samples = sampler.get_chain(discard=discard,thin=thin,flat=True)

samples_dict_wave2 = {}
for count,name in enumerate(parNames_mcmc):
    samples_dict_wave2[name] = flat_samples[:,count].tolist()

samples_dict_wave2.update({
    'theta_pso' : list(theta_pso),
    'warmup' : warmup,
    'calibration_data' : states[0][0],
    'start_date' : start_calibration,
    'end_date' : end_calibration,
    'maxiter' : maxiter,
    'popsize': popsize,
    'steps_mcmc': steps_mcmc,
    'discard' : discard,
})

with open(samples_path+str(spatial_unit)+'_'+str(datetime.date.today())+'_WAVE2_GOOGLE.json', 'w') as fp:
    json.dump(samples_dict_wave2, fp)

In [None]:
fig,ax = plt.subplots()
plt.hist(samples_dict_wave2['prev_rest'])

In [None]:
end_sim = '2021-01-01'

fig,ax=plt.subplots(figsize=(10,4))
for i in range(100):
    # Sampling
    idx, model.parameters['beta'] = random.choice(list(enumerate(samples_dict_wave2['beta'])))
    model.parameters['l'] = samples_dict_wave2['l'][idx] 
    model.parameters['tau'] = samples_dict_wave2['tau'][idx]    
    model.parameters['prev_schools'] = samples_dict_wave2['prev_schools'][idx]    
    model.parameters['prev_work'] = samples_dict_wave2['prev_work'][idx]       
    model.parameters['prev_rest'] = samples_dict_wave2['prev_rest'][idx]      

    # Simulate
    y_model = model.sim(end_sim,start_date=start_calibration,warmup=0)
    # Plot
    ax.plot(y_model['time'],y_model["H_in"].sum(dim="Nc"),color='blue',alpha=0.05)

ax.scatter(df_sciensano[start_calibration:end_calibration].index,df_sciensano['H_in'][start_calibration:end_calibration],color='black',alpha=0.6,linestyle='None',facecolors='none')
ax = _apply_tick_locator(ax)
ax.set_xlim('2020-09-01',end_sim)
fig.savefig(fig_path+'others/FIT_WAVE2_GOOGLE_'+spatial_unit+'_'+str(datetime.date.today())+'.pdf', dpi=400, bbox_inches='tight')