In [None]:
from summer2 import CompartmentalModel
from summer2.parameters import Parameter


In [None]:
def build_model():

    sir_model = CompartmentalModel([0.0,100.0],["S","I","R"],["I"])
    sir_model.set_initial_population({"S": 999.0, "I": 1.0})
    sir_model.add_infection_frequency_flow("infection",Parameter("contact_rate"),"S","I")
    sir_model.add_transition_flow("recovery",Parameter("recovery_rate"),"I","R")

    sir_model.request_output_for_flow("incidence", "infection")
    
    return sir_model

sir_model = build_model()

In [None]:
parameters = {
    "contact_rate": 0.3,
    "recovery_rate": 0.1
}
sir_model.run(parameters)
res = sir_model.get_derived_outputs_df()
res['incidence'].plot()

# Sample from a known distribution

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import truncnorm

def sample_from_truncnorm(mean, std_dev, lower_bound, upper_bound, sample_size, name):
    a = (lower_bound - mean) / std_dev
    b = (upper_bound - mean) / std_dev
    samples = truncnorm.rvs(a, b, loc=mean, scale=std_dev, size=sample_size)

    return pd.DataFrame(samples, columns=[name])

samples = {
    "contact_rate":  pd.concat(
        [
            sample_from_truncnorm(0.225, 0.005, 0.2, 0.25, 10000, "contact_rate"),
            sample_from_truncnorm(0.3, 0.005, 0.25, 0.35, 10000, "contact_rate"),
        ],       
        ignore_index=True
    )
}

In [None]:
import seaborn as sns
sns.kdeplot(samples["contact_rate"], fill=True)

# Run model forward (i.e. feed the samples to the model)

In [None]:
from estival.model import BayesianCompartmentalModel
import estival.priors as esp
import estival.targets as est
from estival.sampling import tools as esamp


priors = [
    esp.UniformPrior("contact_rate", [0, 1]),
]
targets = []
bcm = BayesianCompartmentalModel(model=sir_model,priors=priors, targets=targets,parameters=parameters)
samples_for_estival = [{"contact_rate": samples["contact_rate"].iloc[i]} for i in range(len(samples["contact_rate"]))]


model_runs = esamp.model_results_for_samples(samples_for_estival, bcm)

In [None]:
model_runs.results['incidence'].plot(legend=False)

## Collect the synthetic data and generate likelihood components

In [None]:
data_times = list(range(20, 81, 10))
len(data_times)

In [None]:
from jax.scipy.stats import gaussian_kde
import jax.numpy as jnp

likelihood_comps = {t: gaussian_kde(jnp.array(model_runs.results['incidence'].loc[t]), bw_method=0.01) for t in data_times}

In [None]:
# Check one likelihood component
import numpy as np
import matplotlib.pyplot as plt

for t in data_times:
    kde = likelihood_comps[t]
    x_values = np.linspace(0, 50, 1000)
    pdf_values = kde(x_values)
    plt.plot(x_values, pdf_values)

    model_runs.results['incidence'].loc[t].plot.hist(density=True, bins=50)
    plt.show()

# Refit the model using the likelihood components derived from synthetic data

In [None]:
from jax import lax

# Flat prior
priors = [
    esp.UniformPrior("contact_rate", [0.1, 0.5]),
]
n_data_points = len(data_times)
# Define a custom target using the likelihood components
def make_eval_func(t):
    def eval_func(modelled, obs, parameters, time_weights):
        likelihood_comp = likelihood_comps[t](modelled) 
        likelihood_comp = jnp.max(jnp.array([likelihood_comp, jnp.array([1.e-300])]))  # to avoid zero values.
        return jnp.log(likelihood_comp) / n_data_points

    return eval_func

targets = [est.CustomTarget(f"likelihood_comp_{t}", pd.Series([0.], index=[t]), make_eval_func(t), model_key='incidence') for t in data_times]

refit_bcm = BayesianCompartmentalModel(model=sir_model,priors=priors, targets=targets,parameters=parameters)

