In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import covasim as cv
from util import data

sns.set_context('notebook')
sns.set_style('whitegrid')

pop_region = 4.46e6
pop_size = 50e3
pop_scale = 10

cols = ['n_severe', 'n_critical', 'cum_deaths']
n_runs = 3
n_trials = 5000

In [None]:
df = data.get_regional_data(scaling_factor=pop_region/pop_size)
# df = df[df['date'] <= pd.to_datetime('2020-05-18')]

_, axes = plt.subplots(1, 3, figsize=(18, 4), sharey=False)

sns.lineplot(data=df, x='date', y='new_tests', ax=axes[0]).set(title='New Tests', ylabel='')
sns.lineplot(data=df, x='date', y='cum_deaths', ax=axes[1]).set(title='Cumulative Deaths', ylabel='')
sns.lineplot(data=df, x='date', y='n_severe', label='n_severe', ax=axes[2]).set(title='Active Cases')
sns.lineplot(data=df, x='date', y='n_critical', label='n_critical', ax=axes[2]).set(ylabel='')

plt.tight_layout()

In [None]:
default_params = dict(
    pop_type='hybrid',
    location='italy',
    start_day=df['date'].iloc[0],
    end_day=df['date'].iloc[-1],
    pop_size=pop_size / pop_scale,
    pop_scale=pop_scale,
    rescale=True,
    n_beds_hosp=pop_size * 370.4 / 100e3,
    n_beds_icu=pop_size * 14.46 / 100e3,
    quar_period=14,
    verbose=0
)

In [None]:
from sklearn.metrics import r2_score

def custom_estimator(y_true, y_pred):
    return 1 - r2_score(y_true, y_pred, sample_weight=sample_weight.values)

max_weight = 5

sample_weight = df.set_index(['date'], drop=False)['date']
sample_weight = sample_weight.map(lambda d: (d - pd.to_datetime('2020-09-15').date()).days)
sample_weight = sample_weight.map(lambda d: 1 if d < 0 else max_weight)
sample_weight

In [None]:
import optuna as op
from util.interventions import get_interventions

def objective(trial):
    # define learnable parameters
    initial_params = dict(
        pop_infected=trial.suggest_int('init_pop_infected', 10, int(pop_size / 100), step=10),
        n_imports=trial.suggest_float('init_n_imports', 0.0, 10.0, step=1e-1),
        beta=trial.suggest_float('init_beta', 0.0, 0.2, step=1e-3),
        rel_symp_prob=trial.suggest_float('init_rel_symp_prob', 0.0, 10.0, step=1e-2),
        rel_severe_prob=trial.suggest_float('init_rel_severe_prob', 0.0, 10.0, step=1e-2),
        rel_crit_prob=trial.suggest_float('init_rel_crit_prob', 0.0, 10.0, step=1e-2),
        rel_death_prob=trial.suggest_float('init_rel_death_prob', 0.0, 10.0, step=1e-2),
        **default_params
    )
    intervention_params = dict(
        trace_prob=trial.suggest_float('interv_trace_prob', 0.3, 0.9, step=1e-2),                # CONTACT TRACING
        trace_time=trial.suggest_float('interv_trace_time', 1.0, 7.0, step=1e-1),                #
        work_contacts=trial.suggest_float('interv_work_contacts', 0.4, 1.0, step=1e-2),          # SMART WORKING
        school_contacts=trial.suggest_float('interv_school_contacts', 0.2, 0.8, step=1e-2),      # SCHOOLS CLOSED
        yellow_contacts=trial.suggest_float('interv_yellow_contacts', 0.4, 1.0, step=1e-2),      # LOCKDOWN INTER.
        orange_contacts=trial.suggest_float('interv_orange_contacts', 0.2, 0.9, step=1e-2),      #
        summer_imports=trial.suggest_float('interv_summer_imp', 0.0, 20.0, step=1e-1),           # IMPORTED CASES
        yellow_imports=trial.suggest_float('interv_yellow_imp', 0.0, 12.0, step=1e-1),           #
        orange_imports=trial.suggest_float('interv_orange_imp', 0.0, 8.0, step=1e-1),            #
        summer_beta=trial.suggest_float('interv_summer_beta', 0.0, 0.2, step=1e-3),              # VIRAL LOAD RED.
        winter_beta=trial.suggest_float('interv_winter_beta', 0.0, 0.2, step=1e-3),              #
        summer_symp=trial.suggest_float('interv_summer_symp', 0.0, 10.0, step=1e-2),             #
        winter_symp=trial.suggest_float('interv_winter_symp', 0.0, 10.0, step=1e-2),             #
        summer_sev=trial.suggest_float('interv_summer_sev', 0.0, 10.0, step=1e-2),               #
        winter_sev=trial.suggest_float('interv_winter_sev', 0.0, 10.0, step=1e-2),               #
        summer_crit=trial.suggest_float('interv_summer_crit', 0.0, 10.0, step=1e-2),             #
        winter_crit=trial.suggest_float('interv_winter_crit', 0.0, 10.0, step=1e-2),             #
        summer_death=trial.suggest_float('interv_summer_death', 0.0, 10.0, step=1e-2),           #
        winter_death=trial.suggest_float('interv_winter_death', 0.0, 10.0, step=1e-2)            #
    )
    # define and run simulations
    sim = cv.Sim(pars=initial_params, interventions=get_interventions(intervention_params), datafile=df)
    msim = cv.MultiSim(sim)
    msim.run(n_runs=n_runs)
    # compute loss
    weights = {c: 1 for c in cols}
    mismatches = [s.compute_fit(keys=cols, weights=weights, estimator=custom_estimator).mismatch for s in msim.sims]
    return sum(mismatches) / n_runs

study = op.create_study()
study.optimize(func=objective, n_trials=n_trials)

In [None]:
best_params = study.best_params
best_params

In [None]:
results = pd.DataFrame([dict(objective=t.value, **t.params) for t in study.trials])
# results.to_csv(f'res/Pop{pop_size//1e3}k_logFalse_Full')
results = results[results['objective'] <= results['objective'].quantile(.01)]
results = results.describe().loc[['count', 'min', 'max']]
results = results.append(pd.Series({'objective': study.best_value, **best_params}, name='best'))
results = results.transpose().astype({'count': 'int'})
results

In [None]:
initial_params = {k[5:]: v for k, v in best_params.items() if k.startswith('init_')}
initial_params.update(default_params)

intervention_params = {k[7:]: v for k, v in best_params.items() if k.startswith('interv_')}
interventions = get_interventions(intervention_params)

sim = cv.Sim(pars=initial_params, interventions=interventions, datafile=df)
msim = cv.MultiSim(sim)
msim.run(n_runs=30)

msim.mean()
msim.plot(cols);

In [None]:
_, axes = plt.subplots(1, 2, figsize=(18, 6), sharey=True, tight_layout=True)

for col in cols:
    sns.lineplot(data=df, x='date', y=col, label=col, ax=axes[0]).set(title='Real Data', ylabel='')
    sns.lineplot(data=msim.results, x='date', y=col, label=col, ax=axes[1]).set(title='Simulated Data', ylabel='')