In [None]:
import sys
if 'google.colab' in sys.modules:
    !pip install covasim optuna
    !git clone https://github.com/mazzio97/EpidemicModelLearning.git
    sys.path.append('EpidemicModelLearning')
    sys.path.append('EpidemicModelLearning/notebooks')
    sys.path.append('EpidemicModelLearning/notebooks/util')

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import covasim as cv
from util import data, calibration

sns.set_context('notebook')
sns.set_style('whitegrid')

pop_size = 450e3
pop_scale = 10
df = data.get_regional_data(scaling_factor=4.46e6/pop_size)

cols = ['n_severe', 'n_critical', 'cum_diagnoses', 'cum_deaths']
n_runs = 3
n_trials = 50

sample_weight = calibration.get_sample_weights(df, 'proportional')
custom_estimator = calibration.get_custom_estimator('mse', sample_weight)

In [None]:
_, axes = plt.subplots(2, 2, figsize=(14, 8), sharex=True)

sns.lineplot(data=df, x='date', y='new_tests', ax=axes[0, 0]).set(title='New Tests', ylabel='')
sns.lineplot(data=df, x='date', y='cum_diagnoses', ax=axes[0, 1]).set(title='Cumulative Cases', ylabel='')
sns.lineplot(data=df, x='date', y='cum_deaths', ax=axes[1, 0]).set(title='Cumulative Deaths', ylabel='')
sns.lineplot(data=df, x='date', y='n_severe', label='n_severe', ax=axes[1, 1]).set(title='Active Cases')
sns.lineplot(data=df, x='date', y='n_critical', label='n_critical', ax=axes[1, 1]).set(ylabel='')

plt.tight_layout()

In [None]:
default_params = dict(
    start_day=df['date'].iloc[0],
    end_day=df['date'].iloc[-1],
    pop_size=pop_size / pop_scale,
    pop_scale=pop_scale,
    rescale=True,
    n_beds_hosp=pop_size * 370.4 / 100e3,
    n_beds_icu=pop_size * 14.46 / 100e3,
    quar_period=14,
    verbose=0
)

In [None]:
def get_interventions(pars):
    return [
        cv.test_num('new_tests', quar_policy='both', sensitivity=0.8, do_plot=False),
        cv.contact_tracing(trace_probs=pars['trace_probs'], trace_time=pars['trace_time'], do_plot=False),
        cv.change_beta(days=periods.index.values, changes=[pars[f'{p}_beta_change'] for p in periods]),
        cv.dynamic_pars(beta=dict(
            days=[get_delta('2020-05-01'), get_delta('2020-09-01')],
            vals=[pars['summer_beta'], pars['winter_beta']]
        )),
        cv.dynamic_pars(n_imports=dict(
            days=periods.index.values,
            vals=[pars[f'{p}_imports'] for p in periods]
        ))
    ]

def get_delta(date):
    return (pd.to_datetime(date).date() - df['date'].iloc[0]).days

periods = pd.Series({
    '2020-03-08': 'red',
    '2020-05-18': 'white',
    '2020-11-08': 'yellow',
    '2020-11-15': 'orange',
    '2020-12-10': 'yellow',
    '2020-12-21': 'orange',
    '2021-02-01': 'yellow',
    '2021-02-21': 'orange',
    '2021-03-01': 'red'
})
periods.index = [get_delta(d) for d in periods.index]

In [None]:
import optuna as op

def constraint_violation(pars, eps=1e-9):
  # eps is used to avoid division by zero runtime error
  pars = {k: v + eps for k, v in pars.items() if isinstance(v, float)}
  violation = 1.
  violation *= max(1, pars['beta'] / pars['winter_beta'])
  violation *= max(1, pars['summer_beta'] / pars['beta'])
  violation *= max(1, pars['yellow_beta_change'] / pars['white_beta_change'])
  violation *= max(1, pars['orange_beta_change'] / pars['yellow_beta_change'])
  violation *= max(1, pars['red_beta_change'] / pars['orange_beta_change'])
  violation *= max(1, pars['yellow_imports'] / pars['white_imports'])
  violation *= max(1, pars['orange_imports'] / pars['yellow_imports'])
  violation *= max(1, pars['red_imports'] / pars['orange_imports'])
  return violation