### Pymc sampler

In [None]:
import pymc as pm
from estival.wrappers import pymc as epm

In [None]:
chains = 4
Tab = np.zeros(4)
init_vals = []
for c in range(chains):
    init_vals.append({"contact_rate": np.random.uniform(0.20,0.4) })
    Tab[c] =  np.random.uniform(0.20,0.4)

In [None]:
init_vals

In [None]:
T = [{'contact_rate': 0.34},
 {'contact_rate': 0.20},
 {'contact_rate': 0.24},
 {'contact_rate': 0.4}]


In [None]:
IDATA = dict()
for sampler in [pm.Metropolis, pm.DEMetropolis, pm.DEMetropolisZ]:
    with pm.Model() as model:    
        variables = epm.use_model(refit_bcm)
        IDATA[sampler.__name__] = pm.sample(step=[sampler(variables)], initvals = init_vals , draws=20000, tune=1000,cores=4,chains=4)  #, initvals = [{'contact_rate': x} for x in [0.1, 0.4 ,0.6 ,0.8]])


#### Sequential Monte Carlo

In [None]:
draws = 1000
#The starting points are drawn from the prior.
#For now the manual initialisation is not working
sampler = pm.sample_smc
with pm.Model() as model:
    variables = epm.use_model(refit_bcm)
    IDATA[sampler.__name__] = pm.sample_smc(kernel=pm.smc.IMH, start = None, draws=draws,chains=4, threshold = 0.1,correlation_threshold=0.5)

### NUTS sampling (Numpyro)

In [None]:
import numpyro
from numpyro import infer
from numpyro import distributions as dist
from jax import random



def nmodel():
    sampled = {"contact_rate":numpyro.sample("contact_rate", dist.Uniform(0.0,1.0))}# for k in refit_bcm.parameters}
    ll = numpyro.factor("ll", refit_bcm.loglikelihood(**sampled))

    

In [None]:
refit_bcm.parameters

In [None]:
init_vals_nuts = {"contact_rate": jnp.full(4, 0.26) }

init_vals_nuts

In [None]:
kernel = infer.NUTS(nmodel)
mcmc = infer.MCMC(kernel, num_warmup=1000, num_chains=4, num_samples=2000, progress_bar=True)

mcmc.run(random.PRNGKey(0), init_params=init_vals_nuts)#{'contact_rate': np.array([0.26,0.26,0.26,0.26])})

In [None]:
import arviz as az

In [None]:
IDATA["NUTS"] = az.from_numpyro(mcmc)

In [None]:
az.rhat(IDATA["sample_smc"])

In [None]:
for sampler , idata in IDATA.items():
    print(sampler)
    az.plot_trace(idata)
    plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 3))
i = 0
for sampler , idata in IDATA.items():
    ax = axes[i]
    posterior_sample = idata.posterior.to_dataframe()['contact_rate'].to_list()
    # plt.hist(samples["contact_rate"],histtype='step', bins=50, density=True, label="true sample")
    # plt.hist(posterior_sample, bins=50, histtype='step',density=True, label="posterior by "+ sampler)
    sns.kdeplot(samples["contact_rate"],ax = ax, fill=True, label="true sample")
    sns.kdeplot(posterior_sample,ax = ax, fill=True, label= sampler)
    ax.legend(loc = "upper center")
    i = i+1
    # ax.set_xlabel("")

plt.suptitle(f"Posterior by different MCMC samplers", fontsize=12)
plt.tight_layout()


In [None]:
lls = esamp.likelihood_extras_for_idata(idata, refit_bcm)

In [None]:
lls['logposterior'].min()

In [None]:
lls['logposterior'].plot.hist()

In [None]:
posterior_model_runs = esamp.model_results_for_samples(idata, refit_bcm)

In [None]:
posterior_model_runs.results['incidence'].plot(legend=False)

In [None]:
model_runs.results['incidence'].plot(legend=False)