In [None]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append("../") 

import yaml

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("ticks")
import numpy as np
import pandas as pd

from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
base_dir = "sensitivity_analysis/rc_3a_cdef_drop_outdoors"

In [None]:
all_info_dicts = []

for subdir, dirs, files in os.walk(base_dir):
    for f in files:
        if f.endswith('.yaml'):
             with open(os.path.join(subdir, f), 'r') as file:
                info_dict = yaml.safe_load(file)
                all_info_dicts.append(info_dict)

In [None]:
drop_outdoor_dict = {
    'Some Face-to-Face Businesses Closed': {
        'npis': ['Some Face-to-Face Businesses Closed'],
        'type': "exclude",
    },
    'Gastronomy Closed': {
        'npis': ['Gastronomy Closed'],
        'type': 'exclude'
    },
    'Leisure Venues Closed': {
        'npis': ['Leisure Venues Closed'],
        'type': 'exclude'
    },
    'Curfew': {
        'npis': ['Curfew'],
        'type': 'exclude'
    },
    'Childcare Closed': {
        'npis': ['Childcare Closed'],
        'type': 'exclude'
    },
    'Primary Schools Closed': {
        'npis': ['Primary Schools Closed'],
        'type': 'exclude'
    },
    'Secondary Schools Closed': {
        'npis': ['Secondary Schools Closed'],
        'type': 'exclude'
    },
    'Universities Away': {
        'npis': ['Universities Away'],
        'type': 'exclude'
    },
    'Stay at Home Order AND All F2F Businesses Closed': {
        'npis': ['Stay at Home Order AND All F2F Businesses Closed'],
        'type': 'exclude'
    },
    'Mandatory Mask Wearing >= 3': {
        'npis': ['Mandatory Mask Wearing >= 3'],
        'type': 'exclude'
    },
    'Public Gathering Person Limit - 1': {'npis': ['Public Indoor Gathering Person Limit - 1',
                                                         'Public Indoor Gathering Person Limit - 2',
                                                         'Public Indoor Gathering Person Limit - 6',
                                                         'Public Indoor Gathering Person Limit - 30',
                                                        
                                                         'Extra Public Indoor Household Limit'],
                                                         'type': 'exclude'},
    'Public Gathering Person Limit - 2': {'npis': ['Public Indoor Gathering Person Limit - 2',
                                                         'Public Indoor Gathering Person Limit - 6',
                                                         'Public Indoor Gathering Person Limit - 30',
                                                         
                                                         'Extra Public Indoor Household Limit'],
                                                         'type': 'exclude'},
    'Public Gathering Person Limit - 6': {'npis': ['Public Indoor Gathering Person Limit - 6',
                                                         'Public Indoor Gathering Person Limit - 30',
                                                         ],
                                                         'type': 'exclude'},
    'Public Gathering Person Limit - 30': {'npis': ['Public Indoor Gathering Person Limit - 30'
                                                         ],
                                                         'type': 'exclude'},
    'Private Gathering Person Limit - 1': {'npis': ['Private Indoor Gathering Person Limit - 1',
                                                    'Private Indoor Gathering Person Limit - 2',
                                                         'Private Indoor Gathering Person Limit - 6',
                                                         'Private Indoor Gathering Person Limit - 30',
                                                         
                                                         'Extra Private Indoor Household Limit'],
                                                         'type': 'exclude'},
    'Private Gathering Person Limit - 2': {'npis': ['Private Indoor Gathering Person Limit - 2',
                                                         'Private Indoor Gathering Person Limit - 6',
                                                         'Private Indoor Gathering Person Limit - 30',
                                                         
                                                         'Extra Private Indoor Household Limit'],
                                                         'type': 'exclude'},
    
    'Private Gathering Person Limit - 6': {'npis': ['Private Indoor Gathering Person Limit - 6',
                                                         'Private Indoor Gathering Person Limit - 30',
                                                         ],
                                                         'type': 'exclude'},
    'Private Gathering Person Limit - 30': {'npis': ['Private Indoor Gathering Person Limit - 30',
                                                         ],
                                                         'type': 'exclude'},

 }

