In [None]:
import sys
sys.path.append("../") 

In [None]:
import pandas as pd
import numpy as np

In [None]:
from epimodel.preprocessing import preprocess_data, PreprocessedData
from epimodel.epiparam import EpidemiologicalParameters

In [None]:
from epimodel.models.giprior_complex_mixed_intervention_rw_model import giprior_complex_mixed_intervention_rw_model

In [None]:
import pickle

data_pickle = pickle.load(open('uk_test_set.pkl', 'rb'))

In [None]:
data = PreprocessedData(data_pickle['regions'], data_pickle['days'], data_pickle['CMs'], data_pickle['new_cases'], data_pickle['new_deaths'], data_pickle['active_cms'])

In [None]:
data.new_cases = np.ma.array(data.new_cases)
data.new_cases[:, :10] = np.ma.masked
data.new_deaths = np.ma.array(data.new_deaths)
data.new_deaths[:, :30] = np.ma.masked

In [None]:
data.la_indices = [np.arange(80)[::4], np.arange(80)[1::4], np.arange(80)[2::4], np.arange(80)[3::4]]
data.countries = ['liverpool', 'oxford', 'cambridge', 'london']

In [None]:
ep = EpidemiologicalParameters()

In [None]:
import matplotlib.pyplot as plt

In [None]:
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random
import jax.numpy as jnp
import jax.scipy.signal as jss

import jax

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import time 

In [None]:
from numpyro.infer import init_to_median
nuts_kernel = NUTS(giprior_complex_mixed_intervention_rw_model, init_strategy = init_to_median)

mcmc = MCMC(nuts_kernel, num_samples=500, num_warmup=500, num_chains=1)
rng_key = random.PRNGKey(0)

start = time.time()

with numpyro.validation_enabled():
    res = mcmc.run(rng_key, data, ep)
    
end = time.time()

posterior_samples = mcmc.get_samples(np.array([0]))

print(f'Sampling 1000 samples per chain took {end - start}')

# Notes on implementation:

* includes uncertainty in the Generation Interval with a discrete renewal model!
* compile time was slow! took about 4 minutes to compile into XLA - might be worth looking into reusing this if need be, if that can be done.
* i think I could switch to numpyro plate, or similar?
* once it gets going it does a bit better :) 

In [None]:
import arviz as az

In [None]:
inference_data = az.from_numpyro(mcmc)

In [None]:
inference_data.sample_stats.diverging.data.sum()

In [None]:
%load_ext autoreload 

%autoreload 2

In [None]:
from epimodel.plotting import epicurve_plots, intervention_plots, param_plots

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
intervention_plots.plot_intervention_effectiveness(posterior_samples, data.CMs, xlim=[-50, 50])

In [None]:
posterior_samples['sigma_i'].shape

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
intervention_plots.plot_intervention_sd(posterior_samples, data.CMs, xlim=[0, 1])

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
intervention_plots.plot_intervention_correlation(posterior_samples, data.CMs)

In [None]:
posterior_samples.keys()

In [None]:
posterior_samples.keys()

In [None]:
param_plots.plot_gi(posterior_samples)

In [None]:
np.sum(['cd_' in k for k in posterior_samples.keys()])

In [None]:
data.countries

In [None]:
posterior_samples.keys()

In [None]:
import seaborn as sns

In [None]:
param_plots.plot_cases_death_delays(data, posterior_samples)

In [None]:
param_plots.plot_output_noise_scales(posterior_samples)

In [None]:
param_plots.plot_rw_noise_scales(posterior_samples)

In [None]:
epicurve_plots.area_summary_plot(posterior_samples, 0, data)

In [None]:
epicurve_plots.area_transmission_plot()

In [None]:
posterior_samples['iar']

In [None]:
log_iar_noise[:, 0].reshape((500, 1)).repeat(122, axis=1).shape

In [None]:
jnp.exp(log_iar_noise - log_iar_noise[:, 0].reshape((500, 1)).repeat(122, axis=1))

In [None]:
plt.plot(np.median(posterior_samples['cfr'], axis=0))

In [None]:
posterior_samples['noisepoint_log_iar_noise_series'][:, r_i, :].shape

In [None]:
for r_i in range(80):
    plt.figure()
    log_iar_noise = jnp.repeat(
            jnp.cumsum(np.median(posterior_samples['iar_walk_noise_scale']) * posterior_samples['noisepoint_log_iar_noise_series'][:, r_i, :], axis=-1),
            7,
            axis=-1,
        )[:, :len(data.Ds)]
    
    li, lq, m, uq, ui = np.percentile(jnp.exp(log_iar_noise - log_iar_noise[:, 0].reshape((500, 1)).repeat(122, axis=1)), [2.5, 25, 50, 75, 97.5], axis=0)

    plt.plot(data.Ds, m, color="k")
    plt.fill_between(data.Ds, li, ui, color="k", alpha=0.1, linewidth=0)
    plt.fill_between(data.Ds, lq, uq, color="k", alpha=0.3, linewidth=0)


In [None]:
iar_0 = 1
iar_walk_noise_scale = 1
    # number of 'noise points'
nNP = int(data.nDs / 7) + 1
noisepoint_log_iar_noise_series = np.random.normal(loc=jnp.zeros(nNP))

log_iar_noise = jnp.repeat(
    jnp.cumsum(iar_walk_noise_scale * noisepoint_log_iar_noise_series, axis=-1),
    7,
    axis=-1,
)
iar_t = iar_0 * 5*jax.nn.sigmoid(log_iar_noise)

In [None]:
jax.nn.sigmoid(0.0)

In [None]:
plt.plot(iar_t)

In [None]:
iar_t

In [None]:
jax.nn.sigmoid(0.5)