In [None]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append("../../") 

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import numpy as np
import arviz as az

import yaml
import json

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples, plot_intervention_effectiveness, plot_intervention_correlation

from tqdm import tqdm
import pandas as pd

In [None]:
len(all_info_dicts)

In [None]:
base_dir = "sensitivity_analysis/default_cdefault"

In [None]:
all_info_dicts = []

for subdir, dirs, files in os.walk(base_dir):
    for f in files:
        if f.endswith('.json'):
            with open(os.path.join(subdir, f), 'r') as f:
                info_dict = json.load(f)
                
            if 'alpha_i' in info_dict.keys():
                all_info_dicts.append(info_dict)

In [None]:
len(all_info_dicts)

In [None]:
sns.color_palette('colorblind')

In [None]:
cols = sns.color_palette('colorblind')

In [None]:
res_yaml = yaml.safe_load(open('final_results_summary.yaml', 'r'))

data = preprocess_data('../../data/all_merged_data_2021-01-22.csv')
data.featurize()
data.mask_new_variant(new_variant_fraction_fname='../../data/nuts3_new_variant_fraction.csv')
data.mask_from_date('2021-01-09') 

In [None]:
sol_two_main_dict = {
    'All non-essential\nbusinesses closed': {
        'npis': ['Retail Closed', 'Some Face-to-Face Businesses Closed', 
                'Gastronomy Closed', 'Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
        'main': True,
    },
    'Night clubs closed': {
        'npis': ['Some Face-to-Face Businesses Closed'],
        'type': "exclude",
        'color': cols[0],
        'main': False,
    },
    'Leisure and entertainment\nvenues closed': {
        'npis': ['Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
         'main': False,
    },
    'Gastronomy closed': {
        'npis': ['Gastronomy Closed'],
        'type': 'exclude',
        'color': cols[0],
         'main': False,
    },
    'Retail and close-contact\nservices closed': {
        'npis': ['Retail Closed'],
        'type': 'exclude',
        'color': cols[0],
        'main': False,
    },
    'All gatherings banned': {
        'npis': ['Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
        'main': True
    },
    'All gatherings limited to 2 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        "type": "exclude",
        'color': cols[3],
        'main': False
    },
    'All gatherings limited to ≤10 people\nfrom 2 households': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
        'main': False
    },
    'All gatherings limited to ≤10 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
        'main': False
    },
    'All gatherings limited to ≤30 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 #'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 #'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
        'main': False
    },
    'All educational\ninstitutions closed': {
        'npis': ['Primary Schools Closed', 'Secondary Schools Closed', 'Universities Away'],
        'type': 'exclude',
        'color': cols[2],
        'main': True
    },
    'Night time curfew': {
        'npis': ['Curfew'],
        'type': 'exclude',
        'color': cols[1],
        'main': True
    },
    'Stricter mask-wearing\npolicy': {
        'npis': ['Mandatory Mask Wearing >= 3'],
        'type': 'exclude',
        'color': cols[5],
        'main': True
    },
}

gatherings_dict = {
    'All public gatherings banned': {
        'npis': ['Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                ],
        'type': 'exclude',
        'color': cols[4],
    },
    'Public gatherings limited to 2 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                ],
        "type": "exclude",
        'color': cols[4],
    },
    'Public gatherings limited to ≤10 people\nfrom 2 households': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                ],
        'type': 'exclude',
        'color': cols[4],
    },
    'Public gatherings limited to ≤10 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 10',
                 'Public Indoor Gathering Person Limit - 30',
                ],
        'type': 'exclude',
        'color': cols[4],
    },
    'Public gatherings limited to ≤30 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 #'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                ],
        'type': 'exclude',
        'color': cols[4],
    },
    'All household mixing in private banned': {
        'npis': ['Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[5],
    },
    'Household mixing in private\nlimited to 2 people': {
        'npis': [#'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        "type": "exclude",
        'color': cols[5],
    },
    'Household mixing in private\nlimited to ≤10 people from 2 households': {
        'npis': [#'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[5],
    },
    'Household mixing in private\nlimited to ≤10 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 10',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[5],
    },
    'Household mixing in private\nlimited to ≤30 people': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 #'Public Indoor Gathering Person Limit - 6',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 #'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[5],
    }
}

In [None]:
np.unique([d["exp_tag"] for d in all_info_dicts])

In [None]:
exp_main_tags = {
    "Unrecorded factors": ['npi_leaveout', "eng_ifr_iar"],
    "Data": ["maximum_fraction_voc", "boostrap", "delay_schools"],
    "Delay distributions": ["gen_int_mean", "cases_delay_mean", "death_delay_mean"],
    "Prior distributions": ["basic_R_prior_mean", "basic_R_prior_scale", "infection_noise_scale", 
                                    "output_noise_scale_prior", "r_walk_noise_scale_prior",
                                    "intervention_prior", "seeding_scaling"],
    "Model structure": ["r_walk_period", "seeding_days"]
}

npi_median_data_list_coarse_main_panel = []
npi_median_data_list_coarse_gath_panel = []

n_exps = 0
for d_i, d in tqdm(enumerate(all_info_dicts)):
    e = d["exp_tag"]        
    alpha_i = np.array(d["alpha_i"])
    cm_names = d["cm_names"]
    
    comb_cm_effects, new_names = combine_npi_samples(sol_two_main_dict, alpha_i, cm_names)
    med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
    
    main_tag = None
    for name, v in exp_main_tags.items():
        if d["exp_tag"] in v:
            main_tag = name
    
    if main_tag is None:
        continue
    
    n_exps += 1
    
    for cm_i, name in enumerate(new_names):
        npi_median_data_list_coarse_main_panel.append({
            "npi": name,
            "med": med_per_red[cm_i],
            "tag": main_tag
        })
        
    comb_cm_effects, new_names = combine_npi_samples(gatherings_dict, alpha_i, cm_names)
    med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
    
    main_tag = "Unknown"
    for name, v in exp_main_tags.items():
        if d["exp_tag"] in v:
            main_tag = name
    
    if d["exp_tag"] == "intervention_prior":
        print(new_names)
        
    for cm_i, name in enumerate(new_names):
        npi_median_data_list_coarse_gath_panel.append({
            "npi": name,
            "med": med_per_red[cm_i],
            "tag": main_tag
        })