def objective(trial):
    # define learnable parameters
    initial_params = dict(
        pop_infected=trial.suggest_int('init_pop_infected', pop_size // 1e4, pop_size // 1e2, step=5),
        rel_symp_prob=trial.suggest_float('init_rel_symp_prob', 0.4, 4.0, log=True),
        rel_severe_prob=trial.suggest_float('init_rel_severe_prob', 0.4, 4.0, log=True),
        rel_crit_prob=trial.suggest_float('init_rel_crit_prob', 0.4, 4.0, log=True),
        rel_death_prob=trial.suggest_float('init_rel_death_prob', 0.4, 4.0, log=True),
        beta=trial.suggest_float('init_beta', 0.0, 0.04, step=None),
        n_imports=trial.suggest_float('init_n_imports', 5.0, 20.0, step=1e-1),
        **default_params
    )
    intervention_params = dict(
        trace_probs=trial.suggest_float('interv_trace_probs', 0.2, 0.9, step=1e-2),               # CONTACT TRACING
        trace_time=trial.suggest_float('interv_trace_time', 1.0, 7.0, step=1e-1),                 #
        summer_beta=trial.suggest_float('interv_summer_beta', 0.0, 0.04, step=None),              # VIRAL LOAD RED.
        winter_beta=trial.suggest_float('interv_winter_beta', 0.0, 0.04, step=None),              #
        white_beta_change=trial.suggest_float('interv_white_beta_change', 0.8, 1.0, step=1e-2),   # BETA CHANGE
        yellow_beta_change=trial.suggest_float('interv_yellow_beta_change', 0.6, 0.9, step=1e-2), #
        orange_beta_change=trial.suggest_float('interv_orange_beta_change', 0.2, 0.6, step=1e-2), #
        red_beta_change=trial.suggest_float('interv_red_beta_change', 0.0, 0.3, step=1e-2),       #
        white_imports=trial.suggest_float('interv_white_imports', 5.0, 20.0, step=1e-1),          # N. IMPORTS
        yellow_imports=trial.suggest_float('interv_yellow_imports', 3.0, 12.0, step=1e-1),        #
        orange_imports=trial.suggest_float('interv_orange_imports', 2.0, 8.0, step=1e-1),         #
        red_imports=trial.suggest_float('interv_red_imports', 0.0, 6.0, step=1e-1)                #
    )
    # define and run simulations
    sim = cv.Sim(pars=initial_params, interventions=get_interventions(intervention_params), datafile=df)
    msim = cv.MultiSim(sim)
    msim.run(n_runs=n_runs)
    # compute loss
    weights = {c: 1 for c in cols}
    mismatches = [s.compute_fit(keys=cols, weights=weights, estimator=custom_estimator).mismatch for s in msim.sims]
    violation = constraint_violation({**initial_params, **intervention_params})
    return violation * sum(mismatches) / n_runs

study = op.create_study()
study.optimize(func=objective, n_trials=n_trials)

In [None]:
best_params = study.best_params
best_params

In [None]:
results = pd.DataFrame([dict(objective=t.value, **t.params) for t in study.trials])
results = results[results['objective'] <= results['objective'].quantile(.05)]
results = results.describe().loc[['count', 'min', 'max']]
results = results.append(pd.Series({'objective': study.best_value, **best_params}, name='best'))
results = results.transpose().astype({'count': 'int'})
results

In [None]:
initial_params = {k[5:]: v for k, v in best_params.items() if k.startswith('init_')}
initial_params.update(default_params)

intervention_params = {k[7:]: v for k, v in best_params.items() if k.startswith('interv_')}
interventions = get_interventions(intervention_params)

sim = cv.Sim(pars=initial_params, interventions=interventions, datafile=df)
msim = cv.MultiSim(sim)
msim.run(n_runs=30)

msim.mean()
msim.plot(cols + ['n_infectious', 'n_susceptible']);

In [None]:
_, axes = plt.subplots(2, 2, figsize=(14, 8), sharey='row', sharex=True, tight_layout=True)

for col in cols:
    i = 0 if 'cum_' in col else 1
    sns.lineplot(data=df, x='date', y=col, label=col, ax=axes[i, 0]).set(title='Real Data', ylabel='')
    sns.lineplot(data=msim.results, x='date', y=col, label=col, ax=axes[i, 1]).set(title='Simulated Data', ylabel='')

In [None]:
results = pd.DataFrame([dict(objective=t.value, **t.params) for t in study.trials]).sort_values('objective')
for id, best_params in results.iterrows():
  initial_params = {k[5:]: v for k, v in best_params.items() if k.startswith('init_')}
  initial_params.update(default_params)
  intervention_params = {k[7:]: v for k, v in best_params.items() if k.startswith('interv_')}
  interventions = get_interventions(intervention_params)
  sim = cv.Sim(pars=initial_params, interventions=interventions, datafile=df)
  msim = cv.MultiSim(sim)
  msim.run(n_runs=10)
  msim.mean()
  print('RUN:', id)
  msim.plot(cols + ['n_infectious', 'n_susceptible'])
  print('\n')