In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import os
os.chdir('..')

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import plotly.graph_objects as go

import ax
from ax.modelbridge.registry import Models
from ax.plot.slice import plot_slice
from ax.plot.contour import plot_contour
from ax.service.utils.instantiation import parameter_from_json
from ax.utils.notebook.plotting import render

import clds

# Estimate level curves using GPs and ax

## Simulation of a Single Configuration

In [None]:
def simulate(N=1e7,
             alpha=0.570,
             beta=0.011,
             gamma=0.456,
             delta=0.011,
             compensation=0., 
             lockdown_effectiveness=0.175,
             duty_cycle=1,
             period=7,
             n_steps=365, **kwargs):
    
    steps_high = duty_cycle
    steps_low = period-duty_cycle
    suppression_start = 20
    switching_start = 50

    # ODE
    step_size = 0.001

    # initial condition
    I = 500/6
    D = 20
    A = 1
    R = 2
    T = H = E = 0
    S = N - I - D - A - R - T - H - E
    s0 = np.array([S, I, D, A, R, T, H, E])

    # default latent parameters
    epsilon = 0.171
    theta = 0.371
    zeta = 0.125
    eta = 0.125
    mu = 0.012
    nu = 0.027
    tau = 0.003
    h = 0.034
    rho = 0.034
    kappa = 0.017
    xi = 0.017
    sigma = 0.017

    model = clds.agents.BatchSIDARTHE(s0=s0, 
                        alpha='a', # variable
                        beta='b', # variable
                        gamma='g', # variable
                        delta='d', #variable
                        epsilon=epsilon,
                        theta=theta,
                        zeta=zeta,
                        eta=eta,
                        mu=mu,
                        nu=nu,
                        tau=tau,
                        h=h,
                        rho=rho,
                        kappa=kappa,
                        xi=xi,
                        sigma=sigma,
                        N=N,
                        step_size=step_size)

    fpsp = clds.agents.BatchFPSP(beta_high=1,
                                 beta_low=lockdown_effectiveness, 
                                 steps_high=steps_high,
                                 steps_low=steps_low, 
                                 suppression_start=suppression_start, 
                                 switching_start=switching_start)

    env = clds.Composite()
    env.add(model, 
            pre=lambda x: {'a': x['fpsp']*alpha,
                           'b': x['fpsp']*beta,
                           'g': x['fpsp']*gamma,
                           'd': x['fpsp']*delta},
            out='model')
    env.add(fpsp, out='fpsp')

    #print(model.R0({'a': alpha, 'b': beta, 'g': gamma, 'd': delta}))
    o = [env.reset()]
    for i in range(n_steps):
        if i > suppression_start:
            fpsp.beta_high = (1+compensation)
        o.append(env.step()[0])

    s = np.array([x['model'] for x in o]).reshape(-1, 8)/N*100
    p = np.array([x['fpsp'] for x in o]).reshape(-1, 1)
    
    return s, p

## Evaluation  Wrapper Function

In [None]:
# evaluate configuration
def eval_fn(params):
    
    config = dict(params)
    # R0 -> alpha, beta, gamma, delta
    R0_base = 2.38461532
    correction = config['R0']/R0_base
    config['alpha'] = 0.570 * correction
    config['beta'] = 0.011 * correction
    config['gamma'] = 0.456 * correction
    config['delta'] = 0.011 * correction
    
    # quarantine effectiveness -> leakage
    config['lockdown_effectiveness'] = config['q']
    config['compensation'] = config['d']
    config['period'] *= 7
    config['duty_cycle'] = int(config['X']/7*config['period'])
    s, p = simulate(**config)
    
    #outcome measure: sum of daily I+Q from step 50 onwards
    total_infected_daily = np.sum(s[:,1:6], axis=1)
    peak_daily = total_infected_daily[50:].max()
    return {'peak_daily': (peak_daily, 0.0)} # dict of  tuples (mean, standard error)

## Experiments

