In [None]:
# Uncomment the line below to install on Colab or similar
#!pip install git+https://github.com/monash-emu/renewal.git@108e577

In [None]:
from datetime import datetime, timedelta

from jax import jit, random
import pandas as pd
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from plotly.express.colors import qualitative as qual_colours

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti
from emu_renewal.outputs import plot_uncertainty_patches, plot_post_prior_comparison, plot_priors
from emu_renewal.calibration import StandardCalib

In [None]:
iso = "MYS"
approx_pops = {
    'MYS': 33e6,
    'PHL': 114e6,
    'VNM': 97e6,
}

pop = approx_pops[iso]

In [None]:
# Specify fixed parameters and get calibration data
proc_update_freq = 14
init_time = 50

case_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)[iso]
case_data.index = pd.to_datetime(case_data.index)

analysis_start = datetime(2021, 4, 1)
analysis_end = datetime(2021, 11, 1)
init_start = analysis_start - timedelta(init_time)
init_end = analysis_start - timedelta(1)
select_data = case_data.loc[analysis_start: analysis_end]
init_data = case_data.loc[init_start: init_end]

In [None]:
renew_model = RenewalModel(pop, analysis_start, analysis_end, proc_update_freq, CosineMultiCurve(), GammaDens(), 50, init_data, GammaDens())

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.TruncatedNormal(5.0, 0.4, low=1.0),
    "gen_sd": dist.TruncatedNormal(3.8, 0.5, low=0.01),
    "cdr": dist.Beta(2.8, 10.0),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.TruncatedNormal(5.0, 0.5, low=1.0),
    "report_sd": dist.TruncatedNormal(2.0, 0.5, low=0.01),
}

In [None]:
calib = StandardCalib(renew_model, priors, select_data)
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=calib.custom_init(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
mcmc.run(random.PRNGKey(1))

In [None]:
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=200)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
@jit
def get_full_result(**kwargs):
    return renew_model.renewal_func(**kwargs | calib.fixed_params)

spaghetti = get_spaghetti_from_params(renew_model, sample_params, get_full_result)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles=[0.05, 0.5, 0.95])
plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly).update_layout(showlegend=False)

In [None]:
az.summary(idata)

In [None]:
plot_post_prior_comparison(idata, [p for p in priors if p not in calib.fixed_params], priors);