In [None]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append("../") 

import yaml

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("ticks")
import numpy as np
import pandas as pd

from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
base_dir = "sensitivity_analysis/rc_3a_cdef_keep_both"

In [None]:
all_info_dicts = []

for subdir, dirs, files in os.walk(base_dir):
    for f in files:
        if f.endswith('.yaml'):
             with open(os.path.join(subdir, f), 'r') as file:
                info_dict = yaml.safe_load(file)
                all_info_dicts.append(info_dict)

In [None]:
npi_median_data_list = []

for d in all_info_dicts:
    alpha_i = np.array(d["alpha_i"])
    cm_names = d["cm_names"]
    comb_cm_effects, new_names = combine_npi_samples(both_dict, alpha_i, cm_names)
    med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
    for cm_i, name in enumerate(new_names):
        npi_median_data_list.append({
            "npi": name,
            "med": med_per_red[cm_i],
            "tag": d["exp_tag"]
        })

npi_medians_df = pd.DataFrame(npi_median_data_list)

In [None]:
npi_names = both_dict.keys()

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
sns.stripplot(x="med", y="npi", data=npi_medians_df, color=[0.37647059, 0.37647059, 0.37647059, 1.], size=3.5, zorder=-4, jitter=False, order=npi_names)
sns.violinplot(x="med", y="npi", data=npi_medians_df, scale='width', inner=None, cut=0, order=npi_names, linewidth=0)
plt.title(f"Keep Both\nRobustness across {len(all_info_dicts)} sensitivity analyses", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")

In [None]:
runtimes = np.array([d['time_per_sample'] for d in all_info_dicts])
sns.histplot(runtimes)
plt.xlabel('seconds per sample')

In [None]:
exp_tags = np.unique([d['exp_tag'] for d in all_info_dicts])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]

