# MCMC CALIBRATION TECHNICS IN CONTEXT OF  INFECTIOUS DISEASE MODELING

## Prerequies

In [None]:
# pip install multiprocess
# %pip install --upgrade --force-reinstall multiprocess

#Compatible with latest jax version  
# %pip install summerepi2==1.3.6
# %pip install jinja2

In [None]:
# This is required for pymc parallel evaluation in notebooks

import multiprocess as mp
import platform

if platform.system() != "Windows":
    
    mp.set_start_method('forkserver')

In [None]:
import Calibrate as cal #Runing the calibration process and gathering results
from calibs_utilities import get_all_priors, get_targets, load_data
from models.models import model1 #All the models we design for the test
from Calibrate import plot_comparison_Bars

# Combining tagets and prior with our summer2 model in a BayesianCompartmentalModel (bcm_model_1_model_1_model_1_model_1_model_1_model_1_model_1)
from estival.model import BayesianCompartmentalModel


import pandas as pd
import numpy as np
# import plotly.express as px
import matplotlib.pyplot as plt
from typing import List

import pymc as pm

# We use estivals parallel tools to run the model evaluations
from estival.utils.parallel import map_parallel

# import numpyro
# from numpyro import distributions as dist
from numpyro import infer
import arviz as az
import pickle
# import jax
# from jax import numpy as jnp


## Calibration Data
If data are needed to define a target, we just need to import it from our file data which we will design correctly.
In this example we import data from a YAML file.

In [None]:
#An example of data for the calibration 
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

output_labels = {"index": "time", "value": "number infectious"}

targets_yml = './data/target_yml.yml'
targets = load_data(targets_yml)
targets_data = targets['active_cases']

# targets_data.plot(kind="scatter",labels=output_labels)



In [None]:
#Names of parameters and their ranges
params = {
    "contact_rate": (0.0,1.0),
    "recovery_rate": (0.0,1.0)

}
targets = get_targets(targets_yml)
priors = get_all_priors(params)

## Model Definition and Configuration

A mechanistic model (ODE-Based) model discribing Infectious Disease transmission.

In [None]:
model_1 = model1()

### Trial run 

In [None]:
parameters = {
    "contact_rate": 0.2,
    "recovery_rate": 0.1,
    #"active_cases_dispersion": 0.5,
}

model_1.run(parameters)


pd.DataFrame(
    {
        "modelled": model_1.get_outputs_df()["infectious"],
        "observed": targets_data,
    }
).plot(kind="scatter", labels=output_labels)

### Sampling 

In [None]:
#Defining  a Bayesian Compartmental Model

bcm_model_1 = BayesianCompartmentalModel(model_1, parameters, priors, targets)
T = bcm_model_1.targets["active_cases"]
T.stdev

In [None]:
##____Uniform Initialisation_________
init_vals = []
for c in range(4):
    init_vals.append({param: np.random.uniform(0.0,1.0) for param in parameters.keys()})
 
init_vals

#### Simple Run

In [None]:
%%time
D = 2 # Dimension of the parameter's space
samplers = [infer.NUTS]*2 + [pm.DEMetropolisZ]*2 + [pm.DEMetropolis]*2 + [pm.Metropolis]*4
Draws = [2000]*2 + [4000]*6+ [8000]*2
Tunes = [100,1000]*5
chains = 2*D
df = pd.DataFrame()

for sampler, draws, tune in zip (samplers, Draws, Tunes):
    
    results = cal.multirun(sampler = sampler, 
            draws = draws,
            tune = tune,
            bcm_model = bcm_model_1,
            n_iterations = 1,
            n_jobs = 1,
            initial_params = init_vals

    )
            
    df = pd.concat([df,results])


results_df = df

results_df = results_df.reset_index(drop=True)


In [None]:
results_df.style.set_caption("MCMC COMPARISON") 

#Storing results on a pickle file
with open('./Results/Model_1/Experiment_1/Simple_run_results.pkl', 'wb') as fp:
    pickle.dump(results_df, fp)


# #Loading a pickle file
# with open('./Results/Model_1/Experiment_1/Simple_run_results.pkl', 'rb') as fp:
#     res = pickle.load(fp)

In [None]:
plot_comparison_Bars(results_df=results_df)


### Multiple runs

In [None]:
all_results = dict()

In [None]:
sampler = infer.NUTS
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 2000,
    tune = 1000, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 4,
    initial_params = init_vals
    )

In [None]:
sampler = pm.DEMetropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 4000,
    tune = 1000, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 4,
    initial_params = init_vals
    )

In [None]:
sampler = pm.DEMetropolisZ
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 4000,
    tune = 1000, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 4,
    initial_params = init_vals
    )

In [None]:
sampler = pm.Metropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 8000,
    tune = 1000, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 4,
    initial_params = init_vals
    )

In [None]:
#Storing the results for later analysis

with open('./Results/Model_1/Experiment_1/Simple_run_results.pkl', 'wb') as fp:
    pickle.dump(all_results, fp)


In [None]:
len(results[results["Rhat_max"] > 1.1])

## Using arviz for trace visualization

In [None]:
for idata, Run, draws, tune in zip(results_df.Trace, results_df.Run, results_df.Draws, results_df.Tune):
    subset = idata.sel(draw=slice(0, None), groups="posterior")
    print("Run = ",Run)
    az.plot_trace(subset, figsize=(16,3.2*len(subset.posterior)),compact=False)#, lines=[("m", {}, mtrue), ("c", {}, ctrue)]);
    plt.show()



