In [None]:
import pandas as pd
import json

import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"


import matplotlib.dates as mdates
import seaborn as sns

import arviz as az
import numpyro
from numpyro.infer import MCMC, NUTS, init_to_median

import numpy as np
from jax import random
import jax.numpy as jnp

# Define Data

In [None]:
countries = ["Netherlands", "United Kingdom", "Czech Republic", "Ireland", "Belgium", "Hungary"]
Ds = pd.date_range('2021-01-01', '2021-05-23')

# Load Case Data

In [None]:
oxcgrt_df = pd.read_csv('OxCGRT_latest.csv') # to run this notebook, please download 
oxcgrt_df["Date"] = pd.to_datetime(oxcgrt_df["Date"], format="%Y%m%d")
oxcgrt_df = oxcgrt_df.set_index(['CountryName', 'Date'])

new_cases = np.zeros((len(countries), len(Ds)))
for c_i, c in enumerate(countries):
    if c != "United Kingdom":
        new_cases[c_i, :] = np.array(oxcgrt_df.loc[c]["ConfirmedCases"].diff().loc[Ds])
    else:
        uk_df = oxcgrt_df.loc["United Kingdom"].reset_index()
        uk_df = uk_df.set_index(['RegionName', 'Date'])
        new_cases[c_i, :] = np.array(uk_df.loc["England"]["ConfirmedCases"].diff().loc[Ds])

# observe 30 days of cases, only.
new_cases = np.ma.array(new_cases)
new_cases[:, :10] = np.ma.masked
new_cases[new_cases<1] = np.ma.masked

# Load NPI data

In [None]:
npi_df = pd.read_csv('npi_data.csv')
npi_df["Date"] = pd.to_datetime(npi_df["Date"], format="%Y%m%d")
npi_df = npi_df.set_index(['CountryName', 'Date'])

active_cms_sw = np.zeros((len(countries), 19, len(Ds)))
active_cms_fw = np.zeros((len(countries), 8, len(Ds)))

sw_npis = list(npi_df.columns[:19])
fw_npis = list(npi_df.columns[19:])

# only have NPI data til 135
for c_i, c in enumerate(countries):
    for npi_i, npi in enumerate(sw_npis):
        active_cms_sw[c_i, npi_i, :135] = npi_df.loc[c][npi]
        active_cms_sw[c_i, npi_i, 135:] = active_cms_sw[c_i, npi_i, 134]
        
for c_i, c in enumerate(countries):
    for npi_i, npi in enumerate(fw_npis):
        active_cms_fw[c_i, npi_i, :135] = npi_df.loc[c][npi]
        active_cms_fw[c_i, npi_i, 135:] = active_cms_fw[c_i, npi_i, 134]
        
active_cms_all = np.zeros((len(countries), 27, len(Ds)))
active_cms_all[:, :19, :] = active_cms_sw
active_cms_all[:, -8:, :] = active_cms_fw