plt.figure(figsize=(8, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
    
for e, col in zip(exp_tags, cols):
    filtered_df = npi_medians_df[npi_medians_df['tag'] == e]
    plt.scatter(-200, -200, color=col, label=e)
    sns.stripplot(x="med", y="npi", data=filtered_df, color=col, size=3.5, zorder=-4, jitter=True, order=npi_names)
    
plt.title(f"Keep Both", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")
plt.legend(bbox_to_anchor=(1.01, 1.01), loc='upper left')
plt.tight_layout()
plt.xlim([-30, 50])

In [None]:
main_panel_dict_keep_both = {
    'Some Face-to-Face Businesses Closed': {
        'npis': ['Some Face-to-Face Businesses Closed'],
        'type': "exclude",
        'color': cols[0],
    },
    'Gastronomy Closed': {
        'npis': ['Gastronomy Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'Leisure Venues Closed': {
        'npis': ['Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'Stay at Home Order &\nAll Businesses Closed': {
        'npis': ['Stay at Home Order AND All F2F Businesses Closed', 'Some Face-to-Face Businesses Closed', 
                'Gastronomy Closed', 'Leisure Venues Closed'],
        'type': 'exclude',
        'color': cols[0],
    },
    'All Educational Institutions Closed': {
        'npis': ['Primary Schools Closed', 'Secondary Schools Closed', 'Universities Away'],
        'type': 'exclude',
        'color': cols[2],
    },
    'Curfew': {
        'npis': ['Curfew'],
        'type': 'exclude',
        'color': cols[1]
    },
    'Mandatory Mask Wearing': {
        'npis': ['Mandatory Mask Wearing >= 3'],
        'type': 'exclude',
        'color': cols[1],
    },
    'All Gatherings Banned': {
        'npis': ['Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit',
                 'Public Outdoor Gathering Person Limit - 1',
                 'Public Outdoor Gathering Person Limit - 2',
                 'Public Outdoor Gathering Person Limit - 6',
                 'Public Outdoor Gathering Person Limit - 30',
                 'Extra Public Outdoor Household Limit',
                 'Private Outdoor Gathering Person Limit - 1',
                 'Private Outdoor Gathering Person Limit - 2',
                 'Private Outdoor Gathering Person Limit - 6',
                 'Private Outdoor Gathering Person Limit - 30',
                 'Extra Private Outdoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    },
    'Gatherings limited to 2': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 'Extra Private Indoor Household Limit',
                #'Public Indoor Gathering Person Limit - 1',
                 'Public Outdoor Gathering Person Limit - 2',
                 'Public Outdoor Gathering Person Limit - 6',
                 'Public Outdoor Gathering Person Limit - 30',
                 'Extra Public Outdoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 'Private Outdoor Gathering Person Limit - 2',
                 'Private Outdoor Gathering Person Limit - 6',
                 'Private Outdoor Gathering Person Limit - 30',
                 'Extra Private Outdoor Household Limit'
                ],
        "type": "exclude",
        'color': cols[3],
    },
    'Gatherings limited to 6': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit',
                #'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 'Public Outdoor Gathering Person Limit - 6',
                 'Public Outdoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 'Private Outdoor Gathering Person Limit - 6',
                 'Private Outdoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    },
    'Gatherings limited to 30': {
        'npis': [#'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 #'Public Indoor Gathering Person Limit - 6',
                 'Public Indoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 #'Private Indoor Gathering Person Limit - 6',
                 'Private Indoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit',
                #'Public Indoor Gathering Person Limit - 1',
                 #'Public Indoor Gathering Person Limit - 2',
                 #'Public Indoor Gathering Person Limit - 6',
                 'Public Outdoor Gathering Person Limit - 30',
                 #'Extra Public Indoor Household Limit',
                 #'Private Indoor Gathering Person Limit - 1',
                 #'Private Indoor Gathering Person Limit - 2',
                 #'Private Indoor Gathering Person Limit - 6',
                 'Private Outdoor Gathering Person Limit - 30',
                 #'Extra Private Indoor Household Limit'
                ],
        'type': 'exclude',
        'color': cols[3],
    }
 }





In [None]:
npi_names = main_panel_dict_keep_both.keys()

In [None]:
npi_median_data_list = []

for d in all_info_dicts:
    alpha_i = np.array(d["alpha_i"])
    cm_names = d["cm_names"]
    comb_cm_effects, new_names = combine_npi_samples(main_panel_dict_keep_both, alpha_i, cm_names)
    med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
    for cm_i, name in enumerate(new_names):
        npi_median_data_list.append({
            "npi": name,
            "med": med_per_red[cm_i],
            "tag": d["exp_tag"]
        })

npi_medians_df = pd.DataFrame(npi_median_data_list)

In [None]:
exp_tags = np.unique([d['exp_tag'] for d in all_info_dicts])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]

plt.figure(figsize=(8, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
    
for e, col in zip(exp_tags, cols):
    filtered_df = npi_medians_df[npi_medians_df['tag'] == e]
    plt.scatter(-200, -200, color=col, label=e)
    sns.stripplot(x="med", y="npi", data=filtered_df, color=col, size=3.5, zorder=-4, jitter=True, order=npi_names)
    
plt.title(f"Keep Both", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")
plt.legend(bbox_to_anchor=(1.01, 1.01), loc='upper left')
plt.tight_layout()
plt.xlim([-30, 50])

# look at leaveouts

In [None]:
npi_median_data_list = []

for d in all_info_dicts:
    if d["exp_tag"] == "leaveout":
        alpha_i = np.array(d["alpha_i"])
        cm_names = d["cm_names"]
        comb_cm_effects, new_names = combine_npi_samples(both_dict, alpha_i, cm_names)
        med_per_red = np.median(100*(1-np.exp(-comb_cm_effects)), axis=0)
        for cm_i, name in enumerate(new_names):
            npi_median_data_list.append({
                "npi": name,
                "med": med_per_red[cm_i],
                "country": d["exp_config"]["country_names"][0]
            })

npi_medians_df = pd.DataFrame(npi_median_data_list)

In [None]:
np.unique(npi_medians_df["country"])

In [None]:
unique_countries = np.unique(npi_medians_df["country"])
cols = [*sns.color_palette('colorblind'), *sns.color_palette('bright')]

plt.figure(figsize=(8, 4), dpi=300)
plt.plot([0, 0], [-1, 19], color="k", linestyle='--', alpha=0.75, zorder=-5, linewidth=0.75)
    
for uc, col in zip(unique_countries, cols):
    filtered_df = npi_medians_df[npi_medians_df['country'] == uc]
    plt.scatter(-200, -200, color=col, label=uc)
    sns.stripplot(x="med", y="npi", data=filtered_df, color=col, size=3.5, zorder=-4, jitter=True, order=npi_names)
    
plt.title(f"Keep Both", fontsize=10)
plt.xlabel("Median Reduction in $R_t$")
plt.ylabel("")
plt.legend(bbox_to_anchor=(1.01, 1.01), loc='upper left')
plt.tight_layout()
plt.xlim([-30, 50])