In [None]:
#idata["sample_stats"]

In [None]:
from estival.sampling.tools import likelihood_extras_for_idata
from estival.sampling import tools as esamptools

## Computing the likelihood function for each sampler

In [None]:
from estival.model.base import ResultsData

In [None]:
map_res : List[ResultsData]
targets_datas = pd.DataFrame(targets_data)

targets_datas.values

In [None]:
#i=0 
#pd.options.plotting.backend = "matplotlib"
#map_res = np.zeros(8, dtype=ResultsData) #To specify the type of the elements in the list, default is tuple
map_res : List[ResultsData]
map_params = []
for idata, sampler ,run in zip(results_df.Trace, results_df.Sampler, results_df.Run):
    if (sampler == "metropolis"): #Because pm.metropolis is not compatible directly with the likelihood function from est.samp.tools
        # print(sampler)
       likelihood_df = esamptools.likelihood_extras_for_samples(idata.posterior, bcm_model_1)

    else :
        likelihood_df = likelihood_extras_for_idata(idata, bcm_model_1)
    ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
    map_parameter = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
    print(idata.posterior.to_dataframe().loc[ldf_sorted.index[0]])

    map_params.append(map_parameter)
    map_res.append(bcm_model_1.run(map_parameter))
    # res = bcm_model_1.run(map_parameter)
    # output = res.derived_outputs["active_cases"]

    # ax = pd.DataFrame(
    #     {
    #         f"{run}": output,
    #         "Observed" : targets_datas,
    #     }
    # ).plot(kind="scatter", x=targets_datas.index, y=targets_datas.values, labels=output_labels)
    # plt.show()
    #map_res[i] = bcm_model_1.run(map_parameter)
    #i+=1
        



In [None]:
for res in zip (map_res):
    res

res[0].derived_outputs["active_cases"]

In [None]:
dfs = []

for sampler, res, draws in zip(results_df.Sampler, map_res, results_df.Draws):
    res = res[0]
    ax = pd.DataFrame(
        {
            f"{sampler}, draws = {draws}": res.derived_outputs["active_cases"],
            "Observed" : targets_data,
        }
    ).plot(kind="scatter", labels=output_labels)
   
    
    plt.show()

#    comb_df = pd.concat(dfs, axis=1)

#    comb_df.plot(kind="scatter", labels=output_labels)
# comb_df

In [None]:
# model_1.run(map_params)
"""
pd.DataFrame(
    {
        "DEMetropolisZ": map_res[0].derived_outputs["active_cases"],
        "DEMetropolis": map_res[1].derived_outputs["active_cases"],
        "Metropolis": map_res[2].derived_outputs["active_cases"],
        f"Metropolis, draws": map_res[3].derived_outputs["active_cases"],


        "observed": targets_data,
    }
).plot(kind="scatter", labels=output_labels)
"""


## Uncertainty sampling

In [None]:
# Use the arviz extract method to obtain some samples, then convert to a DataFrame
sample_idata = az.extract(idata, num_samples=4000)
samples_df = sample_idata.to_dataframe().drop(columns=["chain","draw"])


In [None]:
# Wrapper function captures our bcm from the main namespace to pass into map_parallel
# Using this idiom in closures/factory functions is typical
def run_sample(idx_sample):
    idx, params = idx_sample
    return idx, bcm_model_1.run(params)

# Run the samples through our BCM using the above function
# map_parallel takes a function and an iterable as input

# We use 4 workers here, default is cpu_count/2 (assumes hyperthreading)
sample_res = map_parallel(run_sample, samples_df.iterrows(), n_workers=4)


In [None]:
sample_res

In [None]:
# We'll use xarray for this step; aside from computing things very quickly, it's useful
# to persist the run results to netcdf/zarr etc

import xarray as xr

In [None]:
# Build a DataArray out of our results, then assign coords for indexing
xres = xr.DataArray(np.stack([r.derived_outputs for idx, r in sample_res]), 
                    dims=["sample","time","variable"])
xres = xres.assign_coords(sample=sample_idata.coords["sample"], 
                          time=map_res[-1].derived_outputs.index, variable=map_res[-1].derived_outputs.columns)

In [None]:
# Set some quantiles to calculate
quantiles = (0.25,0.5,0.75,0.80,0.95)

# Generate a new DataArray containing the quantiles
xquantiles = xres.quantile(quantiles,dim=["sample"])

In [None]:
# Extract these values to a pandas DataFrame for ease of plotting

uncertainty_df = xquantiles.to_dataframe(name="value").reset_index().set_index("time").pivot(columns=("variable","quantile"))["value"]

In [None]:
variable = "active_cases"

fig = uncertainty_df[variable].plot(title=variable,alpha=0.7)
pd.Series(map_res[-1].derived_outputs[variable]).plot(style='--')
bcm_model_1.targets[variable].data.plot(style='.',color="black", ms=5, alpha=0.8)

## Analysing the posterior likelihood landscape analysis using ELA

In [None]:
# !pip install pflacco
from pflacco.classical_ela_features import *
from pflacco.local_optima_network_features import compute_local_optima_network, calculate_lon_features
#__To___create_a_initial____sample
from pflacco.sampling import create_initial_sample