In [None]:
npi_median_data_list = []

for d in all_info_dicts:
    alpha_i = np.array(d["alpha_i"])
    cm_names = d["cm_names"]
    comb_cm_effects, new_names = combine_npi_samples(drop_outdoor_dict, alpha_i, cm_names)
    med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
    for cm_i, name in enumerate(new_names):
        npi_median_data_list.append({
            "npi": name,
            "med": med_per_red[cm_i],
            "tag": d["exp_tag"]
        })

npi_medians_df = pd.DataFrame(npi_median_data_list)

In [None]:
npi_names = ['Some Face-to-Face Businesses Closed',
  'Gastronomy Closed',
  'Leisure Venues Closed',
  'Curfew',
  'Childcare Closed',
  'Primary Schools Closed',
  'Secondary Schools Closed',
  'Universities Away',
  'Stay at Home Order AND All F2F Businesses Closed',
  'Mandatory Mask Wearing >= 3',
  'Public Gathering Person Limit - 1',
  'Public Gathering Person Limit - 2',
  'Public Gathering Person Limit - 6',
  'Public Gathering Person Limit - 30',
  'Private Gathering Person Limit - 1',
  'Private Gathering Person Limit - 2',
  'Private Gathering Person Limit - 6',
  'Private Gathering Person Limit - 30'
]

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)

sns.stripplot(x="med", y="npi", data=npi_medians_df, color=[0.37647059, 0.37647059, 0.37647059, 1.], size=3.5, zorder=-4, jitter=False, order=npi_names)
sns.violinplot(x="med", y="npi", data=npi_medians_df, scale='width', inner=None, cut=0, order=npi_names)
plt.title(f"Robustness across {len(all_info_dicts)} sensitivity analyses", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")

In [None]:
exp_tags = np.unique([d['exp_tag'] for d in all_info_dicts])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]