In [None]:
def active_cms_to_shock_cms(active_cms, cm_alphas):
    nCs, nCMs, nDs = active_cms.shape
    
    shock_cms = np.zeros((nCs, 0, nDs))
    
    shock_cms_info = []
    cm_index = 0
    
    mask = np.ones(nCMs)
    mask[5:8] = 0 #secondwave schools
    mask[-3:-1] = 0 #firstwave schools
    
    for c_i in range(nCs):
        total_cms_active = np.sum(active_cms[c_i, :, :], axis=0)
        total_cms_changed = np.zeros_like(total_cms_active)
        total_cms_changed[1:] = np.diff(total_cms_active)
        cm_changes = np.nonzero(total_cms_changed)[0]
        
        for change in cm_changes:
            shock_cm = np.zeros((nCs, 1, nDs))
            shock_cm[c_i, 0, change:] = 1
            shock_cms = np.append(shock_cms, shock_cm, axis=1)
            
            cm_delta_sw = np.sum(active_cms[c_i, :19, change-1]*cm_alphas[:, :19], axis=-1) - np.sum(active_cms[c_i, :19, change]*cm_alphas[:, :19], axis=-1)
            cm_delta_fw = np.sum(active_cms[c_i, -8:, change-1]*cm_alphas[:, -8:], axis=-1) - np.sum(active_cms[c_i, -8:, change]*cm_alphas[:, -8:], axis=-1)
                
            act_cm_delta_nm = np.sum(active_cms[c_i, :, change-1]) - np.sum(active_cms[c_i, :, change])
            act_cm_delta = np.sum(active_cms[c_i, :, change-1]*mask) - np.sum(active_cms[c_i, :, change]*mask)
            
            total_change = max(np.abs(np.sum(active_cms[c_i, :19, change-1]) - np.sum(active_cms[c_i, :19, change])), 
                                                np.abs(np.sum(active_cms[c_i, -8:, change-1]) - np.sum(active_cms[c_i, -8:, change])))
            if act_cm_delta < 0:
                reopening = True
            else:
                reopening = False
            
            shock_cms_info.append((cm_index, cm_delta_sw, cm_delta_fw, reopening, change, c_i, total_change))
            cm_index += 1
        
    return shock_cms, shock_cms_info

In [None]:
full_res = az.from_netcdf('final_results/final_results2.netcdf') # to run this file, you will need to do a main model run, and save the results as a .netcdf file
cm_alpha_sw = full_res.posterior.alpha_i.data.reshape((5000, 19))
cm_alpha_sw[:, 5] = np.sum(cm_alpha_sw[:, 5:8], axis=-1)/3
cm_alpha_sw[:, 6] = np.sum(cm_alpha_sw[:, 5:8], axis=-1)/3
cm_alpha_sw[:, 7] = np.sum(cm_alpha_sw[:, 5:8], axis=-1)/3

cm_alpha_fw = -np.log(np.loadtxt('fw_alpha.txt'))
cm_alpha_fw[:, -3] = 2*np.sum(cm_alpha_fw[:, -3:-1], axis=-1)/3
cm_alpha_fw[:, -2] = 1*np.sum(cm_alpha_fw[:, -3:-1], axis=-1)/3
cm_alpha_all = np.zeros((5000, 27))
cm_alpha_all[:, :19] = cm_alpha_sw
cm_alpha_all[:, -8:] = cm_alpha_fw[:5000, :]

# Define Model

In [None]:
import sys
sys.path.append("../") 

from epimodel import preprocess_data, run_model, EpidemiologicalParameters, default_model
from epimodel.models.model_build_utils import *

ep = EpidemiologicalParameters()

In [None]:
def shock_rw_model(ep, new_cases, shock_cms, r_walk_period=7, seeding_scale=4):
    nRs, nDs = new_cases.shape
    _, nShock, _ = shock_cms.shape
    
    n_days_seeding = 7
    alpha_delta = numpyro.sample("alpha_delta", dist.Normal(loc=0, scale=0.25*jnp.ones(nShock)))