### Initial set of 100 simulations

In [None]:
import ax
from ax.modelbridge.registry import Models
from ax.plot.slice import plot_slice
from ax.plot.contour import plot_contour
from ax.utils.notebook.plotting import render

# Developer API
batch_size = 1 # number of parallel GPEI trials
n_sobol=10
n_gpei=1

# parameter search space
params_full = [
    # ('R0', mu=2.675739, sd=0.5719293, lower=0)
    {"name": "R0", "type": "range", "bounds": [1.0, 4.4]},    #"type": "fixed", "value": 2.78358},
    {"name": "q", "type": "range", "bounds": [1e-6, 4.0]}, #"type": "fixed", "value": 0.175},
    {"name": "d", "type": "range", "bounds": [0.0, 1.0]}, #"type": "fixed", "value": 1.0},
    {"name": "X", "parameter_type": ax.ParameterType.INT, "type": "range", "bounds": [1, 7]},
    {"name": "period", "parameter_type": ax.ParameterType.INT, "type": "fixed", "value": 1}
]

params = {}
params['days_vs_R0']=[
    {"name": "R0", "type": "range", "bounds": [1.0, 4.4]},
    {"name": "q", "type": "fixed", "value": 0.175},
    {"name": "d", "type": "fixed", "value": 0.},
    {"name": "X", "parameter_type": ax.ParameterType.INT, "type": "range", "bounds": [1, 7]}, # , "log_scale": True},
    {"name": "period", "parameter_type": ax.ParameterType.INT, "type": "fixed", "value": 1}
]

params['days_vs_q']=[
    {"name": "R0", "type": "fixed", "value": 2.38461532},
    {"name": "q", "type": "range", "bounds": [1e-6, 0.5]},
    {"name": "d", "type": "fixed", "value": 0.},
    {"name": "X", "parameter_type": ax.ParameterType.INT, "type": "range", "bounds": [1, 7]}, # , "log_scale": True},
    {"name": "period", "parameter_type": ax.ParameterType.INT, "type": "fixed", "value": 1}
]

params['days_vs_d']=[
    {"name": "R0", "type": "fixed", "value": 2.38461532},
    {"name": "q", "type": "fixed", "value": 0.175},
    {"name": "d", "type": "range", "bounds": [0.0, 1.0]}, #"type": "fixed", "value": 1.0},
    {"name": "X", "parameter_type": ax.ParameterType.INT, "type": "range", "bounds": [1, 7]}, # , "log_scale": True},
    {"name": "period", "parameter_type": ax.ParameterType.INT, "type": "fixed", "value": 1}
]


params['period_vs_R0']=[
    {"name": "R0", "type": "range", "bounds": [1.0, 4.4]},
    {"name": "q", "type": "fixed", "value": 0.175},
    {"name": "d", "type": "fixed", "value": 0.},
    {"name": "X", "type": "fixed", "value": 2},
    {"name": "period", "parameter_type": ax.ParameterType.INT, "type": "range", "bounds": [1, 7]}
]

params['period_vs_q']=[
    {"name": "R0", "type": "fixed", "value": 2.38461532},
    {"name": "q", "type": "range", "bounds": [1e-6, 0.5]},
    {"name": "d", "type": "fixed", "value": 0.},
    {"name": "X", "type": "fixed", "value": 2},
    {"name": "period", "parameter_type": ax.ParameterType.INT, "type": "range", "bounds": [1, 7]}
]

params['period_vs_d']=[
    {"name": "R0", "type": "fixed", "value": 2.38461532},
    {"name": "q", "type": "fixed", "value": 0.175},
    {"name": "d", "type": "range", "bounds": [0.0, 1.0]}, #"type": "fixed", "value": 1.0},
    {"name": "X", "type": "fixed", "value": 2},
    {"name": "period", "parameter_type": ax.ParameterType.INT, "type": "range", "bounds": [1, 7]}
]

