# MCMC CALIBRATION TECHNICS IN CONTEXT OF  INFECTIOUS DISEASE MODELING

## Prerequies

In [None]:
# pip install multiprocess
# %pip install --upgrade --force-reinstall multiprocess

#Compatible with latest jax version  
# %pip install summerepi2==1.3.6
# %pip install jinja2

In [1]:
# This is required for pymc parallel evaluation in notebooks

import multiprocess as mp
import platform

if platform.system() != "Windows":
    
    mp.set_start_method('forkserver')

In [2]:
import Calibrate as cal #Runing the calibration process and gathering results
from calibs_utilities import get_all_priors, get_targets, load_data
from models.models import model1 #All the models we design for the test
from Calibrate import plot_comparison_bars

# Combining tagets and prior with our summer2 model in a BayesianCompartmentalModel (bcm_model_1_model_1_model_1_model_1_model_1_model_1_model_1)
from estival.model import BayesianCompartmentalModel


import pandas as pd
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
from typing import List

import pymc as pm

# We use estivals parallel tools to run the model evaluations
from estival.utils.parallel import map_parallel

# import numpyro
# from numpyro import distributions as dist
from numpyro import infer
import arviz as az
import jax
from jax import numpy as jnp




## Calibration Data
If data are needed to define a target, we just need to import it from our file data which we will design correctly.
In this example we import data from a YAML file.

In [3]:
#An example of data for the calibration 
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

output_labels = {"index": "time", "value": "number infectious"}

targets_yml = './data/target_yml.yml'
targets = load_data(targets_yml)
targets_data = targets['active_cases']

# targets_data.plot(kind="scatter",labels=output_labels)



In [4]:
#Names of parameters and their ranges
params = {
    "contact_rate": (0.0,1.0),
    "recovery_rate": (0.0,1.0)

}
targets = get_targets(targets_yml)
priors = get_all_priors(params)

## Model Definition and Configuration

A mechanistic model (ODE-Based) model discribing Infectious Disease transmission.

In [5]:
model_1 = model1()

### Trial run 

In [6]:
parameters = {
    "contact_rate": 0.2,
    "recovery_rate": 0.1,
    #"active_cases_dispersion": 0.5,
}

model_1.run(parameters)


pd.DataFrame(
    {
        "modelled": model_1.get_outputs_df()["infectious"],
        "observed": targets_data,
    }
).plot(kind="scatter", labels=output_labels)

### Sampling 

In [7]:
#Defining  a Bayesian Compartmental Model

bcm_model_1 = BayesianCompartmentalModel(model_1, parameters, priors, targets)
T = bcm_model_1.targets["active_cases"]
T.stdev

2000.0

In [8]:
##____Uniform Initialisation_________
init_vals = []
for c in range(4):
    init_vals.append({param: np.random.uniform(0.0,1.0) for param in parameters.keys()})
 
init_vals

[{'contact_rate': 0.006952903622498674, 'recovery_rate': 0.4316636377853278},
 {'contact_rate': 0.5846087007436249, 'recovery_rate': 0.5094688392103793},
 {'contact_rate': 0.06708807968719022, 'recovery_rate': 0.6345131377516215},
 {'contact_rate': 0.9103110726055844, 'recovery_rate': 0.8033968169898004}]

#### Simple Run

In [9]:
%%time
D = 2 # Dimension of the parameter's space
samplers = [infer.NUTS]*2 + [pm.DEMetropolisZ]*2 + [pm.DEMetropolis]*2 + [pm.Metropolis]*4
Draws = [2000]*2 + [4000]*6+ [8000]*2
Tunes = [100,1000]*5
chains = 2*D
results = []
# import logging
# logger = logging.getLogger("pymc5")
# logger.setLevel(logging.ERROR)
# logger.propagate = False
# logging.basicConfig(filename='pymc5_warnings.log', level=logging.ERROR)