#     alpha_delta = numpyro.sample("alpha_delta", dist.Uniform(low=-0.5, high=0.5*jnp.ones(nShock)))
    
    basic_R = numpyro.sample("basic_R", dist.Normal(loc=1.2*np.ones(nRs), scale=0.5))
    cm_reduction = jnp.sum(shock_cms * alpha_delta.reshape((1, nShock, 1)), axis=1)
    nNP = int(nDs / r_walk_period) - 1
    r_walk_noise_scale = numpyro.sample("r_walk_noise_scale", dist.HalfNormal(0.15))

    r_walk_noise = numpyro.sample(
        "r_walk_noise",
        dist.Normal(loc=jnp.zeros((1, nNP)), scale=1.0 / 10),
    )

    expanded_r_walk_noise = jnp.repeat(
        r_walk_noise_scale * 10.0 * jnp.cumsum(r_walk_noise, axis=-1),
        r_walk_period,
        axis=-1,
    )[: nRs, : (nDs - 2 * r_walk_period)]

    full_log_Rt_noise = jnp.zeros((nRs, nDs))
    full_log_Rt_noise = jax.ops.index_update(
        full_log_Rt_noise, jax.ops.index[:, 2 * r_walk_period :], expanded_r_walk_noise
    )

    Rt = numpyro.deterministic(
        "Rt",
        jnp.exp(
            jnp.log(basic_R.reshape((nRs, 1))) + full_log_Rt_noise - cm_reduction
        ),
    )

    seeding_padding = n_days_seeding
    total_padding = ep.GIv.size - 1

    init_infections, total_infections_placeholder = seed_infections(
        seeding_scale, nRs, nDs, seeding_padding, total_padding
    )
    discrete_renewal_transition = get_discrete_renewal_transition(ep)

    _, infections = jax.lax.scan(
        discrete_renewal_transition,
        init_infections,
        Rt.T,
    )

    total_infections = jax.ops.index_update(
        total_infections_placeholder,
        jax.ops.index[:, :seeding_padding],
        init_infections[:, -seeding_padding:],
    )
    total_infections = numpyro.deterministic(
        "total_infections",
        jax.ops.index_update(
            total_infections, jax.ops.index[:, seeding_padding:], infections.T
        ),
    )

    future_cases_t = numpyro.deterministic("future_cases_t", total_infections)

    expected_cases = numpyro.deterministic(
        "expected_cases",
        jax.scipy.signal.convolve2d(future_cases_t, ep.DPC, mode="full")[
            :, seeding_padding : seeding_padding + nDs
        ],
    )

    psi_cases = numpyro.sample(
        "psi_cases",
        dist.HalfNormal(scale=5. * jnp.ones((len(countries), 1)))
    )

    with numpyro.handlers.mask(mask=jnp.logical_not(new_cases.mask)):
        numpyro.sample(
            "observed_cases",
            dist.GammaPoisson(
                concentration=psi_cases,
                rate=psi_cases / expected_cases,
            ),
            obs=new_cases.data,
        )

# Run Model

In [None]:
shock_cms_comb, shock_cms_info_comb = active_cms_to_shock_cms(active_cms_all, cm_alpha_all)

In [None]:
shock_cms_comb, shock_cms_info_comb = active_cms_to_shock_cms(active_cms_all, cm_alpha_all)

nuts_kernel = NUTS(
        shock_rw_model,
        init_strategy=init_to_median,
        max_tree_depth=15,
    )

mcmc = MCMC(
    nuts_kernel,
    num_samples=500,
    num_warmup=250,
    num_chains=1,
    chain_method="sequential"
)
rng_key = random.PRNGKey(0)

mcmc.run(rng_key, ep, new_cases, shock_cms_comb)
samples_shock = mcmc.get_samples()

# Validate Results

In [None]:
for c_i in range(len(countries)):
    plt.figure(figsize=(4.75, 2.75), dpi=200)
    plt.subplot(121)
    rt = samples_shock['Rt'][:, c_i, :]
    plt.fill_between(Ds, np.percentile(rt, 2.5, axis=0), np.percentile(rt, 97.5, axis=0), alpha=0.2, color="tab:purple", linewidth=0)
    plt.plot(Ds, np.percentile(rt, 50, axis=0), color="tab:purple")
    plt.ylim([0.25, 2])
    ax = plt.gca()
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%d %b"))
    plt.xticks(fontsize=8, rotation=-45, ha='left')
    plt.ylabel('$R_t$')

    for _, _, _, _, d, c in shock_cms_info_comb:
        if c == c_i:
            plt.axvline(Ds[d], color='k', linewidth=0.5, alpha=0.5)
            
    plt.xlim([Ds[7], Ds[-1]])
    plt.xticks(Ds[7::20], fontsize=8, rotation=-45, ha='left')

        
    plt.subplot(122)
    ac = new_cases[c_i, :]
    ec_fw = samples_shock['expected_cases'][:, c_i, :]
    plt.fill_between(Ds, np.percentile(ec_fw, 2.5, axis=0), np.percentile(ec_fw, 97.5, axis=0), alpha=0.2, color="tab:orange")
    plt.plot(Ds, np.percentile(ec_fw, 50, axis=0), color="tab:orange")


    plt.scatter(Ds[~ac.mask], ac[~ac.mask], s=4, color='tab:blue', marker='.')
    ax = plt.gca()
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%d %b"))
    plt.xlim([Ds[7], Ds[-1]])
    plt.xticks(fontsize=8, rotation=-45, ha='left')
    plt.xticks(Ds[7::20], fontsize=8, rotation=-45, ha='left')
    
    plt.ylabel('Reported Cases')
    plt.ylim([0, 1.15*np.max(new_cases[c_i, 7:-10].data)])

    plt.suptitle(countries[c_i])
    plt.tight_layout()