npi_medians_df_coarse_main_panel = pd.DataFrame(npi_median_data_list_coarse_main_panel)
npi_medians_df_coarse_gath_panel = pd.DataFrame(npi_median_data_list_coarse_gath_panel)

In [None]:
npi_medians_df = npi_medians_df_coarse_main_panel
fig = plt.figure(constrained_layout=True, figsize=(5.75, 8.25), dpi=400)
gs = gridspec.GridSpec(ncols=8, nrows=11, figure=fig)

main_ax = fig.add_subplot(gs[:8, :])

exp_tags = np.unique(npi_medians_df["tag"])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]
plt.plot([0, 0], [-1, 40], linestyle='-', alpha=1, zorder=-5, linewidth=1, color="tab:gray")


npi_names = list(sol_two_main_dict.keys())
npi_cols = [d['color'] for k, d in sol_two_main_dict.items()]

# groups = [(0, 4), (5, 9), (10, 14), (15, 19), (20, 24), (25, 26)]
# g_cols = [cols[0], cols[3], cols[4], cols[5], cols[2], cols[6]]

groups = [(0, 4), (5, 9), (8, 12), (13, 15)]
g_cols = [cols[0], cols[2], cols[3], cols[1]]

# for i in range(0, 30, 2):
#     plt.fill_between(
#             [-100, 100],
#             [i-0.5, i-0.5],
#             [i-1-0.5, i-1-0.5],
#             color='snow',
#             alpha=0.9,
#             linewidth=0,
#             zorder=-6
#         )

for i, col in enumerate(npi_cols):
    if i % 2 == 0:
        plt.fill_between(
            [-100, 100],
            [i-0.5, i-0.5],
            [i+0.5, i+0.5],
            color=col,
            alpha=0.05,
            linewidth=0,
            zorder=-5
        )
    else:
         plt.fill_between(
            [-100, 100],
            [i-0.5, i-0.5],
            [i+0.5, i+0.5],
            color=col,
            alpha=0.1,
            linewidth=0,
            zorder=-5
        )
   
for g, col in zip(groups, g_cols):
    start = g[0]-0.5
    end = g[1]+0.5
    plt.plot([-100, 100], [end, end], color="tab:gray", linewidth=1, alpha=1)

sns.stripplot(x="med", y="npi", data=npi_medians_df, size=5, zorder=1, 
              jitter=0.3, order=npi_names, alpha=0.2, hue="tag", dodge=True,
              palette="colorblind", linewidth=0.05)
plt.gca().get_legend().remove()


for i, tick in enumerate(plt.gca().get_yticklabels()):
    tick.set_color(npi_cols[i])
    
plt.yticks(linespacing=1)
plt.xlabel("Median Reduction in R (%)")
plt.ylabel("")
plt.xlim([-30, 50])
plt.grid(axis='x', linewidth=0.5, zorder=-10, color="tab:gray", alpha=0.25)
plt.xticks([-20, 0, 20, 40])

gath_ax = fig.add_subplot(gs[8:, :])
npi_medians_df = npi_medians_df_coarse_gath_panel
exp_tags = np.unique(npi_medians_df["tag"])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]
plt.plot([0, 0], [-1, 40], linestyle='-', alpha=1, zorder=-5, linewidth=1, color="tab:gray")

groups = [(0, 4), (5, 9)]
g_cols = [cols[6], cols[5]]
npi_names = list(gatherings_dict.keys())
npi_cols = [d['color'] for k, d in gatherings_dict.items()]

for i, col in enumerate(npi_cols):
    if i % 2 == 0:
        plt.fill_between(
            [-100, 100],
            [i-0.5, i-0.5],
            [i+0.5, i+0.5],
            color=col,
            alpha=0.05,
            linewidth=0,
            zorder=-5
        )
    else:
         plt.fill_between(
            [-100, 100],
            [i-0.5, i-0.5],
            [i+0.5, i+0.5],
            color=col,
            alpha=0.1,
            linewidth=0,
            zorder=-5
        )
            
for g, col in zip(groups, g_cols):
    start = g[0]-0.5
    end = g[1]+0.5
    plt.plot([-100, 100], [end, end], color="tab:gray", linewidth=1, alpha=1)

sns.stripplot(x="med", y="npi", data=npi_medians_df, size=3, zorder=1, 
          jitter=0.3, order=npi_names, alpha=0.2, hue="tag", dodge=True,
          palette="colorblind", linewidth=0.05)

for i, tick in enumerate(plt.gca().get_yticklabels()):
    tick.set_color(npi_cols[i])
    
plt.yticks(fontsize=7, linespacing=0.8)
plt.xlabel("Median Reduction in R (%)")
plt.ylabel("")
plt.xlim([-30, 50])
plt.grid(axis='x', linewidth=0.5, zorder=-10, color="tab:gray", alpha=0.25)
plt.xticks([-20, 0, 20, 40])
plt.yticks()
plt.legend(bbox_to_anchor=(0.5, -0.3), loc='upper center', fontsize=8, ncol=3, fancybox=True, shadow=True, handletextpad=0.05)
plt.savefig('FigVal.svg', bbox_inches='tight')
plt.savefig('validation_all.pdf', bbox_inches='tight')