In [None]:
%%capture
import sys
if 'google.colab' in sys.modules:
    !pip install covasim optuna
    !git clone https://github.com/mazzio97/EpidemicModelLearning.git
    sys.path.append('EpidemicModelLearning')
    sys.path.append('EpidemicModelLearning/notebooks')
    sys.path.append('EpidemicModelLearning/notebooks/util')

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import covasim as cv
from util import data, calibration

sns.set_context('notebook')
sns.set_style('whitegrid')

pop_size = 450e3
pop_scale = 10
df = data.get_regional_data(scaling_factor=4.46e6/pop_size)

cols = set(df.columns) - {'date', 'new_tests'}
n_runs = 3
n_trials = 1000

sample_weight = calibration.get_sample_weights(df, 'proportional')
custom_estimator = calibration.get_custom_estimator('mse', sample_weight)

In [None]:
_, axes = plt.subplots(2, 2, figsize=(14, 8), sharex=True)

sns.lineplot(data=df, x='date', y='new_tests', ax=axes[0, 0]).set(title='New Tests', ylabel='')
sns.lineplot(data=df, x='date', y='cum_diagnoses', ax=axes[0, 1]).set(title='Cumulative Cases', ylabel='')
sns.lineplot(data=df, x='date', y='cum_deaths', ax=axes[1, 0]).set(title='Cumulative Deaths', ylabel='')
sns.lineplot(data=df, x='date', y='n_severe', label='n_severe', ax=axes[1, 1]).set(title='Active Cases')
sns.lineplot(data=df, x='date', y='n_critical', label='n_critical', ax=axes[1, 1]).set(ylabel='')

plt.tight_layout()

In [None]:
default_params = dict(
    pop_type='hybrid',
    location='italy',
    start_day=df['date'].iloc[0],
    end_day=df['date'].iloc[-1],
    pop_size=pop_size / pop_scale,
    pop_scale=pop_scale,
    rescale=True,
    n_beds_hosp=pop_size * 370.4 / 100e3,
    n_beds_icu=pop_size * 14.46 / 100e3,
    quar_period=14,
    verbose=0
)

In [None]:
import optuna as op
from util.interventions import get_interventions

constraint_orderings = [
    ('winter_beta', 'beta', 'summer_beta'),
    ('winter_symp', 'rel_symp_prob', 'summer_symp'),
    ('winter_sev', 'rel_severe_prob', 'summer_sev'),
    ('winter_crit', 'rel_crit_prob', 'summer_crit'),
    ('winter_death', 'rel_death_prob', 'summer_death'),
    ('yellow_contacts', 'orange_contacts'),
    ('white_imp', 'yellow_imp', 'orange_imp')
]

def objective(trial):
    # define learnable parameters
    initial_params = dict(
        pop_infected=trial.suggest_int('init_pop_infected', pop_size // 1e4, pop_size // 1e2, step=5),
        n_imports=trial.suggest_float('init_n_imports', 5.0, 20.0, step=1e-1),
        rel_symp_prob=trial.suggest_float('init_rel_symp_prob', 0.5, 5.0, log=True),
        rel_severe_prob=trial.suggest_float('init_rel_severe_prob', 0.5, 5.0, log=True),
        rel_crit_prob=trial.suggest_float('init_rel_crit_prob', 0.5, 5.0, log=True),
        rel_death_prob=trial.suggest_float('init_rel_death_prob', 0.5, 5.0, log=True),
        beta=trial.suggest_float('init_beta', 0.01, 0.04, step=None),
        **default_params
    )
    intervention_params = dict(
        trace_prob=trial.suggest_float('interv_trace_prob', 0.3, 0.9, step=1e-2),           # CONTACT TRACING
        trace_time=trial.suggest_float('interv_trace_time', 1.0, 7.0, step=1e-1),           #
        work_contacts=trial.suggest_float('interv_work_contacts', 0.4, 1.0, step=1e-2),     # SMART WORKING
        school_contacts=trial.suggest_float('interv_school_contacts', 0.2, 0.8, step=1e-2), # SCHOOLS CLOSED
        yellow_contacts=trial.suggest_float('interv_yellow_contacts', 0.4, 1.0, step=1e-2), # LOCKDOWN INTER.
        orange_contacts=trial.suggest_float('interv_orange_contacts', 0.2, 0.9, step=1e-2), #
        white_imp=trial.suggest_float('interv_white_imp', 5.0, 20.0, step=1e-1),            # IMPORTED CASES
        yellow_imp=trial.suggest_float('interv_yellow_imp', 3.0, 12.0, step=1e-1),          #
        orange_imp=trial.suggest_float('interv_orange_imp', 2.0, 8.0, step=1e-1),           #
        summer_beta=trial.suggest_float('interv_summer_beta', 0.0, 0.01, step=None),        # VIRAL LOAD RED.
        winter_beta=trial.suggest_float('interv_winter_beta', 0.01, 0.04, step=None),       #
        summer_symp=trial.suggest_float('interv_summer_symp', 0.5, 5.0, log=True),          #
        winter_symp=trial.suggest_float('interv_winter_symp', 0.5, 5.0, log=True),          #
        summer_sev=trial.suggest_float('interv_summer_sev', 0.5, 5.0, log=True),            #
        winter_sev=trial.suggest_float('interv_winter_sev', 0.5, 5.0, log=True),            #
        summer_crit=trial.suggest_float('interv_summer_crit', 0.5, 5.0, log=True),          #
        winter_crit=trial.suggest_float('interv_winter_crit', 0.5, 5.0, log=True),          #
        summer_death=trial.suggest_float('interv_summer_death', 0.5, 5.0, log=True),        #
        winter_death=trial.suggest_float('interv_winter_death', 0.5, 5.0, log=True)         #
    )
    # define and run simulations
    sim = cv.Sim(pars=initial_params, interventions=get_interventions(intervention_params), datafile=df)
    msim = cv.MultiSim(sim)
    msim.run(n_runs=n_runs)
    # compute loss
    weights = {c: 1 for c in cols}
    mismatches = [s.compute_fit(keys=cols, weights=weights, estimator=custom_estimator).mismatch for s in msim.sims]
    violation = calibration.compute_violation({**initial_params, **intervention_params}, constraint_orderings)
    return sum(mismatches) / n_runs

study = op.create_study()
study.optimize(func=objective, n_trials=n_trials)

In [None]:
results = pd.DataFrame([dict(objective=t.value, **t.params) for t in study.trials]).sort_values('objective')
summary = results.head(n_trials // 20).describe().loc[['count', 'min', 'max']]
summary = summary.append(pd.Series({'objective': study.best_value, **study.best_params}, name='best'))
summary = summary.transpose().astype({'count': 'int'})
summary

In [None]:
for id, best_params in results.iterrows():
  initial_params = {k[5:]: v for k, v in best_params.items() if k.startswith('init_')}
  initial_params.update(default_params)
  intervention_params = {k[7:]: v for k, v in best_params.items() if k.startswith('interv_')}
  interventions = get_interventions(intervention_params)
  sim = cv.Sim(pars=initial_params, interventions=interventions, datafile=df)
  msim = cv.MultiSim(sim)
  msim.run(n_runs=10)
  msim.mean()
  print('RUN:', id)
  msim.plot(list(cols) + ['n_infectious', 'n_susceptible'])
  print('\n')