# Summarise Results

In [None]:
ignore_reopenings = False

last_day = 120 #may 1st
# second wave
sw_shock_values = 100*(1-np.exp(-np.median(samples_shock['alpha_delta'], axis=0)))
sw_shock_values_pred = 100*(1-np.exp(np.array([x for _, x, _, _, _, _, _ in shock_cms_info_comb]))).T
sw_reopening_mask = np.array([x for _, _,_, x, _, _, _ in shock_cms_info_comb])

npis_changd = np.array([x for _, _,_, _, _, _, x in shock_cms_info_comb])

if ignore_reopenings:
    sw_shock_values[sw_reopening_mask] = np.nan

for s_i in range(len(shock_cms_info_comb)):
    if shock_cms_info_comb[s_i][4] >= last_day:
        sw_shock_values[s_i] = np.nan
        sw_shock_values_pred[:, s_i] = np.nan
        npis_changd[s_i] = np.nan
    
sw_made = np.nanmean(np.abs(sw_shock_values - np.median(sw_shock_values_pred, axis=0)))
sw_mse = np.nanmean(np.power(sw_shock_values - np.median(sw_shock_values_pred, axis=0), 2))
sw_mse_ps = np.power(sw_shock_values - np.median(sw_shock_values_pred, axis=0), 2)

# first wave
fw_shock_values = 100*(1-np.exp(-np.median(samples_shock['alpha_delta'], axis=0)))
fw_shock_values_pred = 100*(1-np.exp(np.array([x for _, _, x, _, _, _, _ in shock_cms_info_comb]))).T
fw_reopening_mask = np.array([x for _,_, _, x, _, _, _ in shock_cms_info_comb])

if ignore_reopenings:
    fw_shock_values[fw_reopening_mask] = np.nan
    
for s_i in range(len(shock_cms_info_comb)):
    if shock_cms_info_comb[s_i][4] >= last_day:
        fw_shock_values[s_i] = np.nan
        fw_shock_values_pred[:, s_i] = np.nan

fw_made = np.nanmean(np.abs(fw_shock_values - np.median(fw_shock_values_pred, axis=0)))
fw_mse = np.nanmean(np.power(fw_shock_values - np.median(fw_shock_values_pred, axis=0), 2))
fw_mse_ps = np.power(fw_shock_values - np.median(fw_shock_values_pred, axis=0), 2)

obs_mean = np.nanmean(np.abs(sw_shock_values))
sw_mean = np.nanmean(np.abs(sw_shock_values_pred))
fw_mean = np.nanmean(np.abs(fw_shock_values_pred))

