In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import covasim as cv
from util import data

sns.set_context('notebook')
sns.set_style('whitegrid')

pop_region = 4.46e6
pop_size = 450e3
pop_scale = 10

cols = ['n_severe', 'n_critical', 'cum_deaths']
n_runs = 3
n_trials = 1000

In [None]:
df = data.get_regional_data(scaling_factor=pop_region/pop_size)
df = df[df['date'] <= pd.to_datetime('2020-09-15')]

_, axes = plt.subplots(1, 3, figsize=(18, 4), sharey=False)

sns.lineplot(data=df, x='date', y='new_tests', ax=axes[0]).set(title='New Tests', ylabel='')
sns.lineplot(data=df, x='date', y='cum_deaths', ax=axes[1]).set(title='Cumulative Deaths', ylabel='')
sns.lineplot(data=df, x='date', y='n_severe', label='n_severe', ax=axes[2]).set(title='Active Cases')
sns.lineplot(data=df, x='date', y='n_critical', label='n_critical', ax=axes[2]).set(ylabel='')

plt.tight_layout()

In [None]:
default_params = dict(
    pop_type='hybrid',
    location='italy',
    start_day=df['date'].iloc[0],
    end_day=df['date'].iloc[-1],
    pop_size=pop_size / pop_scale,
    pop_scale=pop_scale,
    rescale=True,
    n_beds_hosp=pop_size * 370.4 / 100e3,
    n_beds_icu=pop_size * 14.46 / 100e3,
    quar_period=14,
    verbose=0
)

In [None]:
import optuna as op
from util.interventions import get_interventions

def unused_param(trial, name, value=0.):
    return trial.suggest_float(name, value, value)

def objective(trial):
    # define learnable parameters
    initial_params = dict(
        pop_infected=trial.suggest_int('init_pop_infected', 2400, 4500, step=10),
        n_imports=trial.suggest_float('init_n_imports', 0.6, 2.6, step=1e-1),
        beta=trial.suggest_float('init_beta', 0.018, 0.033, step=1e-3),
        rel_symp_prob=trial.suggest_float('init_rel_symp_prob', 1.8, 10.0, log=True),
        rel_severe_prob=trial.suggest_float('init_rel_severe_prob', 0.05, 0.2, log=True),
        rel_crit_prob=trial.suggest_float('init_rel_crit_prob', 0.7, 1.5, log=True),
        rel_death_prob=trial.suggest_float('init_rel_death_prob', 4.0, 10.0, log=True),
        **default_params
    )
    intervention_params = dict(
        trace_prob=trial.suggest_float('interv_trace_prob', 0.4, 0.85, step=1e-2),          # CONTACT TRACING
        trace_time=trial.suggest_float('interv_trace_time', 2.3, 6.3, step=1e-1),           #
        work_contacts=trial.suggest_float('interv_work_contacts', 0.55, 0.75, step=1e-2),   # SMART WORKING
        school_contacts=trial.suggest_float('interv_school_contacts', 0.0, 1.0, step=1e-2), # SCHOOLS CLOSED
        yellow_contacts=unused_param(trial, 'interv_yellow_contacts'),                      # LOCKDOWN INTERACTIONS
        orange_contacts=unused_param(trial, 'interv_orange_contacts'),                      #
        summer_imports=trial.suggest_float('interv_summer_imp', 0., 20., step=1e-1),        # IMPORTED CASES
        yellow_imports=unused_param(trial, 'interv_yellow_imp'),                            #
        orange_imports=unused_param(trial, 'interv_orange_imp'),                            #
        summer_beta=trial.suggest_float('interv_summer_beta', 0.0, 0.2, step=1e-3),         # VIRAL LOAD REDUCTION
        winter_beta=unused_param(trial, 'interv_winter_beta'),                              #
        summer_symp=trial.suggest_float('interv_summer_symp', 0.01, 10.0, log=True),        #
        winter_symp=unused_param(trial, 'interv_winter_symp'),                              #
        summer_sev=trial.suggest_float('interv_summer_sev', 0.01, 10.0, log=True),          #
        winter_sev=unused_param(trial, 'interv_winter_sev'),                                #
        summer_crit=trial.suggest_float('interv_summer_crit', 0.01, 10.0, log=True),        #
        winter_crit=unused_param(trial, 'interv_winter_crit'),                              #
        summer_death=trial.suggest_float('interv_summer_death', 0.01, 10.0, log=True),      #
        winter_death=unused_param(trial, 'interv_winter_death')                             #
    )
    # define and run simulations
    sim = cv.Sim(pars=initial_params, interventions=get_interventions(intervention_params), datafile=df)
    msim = cv.MultiSim(sim)
    msim.run(n_runs=n_runs)
    # compute loss
    mismatches = [s.compute_fit(keys=cols, weights={c: 1 for c in cols}).mismatch for s in msim.sims]
    return sum(mismatches) / n_runs

study = op.create_study()
study.optimize(func=objective, n_trials=n_trials)

In [None]:
best_params = study.best_params
best_params

In [None]:
results = pd.DataFrame([dict(objective=t.value, **t.params) for t in study.trials])
results = results[results['objective'] <= results['objective'].quantile(0.01)]
results = results.drop('objective', axis=1)
results = results.describe().loc[['count', 'min', 'max']]
results = results.append(pd.Series(best_params, name='best'))
results = results.transpose().astype({'count': 'int'})
# results.to_csv(f'res/Pop{int(pop_size/1e3)}k_Scale{pop_scale}')
results

In [None]:
initial_params = {k[5:]: v for k, v in best_params.items() if k.startswith('init_')}
initial_params.update(default_params)

intervention_params = {k[7:]: v for k, v in best_params.items() if k.startswith('interv_')}
interventions = get_interventions(intervention_params)

sim = cv.Sim(pars=initial_params, interventions=interventions, datafile=df)
msim = cv.MultiSim(sim)
msim.run(n_runs=30)

msim.mean()
msim.plot(cols);

In [None]:
_, axes = plt.subplots(1, 2, figsize=(18, 6), sharey=True)

for col in cols:
    sns.lineplot(data=df, x='date', y=col, label=col, ax=axes[0]).set(title='Real Data', ylabel='')
    sns.lineplot(data=msim.results, x='date', y=col, label=col, ax=axes[1]).set(title='Simulated Data', ylabel='')

plt.tight_layout()