def do_experiment(experiment):
    parameters = params[experiment]
    outfile = f'results/{experiment}_SIDARTHE'
    print("Experiment", experiment)

    # NOTE: if search space is needed (for developer API)
    from ax.service.utils.instantiation import parameter_from_json
    param_list = [parameter_from_json(p) for p in parameters]
    search_space = ax.SearchSpace(parameters=param_list, parameter_constraints=None)

    exp = ax.SimpleExperiment(
        name="test_experiment",
        search_space=search_space,
        evaluation_function=eval_fn,
        objective_name="peak_daily",
        minimize=True
    )

    # repeatedly evaluate SOBOL and save to disk
    sobol = Models.SOBOL(exp.search_space)
    for i in range(10):
        print(f"Running trials {i*n_sobol} to {(i+1)*n_sobol}..")
        exp.new_batch_trial(generator_run=sobol.gen(n_sobol))
        exp.eval()
        ax.save(exp, outfile+f'_{i}.json')


    for i in range(n_gpei):
        gp = Models.GPEI(experiment=exp, data=exp.eval())
        print(f"GPEI trial {i+1}")
        exp.new_trial(generator_run=gp.gen(batch_size))

experiments = ['days_vs_R0', 'days_vs_d', 'days_vs_q', 'period_vs_R0', 'period_vs_d', 'period_vs_q']
for e in experiments:
    do_experiment(e)

### Resume and complete remaining 900 simulations

In [None]:
# Developer API
batch_size = 1 # number of parallel GPEI trials
n_continue=10
n_sobol=10
n_gpei=1

def do_experiment(experiment):
    parameters = params[experiment]
    outfile = f'results/{experiment}_SIDARTHE'
    print("Experiment", experiment)

    param_list = [parameter_from_json(p) for p in parameters]
    search_space = ax.SearchSpace(parameters=param_list, parameter_constraints=None)

    infile = outfile+f'_{n_continue-1}.json'
    exp = ax.load(infile)
    exp.evaluation_function = eval_fn

    # repeatedly evaluate SOBOL and save to disk
    sobol = Models.SOBOL(exp.search_space)
    for i in range(90):
        print(f"Running trials {(i+n_continue)*n_sobol} to {(i+1+n_continue)*n_sobol}..")
        exp.new_batch_trial(generator_run=sobol.gen(n_sobol))
        exp.eval()
        ax.save(exp, outfile+f'_{i+n_continue}.json')


    for i in range(n_gpei):
        gp = Models.GPEI(experiment=exp, data=exp.eval())
        print(f"GPEI trial {i+1}")
        exp.new_trial(generator_run=gp.gen(batch_size))

for e in experiments:
    do_experiment(e)

## Visualisation

In [None]:
def save_fig(ax_config, filename):
    fig = go.Figure(data=ax_config.data)
    fig.update_layout(
        font=dict(
            size=18
        ),
        xaxis=dict(tickfont_size=18),
        yaxis=dict(tickfont_size=18)
    )
    fig.write_image(filename+'.png')
    fig.write_image(filename+'.eps')
    
def plot_level_curve(experiment, n_trials, x, y, metric='peak_daily'):
    idx = int(n_trials-1)
    exp = ax.load(f'results/{experiment}_SIDARTHE_{idx}.json')
    model = Models.GPEI(experiment=exp, data=exp.eval())
    ax_config = plot_contour(
        model, 
        param_x=x, 
        param_y=y, 
        metric_name=metric
    )
    save_fig(ax_config, f'results/{experiment}_SIDARTHE_{idx}')
    render(ax_config)

In [None]:
plot_level_curve('days_vs_R0', 100, x='X', y='R0')
plot_level_curve('days_vs_d', 100, x='X', y='d')
plot_level_curve('days_vs_q', 100, x='X', y='q')
plot_level_curve('period_vs_R0', 100, x='period', y='R0')
plot_level_curve('period_vs_d', 100, x='period', y='d')
plot_level_curve('period_vs_q', 100, x='period', y='q')