for sampler, draws, tune in zip (samplers, Draws, Tunes):
    #print(sampler)
    idata, Time = cal.Sampling_calib(
        bcm_model = bcm_model_1,
        mcmc_algo = sampler,
        initial_params = init_vals,
        draws = draws,
        tune = tune,
        cores = 4,
        chains = chains,
        )

    results.append(cal.Compute_metrics(
        mcmc_algo = sampler,
        idata = idata,
        Time = Time,
        draws = draws, 
        chains = chains,
        tune = tune,
            )
        )
    

Multiprocess sampling (4 chains in 4 jobs)
DEMetropolisZ: [contact_rate, recovery_rate]
Sampling 4 chains for 100 tune and 4_000 draw iterations (400 + 16_000 draws total) took 45 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Multiprocess sampling (4 chains in 4 jobs)
DEMetropolisZ: [contact_rate, recovery_rate]
Sampling 4 chains for 1_000 tune and 4_000 draw iterations (4_000 + 16_000 draws total) took 47 seconds.
Population sampling (4 chains)
DEMetropolis: [contact_rate, recovery_rate]
Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`.
Sampling 4 chains for 100 tune and 4_000 draw iterations (400 + 16_000 draws total) took 40 seconds.
The 

CPU times: total: 5min 3s
Wall time: 10min 56s


In [10]:
results_df = pd.concat(results)
results_df["Run"] = results_df.Sampler + "\nDraws=" + results_df.Draws.astype(str) + "\nTune=" + results_df.Tune.astype(str)

results_df = results_df.reset_index(drop=True)
results_df.style.set_caption("MCMC COMPARISON")
# print(results_df.Sampler)

Unnamed: 0,Sampler,Draws,Chains,Tune,Time,Mean_ESS,Min_Ess,Ess_per_sec,Mean_Rhat,Rhat_max,Trace,Run
0,NUTS,2000,4,100,84.176366,11.102887,10.717993,0.1319,1.274811,1.289247,Inference data with groups: 	> posterior 	> log_likelihood 	> sample_stats 	> observed_data,NUTS Draws=2000 Tune=100
1,NUTS,2000,4,1000,90.804782,1155.426503,1152.050439,12.724291,1.004963,1.00512,Inference data with groups: 	> posterior 	> log_likelihood 	> sample_stats 	> observed_data,NUTS Draws=2000 Tune=1000
2,DEMetropolisZ,4000,4,100,52.203054,69.904049,42.262172,1.33908,1.072762,1.079728,Inference data with groups: 	> posterior 	> sample_stats,DEMetropolisZ Draws=4000 Tune=100
3,DEMetropolisZ,4000,4,1000,52.484066,2086.963489,2051.277897,39.763754,1.001412,1.001657,Inference data with groups: 	> posterior 	> sample_stats,DEMetropolisZ Draws=4000 Tune=1000
4,DEMetropolis,4000,4,100,45.590847,7.66211,5.530279,0.168062,1.671135,2.011277,Inference data with groups: 	> posterior 	> sample_stats,DEMetropolis Draws=4000 Tune=100
5,DEMetropolis,4000,4,1000,54.318143,1003.793759,1002.247847,18.479898,1.002253,1.002922,Inference data with groups: 	> posterior 	> sample_stats,DEMetropolis Draws=4000 Tune=1000
6,Metropolis,4000,4,100,50.023454,4.815239,4.81349,0.09626,2.467098,2.469732,Inference data with groups: 	> posterior 	> sample_stats,Metropolis Draws=4000 Tune=100
7,Metropolis,4000,4,1000,60.580051,14.195786,14.151174,0.234331,1.209595,1.209876,Inference data with groups: 	> posterior 	> sample_stats,Metropolis Draws=4000 Tune=1000
8,Metropolis,8000,4,100,69.771679,4.732203,4.729177,0.067824,2.587542,2.59212,Inference data with groups: 	> posterior 	> sample_stats,Metropolis Draws=8000 Tune=100
9,Metropolis,8000,4,1000,71.444511,36.685616,36.432387,0.513484,1.07289,1.073454,Inference data with groups: 	> posterior 	> sample_stats,Metropolis Draws=8000 Tune=1000


In [None]:
def plot_comparison_Bars(results_df: pd.DataFrame):
    fig, axes = plt.subplots(1, 2, figsize=(22, 6))
    ax = axes[0]
    ax.bar(x=results_df["Run"], height=results_df["Ess_per_sec"])#, legend=False)
    ax.set_title("ESS per Second")
    ax.set_xlabel("")
    labels = ax.get_xticklabels()
    """
    ax = axes[1]
    results_df.plot.bar(y="ESS_pct", x="Run", ax=ax, legend=False)
    ax.set_title("ESS Percentage")
    ax.set_xlabel("")
    labels = ax.get_xticklabels()
    """
    ax = axes[1]
    ax.bar(x=results_df["Run"], height=results_df["Mean_Rhat"])#, legend=False)
    ax.set_title(r"$\hat{R}$")
    ax.set_xlabel("")
    ax.set_ylim(1)
    labels = ax.get_xticklabels()
    plt.suptitle(f"Comparison of Runs for ... Target Distribution", fontsize=16)

    plt.tight_layout()
    plt.show()

In [None]:
plot_comparison_Bars(results_df=results_df)


### Multiple runs

In [None]:
# iterations = np.range(1,50)
# @jax.jit
def run_comparison():
    results = []
    samplers = [pm.Metropolis] #[infer.NUTS] + [pm.DEMetropolisZ] + [pm.DEMetropolis] + [pm.Metropolis]
    Draws = [2000]# + [4000]*2 + [8000] 
    tune = 1000
    for sampler, draws in zip (samplers, Draws):
        idata, Time = cal.Sampling_calib(
            bcm_model = bcm_model_1,
            mcmc_algo = sampler,
            initial_params = init_vals,
            draws = draws,
            tune = tune,
            cores = 4,
            chains = 4,
            )

        results.append(cal.Compute_metrics(
            mcmc_algo = sampler,
            idata = idata,
            Time = Time,
            draws = draws, 
            chains = 4,
            tune = tune,
                )
            )
    results_df = pd.concat(results)
    results_df["Run"] = results_df.Sampler + "\nDraws=" + results_df.Draws.astype(str) + "\nTune=" + results_df.Tune.astype(str)

    results_df = results_df.reset_index(drop=True)
    return results_df

In [None]:
%%time
All_results = []
#_________Removing the PYMC text outputs in the consol______
import logging
logger = logging.getLogger('pymc5')
logger.propagate = False
#___________________________________________________________
for k in jnp.arange(2):
    # print(k)
    Res = run_comparison()
    All_results.append(Res)


## Using arviz for trace visualization

In [None]:
for idata, Run, draws, tune in zip(results_df.Trace, results_df.Run, results_df.Draws, results_df.Tune):
    subset = idata.sel(draw=slice(0, None), groups="posterior")
    print("Run = ",Run)
    az.plot_trace(subset, figsize=(16,3.2*len(subset.posterior)),compact=False)#, lines=[("m", {}, mtrue), ("c", {}, ctrue)]);
    plt.show()



In [None]:
idata["sample_stats"]

In [None]:
from estival.sampling.tools import likelihood_extras_for_idata
from estival.sampling import tools as esamptools

## Computing the likelihood function for each sampler

In [None]:
from estival.model.base import ResultsData

In [None]:
map_res : List[ResultsData]
targets_datas = pd.DataFrame(targets_data)

targets_datas.values

In [None]:
#i=0 
#pd.options.plotting.backend = "matplotlib"
#map_res = np.zeros(8, dtype=ResultsData) #To specify the type of the elements in the list, default is tuple
map_res : List[ResultsData]
map_params = []
for idata, sampler ,run in zip(results_df.Trace, results_df.Sampler, results_df.Run):
    if (sampler == "metropolis"): #Because pm.metropolis is not compatible directly with the likelihood function from est.samp.tools
        # print(sampler)
       likelihood_df = esamptools.likelihood_extras_for_samples(idata.posterior, bcm_model_1)

    else :
        likelihood_df = likelihood_extras_for_idata(idata, bcm_model_1)
    ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
    map_parameter = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
    print(idata.posterior.to_dataframe().loc[ldf_sorted.index[0]])

    map_params.append(map_parameter)
    map_res.append(bcm_model_1.run(map_parameter))
    # res = bcm_model_1.run(map_parameter)
    # output = res.derived_outputs["active_cases"]

    # ax = pd.DataFrame(
    #     {
    #         f"{run}": output,
    #         "Observed" : targets_datas,
    #     }
    # ).plot(kind="scatter", x=targets_datas.index, y=targets_datas.values, labels=output_labels)
    # plt.show()
    #map_res[i] = bcm_model_1.run(map_parameter)
    #i+=1
        



In [None]:
for res in zip (map_res):
    res

res[0].derived_outputs["active_cases"]

In [None]:
dfs = []

for sampler, res, draws in zip(results_df.Sampler, map_res, results_df.Draws):
    res = res[0]
    ax = pd.DataFrame(
        {
            f"{sampler}, draws = {draws}": res.derived_outputs["active_cases"],
            "Observed" : targets_data,
        }
    ).plot(kind="scatter", labels=output_labels)
   
    
    plt.show()

#    comb_df = pd.concat(dfs, axis=1)

#    comb_df.plot(kind="scatter", labels=output_labels)
# comb_df

In [None]:
# model_1.run(map_params)
"""
pd.DataFrame(
    {
        "DEMetropolisZ": map_res[0].derived_outputs["active_cases"],
        "DEMetropolis": map_res[1].derived_outputs["active_cases"],
        "Metropolis": map_res[2].derived_outputs["active_cases"],
        f"Metropolis, draws": map_res[3].derived_outputs["active_cases"],


        "observed": targets_data,
    }
).plot(kind="scatter", labels=output_labels)
"""


## Uncertainty sampling

In [None]:
# Use the arviz extract method to obtain some samples, then convert to a DataFrame
sample_idata = az.extract(idata, num_samples=4000)
samples_df = sample_idata.to_dataframe().drop(columns=["chain","draw"])


In [None]:
idata

In [None]:
# Wrapper function captures our bcm from the main namespace to pass into map_parallel
# Using this idiom in closures/factory functions is typical
def run_sample(idx_sample):
    idx, params = idx_sample
    return idx, bcm_model_1.run(params)

# Run the samples through our BCM using the above function
# map_parallel takes a function and an iterable as input

# We use 4 workers here, default is cpu_count/2 (assumes hyperthreading)
sample_res = map_parallel(run_sample, samples_df.iterrows(), n_workers=4)


In [None]:
sample_res

In [None]:
# We'll use xarray for this step; aside from computing things very quickly, it's useful
# to persist the run results to netcdf/zarr etc

import xarray as xr

In [None]:
# Build a DataArray out of our results, then assign coords for indexing
xres = xr.DataArray(np.stack([r.derived_outputs for idx, r in sample_res]), 
                    dims=["sample","time","variable"])
xres = xres.assign_coords(sample=sample_idata.coords["sample"], 
                          time=map_res[-1].derived_outputs.index, variable=map_res[-1].derived_outputs.columns)

In [None]:
# Set some quantiles to calculate
quantiles = (0.25,0.5,0.75,0.80,0.95)

# Generate a new DataArray containing the quantiles
xquantiles = xres.quantile(quantiles,dim=["sample"])

In [None]:
# Extract these values to a pandas DataFrame for ease of plotting

uncertainty_df = xquantiles.to_dataframe(name="value").reset_index().set_index("time").pivot(columns=("variable","quantile"))["value"]

In [None]:
variable = "active_cases"

fig = uncertainty_df[variable].plot(title=variable,alpha=0.7)
pd.Series(map_res[-1].derived_outputs[variable]).plot(style='--')
bcm_model_1.targets[variable].data.plot(style='.',color="black", ms=5, alpha=0.8)

## Analysing the posterior likelihood landscape analysis using ELA

In [None]:
# !pip install pflacco
from pflacco.classical_ela_features import *
from pflacco.local_optima_network_features import compute_local_optima_network, calculate_lon_features
#__To___create_a_initial____sample
from pflacco.sampling import create_initial_sample