In [None]:
def compute_under_overestimate(fw, sw, observed):
    nS = observed.size
    errors_fw = np.zeros(nS)
    errors_sw = np.zeros(nS)
    
    pred_meds_fw = np.median(fw, axis=0)
    pred_meds_sw = np.median(sw, axis=0)
    
    for i in range(nS):
        closing = np.logical_or(pred_meds_fw[i] >0, pred_meds_sw[i] > 0)
        
        if closing:
            fw_err = pred_meds_fw[i] - observed[i]
            sw_err = pred_meds_sw[i] - observed[i]
        else:
            fw_err = -pred_meds_fw[i] + observed[i]
            sw_err = -pred_meds_sw[i] + observed[i]
        
        errors_fw[i] = fw_err
        errors_sw[i] = sw_err
    
    return np.nanmean(errors_fw), np.nanmean(errors_sw)


In [None]:
compute_under_overestimate(fw_shock_values_pred, sw_shock_values_pred, fw_shock_values)

In [None]:
ignore_reopenings = False

last_day = 120 #may 5th. could be 120 for may first
# second wave
sw_shock_values = 100*(1-np.exp(-np.median(samples_shock['alpha_delta'], axis=0)))
sw_shock_values_pred = 100*(1-np.exp(np.array([x for _, x, _, _, _, _ in shock_cms_info_comb]))).T
sw_reopening_mask = np.array([x for _, _,_, x, _, _ in shock_cms_info_comb])

if ignore_reopenings:
    sw_shock_values[sw_reopening_mask] = np.nan

for s_i in range(len(shock_cms_info_comb)):
    if shock_cms_info_comb[s_i][4] >= last_day:
        sw_shock_values[s_i] = np.nan
    
sw_mse_ps = np.power(sw_shock_values - np.median(sw_shock_values_pred, axis=0), 2)

# first wave
fw_shock_values = 100*(1-np.exp(-np.median(samples_shock['alpha_delta'], axis=0)))
fw_shock_values_pred = 100*(1-np.exp(np.array([x for _, _, x, _, _, _ in shock_cms_info_comb]))).T
fw_reopening_mask = np.array([x for _,_, _, x, _, _ in shock_cms_info_comb])

if ignore_reopenings:
    fw_shock_values[fw_reopening_mask] = np.nan
    
for s_i in range(len(shock_cms_info_comb)):
    if shock_cms_info_comb[s_i][4] >= last_day:
        fw_shock_values[s_i] = np.nan
    
fw_mse_ps = np.power(fw_shock_values - np.median(fw_shock_values_pred, axis=0), 2)

shock_c = np.array([x[-1] for x in shock_cms_info_comb])
country_relative_mses = np.zeros(len(countries))
for c_i, c in enumerate(countries):
    country_mask = shock_c == c_i
    fw_mse_ps = np.power(fw_shock_values - np.median(fw_shock_values_pred, axis=0), 2)
    fw_mse_ps[~country_mask] = np.nan
    fw_mse_pc = np.nanmean(fw_mse_ps)
    
    sw_mse_ps = np.power(sw_shock_values - np.median(sw_shock_values_pred, axis=0), 2)
    sw_mse_ps[~country_mask] = np.nan
    sw_mse_pc = np.nanmean(sw_mse_ps)
    
    country_relative_mses[c_i] = sw_mse_pc / fw_mse_pc

In [None]:
plt.figure(figsize=(4, 3), dpi=500)
plt.title('Relative Average Prediction Error')
# plt.barh(countries, country_relative_mses, color='tab:gray')
plt.scatter(country_relative_mses, countries, color='tab:gray', marker='d')
plt.plot([1, 1], [-3, 8], 'k--', linewidth=0.5)
plt.ylim((-0.5, 5.5))
plt.xlim([0, 2])

plt.xticks([0, 1, 2], ["0%", "100%", "200%"])
plt.xlabel("\nSecond wave MSE as a percentage of first wave MSE")
plt.text(0, -1.3, "Second wave predicts better", fontsize=6, ha='center')
plt.text(1, -1.3, "Waves predict equally well", fontsize=6, ha='center')
plt.text(2, -1.3, "First wave predicts better", fontsize=6, ha='center')
plt.savefig('FigRelativePred.pdf', bbox_inches='tight')