plt.figure(figsize=(8, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
    
for e, col in zip(exp_tags, cols):
    filtered_df = npi_medians_df[npi_medians_df['tag'] == e]
    plt.scatter(-200, -200, color=col, label=e)
    sns.stripplot(x="med", y="npi", data=filtered_df, color=col, size=3.5, zorder=-4, jitter=True, order=npi_names)

plt.title(f"Keep Both", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")
plt.legend(bbox_to_anchor=(1.01, 1.01), loc='upper left')
plt.tight_layout()
plt.xlim([-30, 50])

In [None]:
runtimes = np.array([d['time_per_sample'] for d in all_info_dicts])
sns.histplot(runtimes)
plt.xlabel('seconds per sample')

In [None]:
main_panel_dict_drop_outdoors = {
    'Some Face-to-Face Businesses Closed': {
        'npis': ['Some Face-to-Face Businesses Closed'],
        'type': "exclude",
        'color': cols[0],
    },
    'Gastronomy Closed': {
        'npis': ['Gastronomy Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'Leisure Venues Closed': {
        'npis': ['Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'Stay at Home Order &\nAll Businesses Closed': {
        'npis': ['Stay at Home Order AND All F2F Businesses Closed', 'Some Face-to-Face Businesses Closed', 
                'Gastronomy Closed', 'Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'All Educational Institutions Closed': {
        'npis': ['Primary Schools Closed', 'Secondary Schools Closed', 'Universities Away'],
        'type': 'exclude',
        'color': cols[2],
    },
    'Curfew': {
        'npis': ['Curfew'],
        'type': 'exclude',
        'color': cols[1]
    },
    'Mandatory Mask Wearing': {
        'npis': ['Mandatory Mask Wearing >= 3'],
        'type': 'exclude',
        'color': cols[1],
    },
    'All Gatherings Banned': {
        'npis': ['Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    },
    'Gatherings limited to 2': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        "type": "exclude",
        'color': cols[3],
    },
    'Gatherings limited to 6': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    },
    'Gatherings limited to 30': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 #'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 #'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    }
 }

In [None]:
npi_median_data_list = []

for d in all_info_dicts:
    alpha_i = np.array(d["alpha_i"])
    cm_names = d["cm_names"]
    comb_cm_effects, new_names = combine_npi_samples(main_panel_dict_drop_outdoors, alpha_i, cm_names)
    med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
    for cm_i, name in enumerate(new_names):
        npi_median_data_list.append({
            "npi": name,
            "med": med_per_red[cm_i],
            "tag": d["exp_tag"]
        })

npi_medians_df = pd.DataFrame(npi_median_data_list)

In [None]:
npi_names = main_panel_dict_drop_outdoors.keys()

In [None]:
exp_tags

In [None]:
grouped_tags = {
    "Epidemiological Parameters": ['basic_R_prior_mean', 'basic_R_prior_scale', 'death_delay_mean', 'gen_int_mean', 
                                   'infection_noise_scale', 'intervention_prior', 'output_noise_scale_prior', 
                                   'seeding_days', 'seeding_scaling']
}

In [None]:
exp_tags = np.unique([d['exp_tag'] for d in all_info_dicts])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]

plt.figure(figsize=(8, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
    
for e, col in zip(exp_tags, cols):
    filtered_df = npi_medians_df[npi_medians_df['tag'] == e]
    plt.scatter(-200, -200, color=col, label=e)
    sns.stripplot(x="med", y="npi", data=filtered_df, color=col, size=3.5, zorder=-4, jitter=True, order=npi_names)
    
plt.title(f"Drop Outdoors", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")
plt.legend(bbox_to_anchor=(1.01, 1.01), loc='upper left')
plt.tight_layout()
plt.xlim([-30, 50])

# country leaveouts

In [None]:
npi_median_data_list = []

for d in all_info_dicts:
    if d["exp_tag"] == "leaveout":
        alpha_i = np.array(d["alpha_i"])
        cm_names = d["cm_names"]
        comb_cm_effects, new_names = combine_npi_samples(drop_outdoor_dict, alpha_i, cm_names)
        med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
        for cm_i, name in enumerate(new_names):
            npi_median_data_list.append({
                "npi": name,
                "med": med_per_red[cm_i],
                "country": d["exp_config"]["country_names"][0]
            })

npi_medians_df = pd.DataFrame(npi_median_data_list)

In [None]:
unique_countries = np.unique(npi_medians_df["country"])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]

plt.figure(figsize=(8, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
    
for uc, col in zip(unique_countries, cols):
    filtered_df = npi_medians_df[npi_medians_df['country'] == uc]
    plt.scatter(-200, -200, color=col, label=uc)
    sns.stripplot(x="med", y="npi", data=filtered_df, color=col, size=3.5, zorder=-4, jitter=True, order=npi_names)
    
plt.title(f"Drop Outdoors", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")
plt.legend(bbox_to_anchor=(1.01, 1.01), loc='upper left')
plt.tight_layout()
plt.xlim([-30, 50])

In [None]:
window_analysis = [d for d in all_info_dicts if d["exp_tag"]=="window_of_analysis"]

In [None]:
window_analysis[0]["exp_config"]

In [None]:
from epimodel.plotting.intervention_plots import plot_intervention_effectiveness

In [None]:
cm_names = window_analysis[0]["cm_names"]
alpha_i = np.array(window_analysis[0]["alpha_i"])

In [None]:
plot_intervention_effectiveness(alpha_i, cm_names, xlim=[-25, 25])

In [None]:
window_analysis = [d for d in all_info_dicts if d["exp_tag"]=="default"]
cm_names = window_analysis[0]["cm_names"]
alpha_i = np.array(window_analysis[0]["alpha_i"])
plot_intervention_effectiveness(alpha_i, cm_names, xlim=[-25, 25])

In [None]:
plot_intervention_effectivenes

# Nicer Validation Figure

In [None]:
main_panel_dict_drop_outdoors = {
    'Some Face-to-Face Businesses Closed': {
        'npis': ['Some Face-to-Face Businesses Closed'],
        'type': "exclude",
        'color': cols[0],
    },
    'Gastronomy Closed': {
        'npis': ['Gastronomy Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'Leisure Venues Closed': {
        'npis': ['Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'Stay at Home Order &\nAll Businesses Closed': {
        'npis': ['Stay at Home Order AND All F2F Businesses Closed', 'Some Face-to-Face Businesses Closed', 
                'Gastronomy Closed', 'Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'All Educational Institutions Closed': {
        'npis': ['Primary Schools Closed', 'Secondary Schools Closed', 'Universities Away'],
        'type': 'exclude',
        'color': cols[2],
    },
    'Curfew': {
        'npis': ['Curfew'],
        'type': 'exclude',
        'color': cols[1]
    },
    'Mandatory Mask Wearing': {
        'npis': ['Mandatory Mask Wearing >= 3'],
        'type': 'exclude',
        'color': cols[1],
    },
    'All Gatherings Banned': {
        'npis': ['Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    },
    'Gatherings limited to 2': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit'
                ],
        "type": "exclude",
        'color': cols[3],
    },
    'Gatherings limited to 6': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    },
    'Gatherings limited to 30': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 #'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 #'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    }
 }

In [None]:
npi_names = main_panel_dict_drop_outdoors.keys()
npi_cols = [d['color'] for k, d in main_panel_dict.items()]

In [None]:
npi_median_data_list = []

for d in all_info_dicts:
    alpha_i = np.array(d["alpha_i"])
    cm_names = d["cm_names"]
    comb_cm_effects, new_names = combine_npi_samples(main_panel_dict_drop_outdoors, alpha_i, cm_names)
    med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
    for cm_i, name in enumerate(new_names):
        npi_median_data_list.append({
            "npi": name,
            "med": med_per_red[cm_i],
            "tag": d["exp_tag"]
        })

npi_medians_df = pd.DataFrame(npi_median_data_list)

In [None]:
x = plt.violinplot(np.arange(10))

In [None]:
x

In [None]:
from matplotlib.collections import PolyCollection
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

plt.figure(figsize=(4, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
nCMs = len(npi_names)
for i in range(0, nCMs, 2):
    plt.fill_between(
        [-100, 100],
        [i + 0.5, i + 0.5],
        [i - 0.5, i - 0.5],
        color="tab:grey",
        alpha=0.1,
        linewidth=0,
    )
    
for col, npi in zip(npi_cols, npi_names):
    filtered_df = npi_medians_df[npi_medians_df["npi"] == npi]
    v = sns.violinplot(x="med", y="npi", data=filtered_df,  
                 size=3.5, zorder=-6, order=npi_names, color=col ,split=True, 
                   inner=None, linewidth=0, alpha=0.25, cut=0)
    sns.stripplot(x="med", y="npi", data=filtered_df,  
                 size=3.5, zorder=10, order=npi_names, alpha=0.35, color=col, jitter=0.2)


for art_i, art in enumerate(plt.gca().get_children()):
    if isinstance(art, PolyCollection):
        b = art
        mean = np.mean(b.get_paths()[0].vertices[:, 1])
        if art_i > 17 or art_i == 6:
            b.get_paths()[0].vertices[:, 1] = np.clip(b.get_paths()[0].vertices[:, 1], -np.inf, mean)-0.1
            b.set_alpha(0.8)

for i, tick in enumerate(plt.gca().get_yticklabels()):
    tick.set_color(npi_cols[i])

plt.title(f"Robustness across {len(all_info_dicts)} sensitivity analyses", fontsize=10)
plt.xlabel("Median Reduction in R (%)")
plt.ylabel("")
plt.xlim([-10, 50])

In [None]:
b.get_paths()[0].vertices[:, 1]