**Calibration of the age-stratified deterministic model**

*Original code by Ryan S. McGee. Modified by T.W. Alleman in consultation with the BIOMATH research unit headed by prof. Ingmar Nopens.*

Copyright (c) 2020 by T.W. Alleman, BIOMATH, Ghent University. All Rights Reserved.

This notebook accompanies our preprint: "*A deterministic, age-stratified, extended SEIRD model for assessing the effect of non-pharmaceutical interventions on SARS-CoV-2 spread in Belgium*"(https://doi.org/10.1101/2020.07.17.20156034)

# Load required packages

In [1]:
import random
import os
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
from IPython.display import Image
from ipywidgets import interact,fixed,FloatSlider,IntSlider,ToggleButtons
import pandas as pd
from datetime import datetime, timedelta
import scipy
from scipy.integrate import odeint
import matplotlib.dates as mdates
import matplotlib
import scipy.stats as st

import math
import xarray as xr
import emcee
import json
import corner

from covid19model.optimization import objective_fcns
from covid19model.optimization import MCMC
from covid19model.models import models
from covid19model.data import google
from covid19model.data import sciensano
from covid19model.data import model_parameters
from covid19model.visualization.output import population_status, infected
from covid19model.visualization.optimization import plot_fit, traceplot
from covid19model.optimization.run_optimization import full_calibration, full_calibration_wave2
from covid19model.models.utils import draw_sample_COVID19_SEIRD


# OPTIONAL: Load the "autoreload" extension so that package code can change
%load_ext autoreload
# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

In [None]:
# import dask

# from dask.distributed import Client, progress
# client = Client(threads_per_worker=16, n_workers=1)
# client

# Get public data

In [None]:
#df_sciensano = sciensano.get_sciensano_COVID19_data(update=False)


In [2]:
raw_hosp_data = pd.read_csv('../data/raw/sciensano/COVID19BE_HOSP.csv', parse_dates=['DATE'])
Luik_hosp_data = raw_hosp_data[raw_hosp_data.PROVINCE=='Liège']

In [None]:
#Luik_hosp_data.plot('DATE','NEW_IN')

In [3]:
Luik_ts = Luik_hosp_data[['DATE','NEW_IN']].set_index('DATE')

In [41]:
fig,ax = plt.subplots()
Luik_ts.reset_index() .plot('DATE','NEW_IN', logy=True, ax=ax)
ax.set_xlim('2020-08-01', '2020-11-01')

<IPython.core.display.Javascript object>

(18475.0, 18567.0)

# Load data

In [4]:
# Load the interaction matrices (size: 9x9)
initN, Nc_home, Nc_work, Nc_schools, Nc_transport, Nc_leisure, Nc_others, Nc_total = model_parameters.get_interaction_matrices(dataset='willem_2012')
# Define the number of age categories
levels = initN.size
#province_names = pd.read_csv('../data/raw/GIS/NIS_province.csv')
initN_province = pd.read_csv('../data/interim/demographic/initN_province.csv')
initN = np.array(list(initN_province.set_index('NIS').loc[60000].iloc[:9]))

# Initialize the model

In [5]:
# Define the compliance and lockdown function
def lockdown_func(t,param,policy_time,policy1,policy2,l,tau,prevention):
    if t <= policy_time + tau:
        return policy1
    elif policy_time + tau < t <= policy_time + tau + l:
        return policy1 + (prevention*policy2-policy1)/l*(t-policy_time-tau)
    else:
        return prevention*policy2

In [6]:
# Load the parameters using `get_COVID19_SEIRD_parameters()`.
params = model_parameters.get_COVID19_SEIRD_parameters()
params.update({'policy1': Nc_total,
              'policy2': 1.0*Nc_home + (1-0.60)*Nc_work + (1-0.70)*Nc_transport + (1-0.30)*Nc_others + (1-0.80)*Nc_leisure,
              'policy_time': 500,
              'l': 1,
              'tau': 5,
              'prevention': 0.5})
# Define the initial condition: one exposed inidividual in every age category
initial_states = {'S': initN, 'E': np.ones(levels)}
model = models.COVID19_SEIRD(initial_states, params, time_dependent_parameters={'Nc': lockdown_func})

In [7]:
fig_path = '../results/calibrations_provinces/'
samples_path = '../data/interim/model_parameters/provinces/'

In [8]:
timeseries = Luik_ts['NEW_IN']
spatial_unit = 'Luik_wave1_1000steps'

# Calibration on first wave

In [None]:
## First wave
start_date = '2020-03-15'
end_beta = '2020-03-25'#'2020-03-22'
end_ramp = '2020-05-23'

In [None]:
len(Luik_ts.loc[start_date:end_beta])

In [None]:
len(Luik_ts.loc[end_beta:end_ramp])

In [None]:
# function parameters 
# maxiter=100
# popsize=200
# steps_mcmc=10000


In [None]:
samples_dict = full_calibration_wave1(model, timeseries, spatial_unit, start_date, end_beta, end_ramp, 
                                fig_path=fig_path, samples_path=samples_path,
                                initN=initN, Nc_total=Nc_total,
                                maxiter=50, popsize=50, steps_mcmc=1000)



In [None]:
states = [['H_in']]
end_date = '2020-07-01'
data=[timeseries[start_date:end_ramp].values]

fig,ax=plt.subplots()
for i in range(10):
    idx,model.parameters['beta'] = random.choice(list(enumerate(samples_dict['beta'])))
    idx,model.parameters['l'] = random.choice(list(enumerate(samples_dict['l'])))
    model.parameters['tau'] = samples_dict['tau'][idx]
    model.parameters['policy_time'] = samples_dict['lag_time']
    model.parameters['prevention'] = samples_dict['prevention'][idx]
    y_model = model.sim(time=end_date, excess_time=samples_dict['lag_time'], start_date=start_date)
    ax = plot_fit(y_model,data,start_date,samples_dict['lag_time'],states,end_date=end_date,with_ints=False,ax=ax,plt_kwargs={'color':'blue','linewidth': 2,'alpha': 0.05})

data_after_calib = timeseries[pd.to_datetime(end_ramp)+pd.to_timedelta('1d'):end_date]
plt.scatter(data_after_calib.index, data_after_calib.values, marker='o',color='red',linestyle='None',facecolors='none')
legend_text=['daily \nhospitalizations']
ax.set_xlim('2020-03-10', '2020-07-01')
# fig.savefig(fig_path+spatial_unit+'.pdf',
#             bbox_inches='tight', dpi=600)

# Run simulation until first september

In [9]:
# Load the dictionary containing the posterior parameter distributions obtained from calibrating the model to Belgian hospitalization data
with open('../data/interim/model_parameters/provinces/Luik_wave1_1000steps_2020-10-28.json', 'r') as fp:
    samples_dict_wave1 = json.load(fp)

In [12]:
states = [['H_in']]
start_date = samples_dict_wave1['start_date']
end_ramp = samples_dict_wave1['end_ramp']
end_date = '2020-09-01'

data=[timeseries[start_date:end_ramp].values]

fig,ax=plt.subplots()
for i in range(10):
    idx,model.parameters['beta'] = random.choice(list(enumerate(samples_dict_wave1['beta'])))
    idx,model.parameters['l'] = random.choice(list(enumerate(samples_dict_wave1['l'])))
    model.parameters['tau'] = samples_dict_wave1['tau'][idx]
    model.parameters['policy_time'] = samples_dict_wave1['lag_time']
    model.parameters['prevention'] = samples_dict_wave1['prevention'][idx]
    y_model = model.sim(time=end_date, excess_time=samples_dict_wave1['lag_time'], start_date=start_date)
    ax = plot_fit(y_model,data,start_date,samples_dict_wave1['lag_time'],states,end_date=end_date,with_ints=False,ax=ax,plt_kwargs={'color':'blue','linewidth': 2,'alpha': 0.05})

data_after_calib = timeseries[pd.to_datetime(end_ramp)+pd.to_timedelta('1d'):end_date]
plt.scatter(data_after_calib.index, data_after_calib.values, marker='o',color='red',linestyle='None',facecolors='none')
legend_text=['daily \nhospitalizations']
ax.set_xlim('2020-03-10', end_date)
# fig.savefig(fig_path+spatial_unit+'.pdf',
#             bbox_inches='tight', dpi=600)

<IPython.core.display.Javascript object>

(18331.0, 18506.0)

In [13]:
# Update this parameter here, the rest is updated by draw_sample_COVID19_SEIRD
model.parameters['policy_time'] = samples_dict_wave1['lag_time']

In [14]:
out_sept = model.sim(time='2020-09-01', excess_time=samples_dict_wave1['lag_time'],start_date='2020-03-15',
                   N=100, draw_fcn=draw_sample_COVID19_SEIRD,samples=samples_dict_wave1)
states = out_sept.isel(time=-1).mean(dim="draws")
initial_states_sept = {key:states[key].values for key in initial_states.keys()}

In [26]:
out_aug = model.sim(time='2020-08-01', excess_time=samples_dict_wave1['lag_time'],start_date='2020-03-15',
                   N=100, draw_fcn=draw_sample_COVID19_SEIRD,samples=samples_dict_wave1)
states = out_aug.isel(time=-1).mean(dim="draws")
initial_states_aug = {key:states[key].values for key in initial_states.keys()}

In [None]:
#np.sum(initial_states_sept['S'])/np.sum(initial_states['S'])*100

# Calibration on second wave

In [56]:
# Second wave
start_date = '2020-08-01'
end_beta = '2020-10-25'

In [57]:
len(Luik_ts.loc[start_date:end_beta])

86

In [48]:
# Load the parameters using `get_COVID19_SEIRD_parameters()`.
params = model_parameters.get_COVID19_SEIRD_parameters()
model_sept = models.COVID19_SEIRD(initial_states_sept, params)
model_aug = models.COVID19_SEIRD(initial_states_aug, params)

In [58]:
# function parameters 
timeseries = Luik_ts['NEW_IN']
spatial_unit = 'Luik_wave2_from_aug_1000steps'
# maxiter=100
# popsize=200
# steps_mcmc=10000


In [32]:
beta_init = np.mean(samples_dict_wave1['beta'])
sigma_data_init = np.mean(samples_dict_wave1['sigma_data'])

In [33]:
samples_dict = full_calibration_wave2(model_aug, timeseries, spatial_unit, start_date, end_beta, 
                                      beta_init, sigma_data_init,
                                      fig_path=fig_path, samples_path=samples_path,
                                      initN=initN, Nc_total=Nc_total,
                                      maxiter=100, popsize=100, steps_mcmc=1000)

100%|██████████| 1000/1000 [05:41<00:00,  2.93it/s]




In [59]:
# Load the dictionary containing the posterior parameter distributions obtained from calibrating the model to Belgian hospitalization data
with open('../data/interim/model_parameters/provinces/Luik_wave2_from_aug_1000steps_2020-10-29.json', 'r') as fp:
    samples_dict = json.load(fp)


In [60]:
states = [['H_in']]
end_date = '2021-01-01' #end_beta
data=[timeseries[start_date:end_beta].values]

fig,ax=plt.subplots(figsize=(8,8))
for i in range(200):
    idx,model_aug.parameters['beta'] = random.choice(list(enumerate(samples_dict['beta'])))
    y_model = model_aug.sim(time=end_date, excess_time=0, start_date=start_date)
    ax = plot_fit(y_model,data,start_date,lag_time=0,states=states,end_date=end_date,
                  with_ints=False,ax=ax,plt_kwargs={'color':'blue','linewidth': 2,'alpha': 0.05})
plt.setp(plt.gca().xaxis.get_majorticklabels(),'rotation', 0)
#data_after_calib = timeseries[pd.to_datetime(end_ramp)+pd.to_timedelta('1d'):end_date]
#plt.scatter(data_after_calib.index, data_after_calib.values, marker='o',color='red',linestyle='None',facecolors='none')
legend_text=['daily \nhospitalizations']
ax.set_xlim('2020-08-01', '2021-01-01')
# fig.savefig(fig_path+spatial_unit+'.pdf',
#             bbox_inches='tight', dpi=600)

<IPython.core.display.Javascript object>

(18475.0, 18628.0)

In [55]:
fig.savefig(fig_path+spatial_unit+'.pdf',
            bbox_inches='tight', dpi=600)

# Run in parallel

In [None]:
arr_list = list(nonpublic_ts.NIS.unique())

In [None]:
def run_me_parallel(arr):
    arrond_ts = nonpublic_ts.pivot(index='DATE', columns='NIS', values='hospitalised_IN')[arr]
    samples_dict = full_calibration(model, arrond_ts, arr, start_date, end_beta, end_ramp, 
                                fig_path, samples_path)
                                #maxiter=10, popsize=10, steps_mcmc=250)
    
    states = [['H_in']]
    end_date = '2020-09-20'
    data=[arrond_ts[start_date:end_ramp].values]
    fig,ax=plt.subplots()
    for i in range(200):
        idx,model.parameters['beta'] = random.choice(list(enumerate(samples_dict['beta'])))
        idx,model.parameters['l'] = random.choice(list(enumerate(samples_dict['l'])))
        model.parameters['tau'] = samples_dict['tau'][idx]
        prevention = samples_dict['prevention'][idx]
        # Create a dictionary of past policies
        chk = {'time':   [start_date], 
              'Nc':      [prevention*(Nc_home + 0.4*Nc_work + 0.3*Nc_transport + 0.7*Nc_others + 0.2*Nc_leisure)]
              }
        y_model = model.sim(time=end_date, excess_time=samples_dict['lag_time'],checkpoints=chk)
        ax = plot_fit(y_model,data,start_date,samples_dict['lag_time'],states,end_date=end_date,with_ints=False,ax=ax,plt_kwargs={'color':'blue','linewidth': 2,'alpha': 0.05})

    data_after_calib = arrond_ts[pd.to_datetime(end_ramp)+pd.to_timedelta('1d'):end_date]
    plt.scatter(data_after_calib.index, data_after_calib.values, marker='o',color='red',linestyle='None',facecolors='none')
    legend_text=['daily \nhospitalizations']
    ax.set_xlim('2020-03-10', '2020-08-03')
    fig.savefig('../results/calibrations_arrondissements/'+str(arr)+'_'+str(datetime.date.today())+'.pdf',
                bbox_inches='tight', dpi=600)
    return

In [None]:
tasks = []
for arr in arr_list:
    task = dask.delayed(run_me_parallel)(arr)
    tasks.append(task)
    

In [None]:
dask.compute(*tasks, scheduler='processes')

In [None]:
run_date = '2020-08-22'

In [None]:
with open(samples_path+'44000'+'_'+run_date+'.json', 'r') as fp:
    samples_dict = json.load(fp)