# MCMC CALIBRATION TECHNICS FOR DETERMINISTIC ODE-BASED INFECTIOUS DISEASE MODELING

##

## Prerequies

In [None]:
import multiprocess as mp
import platform
if platform.system() != "Windows":
    
    mp.set_start_method('forkserver')

In [None]:
import Calibrate as cal #Runing the calibration process and gathering results
from calibs_utilities import get_all_priors, get_targets, load_data
from models.models import model1, model2, bcm_seir_age_strat, bcm_sir #All the models we design for the test
from Calibrate import plot_comparison_bars

# Combining tagets and prior with our summer2 model in a BayesianCompartmentalModel (bcm_model_1)
from estival.model import BayesianCompartmentalModel
from estival.sampling.tools import likelihood_extras_for_idata
from estival.sampling.tools import likelihood_extras_for_samples


import pandas as pd
import numpy as np
# import plotly.express as px
import matplotlib.pyplot as plt
from typing import List

import pymc as pm

# We use estivals parallel tools to run the model evaluations
from estival.utils.parallel import map_parallel

import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
import pickle
from datetime import datetime
from plotly import graph_objects as go
# import jax
from jax import numpy as jnp
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
# pd.options.plotting.backend = "matplotlib"
import pytensor

# Application 1: The basic SIR model

## Model Definition and Configuration

A mechanistic model (ODE-Based) model discribing Infectious Disease transmission.

In [None]:
model_1 = model1() #Only the SIR model

In [None]:
#Defining  a Bayesian Compartmental Model
#Targets and priors are already defined with the data 
#See models.py for the models costumization
bcm_model_1 = bcm_sir() #Directly by the function bcm sir

#Or using the following by combining the SIR model with the Bayesian Compartmental
# bcm_model_1 = BayesianCompartmentalModel(model_1, parameters, priors, targets)


### Trial run 

Here you can see a test of the SIR model ouput

In [None]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
output_labels = {"index": "time", "value": "number infectious"}

model_1.run(bcm_model_1.parameters) #Runing the model with default parameters
plt.rcParams["figure.figsize"] = (5, 5)


df = pd.DataFrame(
    {
        # "modelled": model_1.get_outputs_df()["I"],
        "observed": bcm_model_1.targets["active_cases"].data,
    }
)
df.plot(kind="scatter", labels=output_labels) #,figsize=(3,3));

### Sampling 

In order to do an MCMC sampling you can use an explicite initialisation process or let the algorithm choose by default.

Here we choose a uniform sample from the parameter range. Each chain has its own starting point.

In [None]:
##____Uniform Initialisation_________
def init_uniform(num_chains, parameters):
    init_vals = []
    for c in range(num_chains):
        init_vals.append({param: np.random.uniform(0.0,1.0) for param in parameters.keys()})
    
    return init_vals


init_vals_nuts = {param: jnp.array(np.random.uniform(0.0,1.0, 4)) for param in bcm_model_1.parameters.keys()}

init_vals_4 = init_uniform(4,bcm_model_1.parameters)
init_vals_6 = init_uniform(6,bcm_model_1.parameters)



In [None]:
#Defining the Numpyro model using our likelihood from the BaysianCompartmentalModel
def nmodel():
    sampled = {k:numpyro.sample(k, dist.Uniform(0.0,1.0)) for k in bcm_model_1.parameters}
    ll = numpyro.factor("ll", bcm_model_1.loglikelihood(**sampled))

#### Simple Run

In what following the model_1's parameters are calibrated by run simply each algorithms


In [None]:
%%time
D = 2 # Dimension of the parameter's space
samplers = [pm.Metropolis] # [infer.NUTS] #+ [pm.sample_smc] + [pm.Metropolis] + [pm.DEMetropolisZ] + [pm.DEMetropolis]*2
Draws = [8000] #[2000] #+ [2000] + [10000] + [8000]*3
# Tunes = [0] + [100, 1000]*5
Init =  [init_vals_4]#[init_vals_nuts] #+ [init_vals_4]*4 + [init_vals_6]
Chains = [4]#*5 + [6]
results_df = pd.DataFrame()

for sampler, draws, chains, init in zip(samplers, Draws, Chains, Init):
    # if sampler.__name__ == "NUTS":
    #     init = init_vals_nuts
    # else:
    #     init = init_vals_4
    results = cal.Single_analysis(sampler = sampler, 
            draws = draws,
            chains=chains,
            cores = chains,
            tune = 1000,
            bcm_model = bcm_model_1,
            # n_iterations = 1,
            nmodel=nmodel,
            # n_jobs = 1,
            initial_params = init

    )
            
    results_df = pd.concat([results_df,results])



results_df = results_df.reset_index(drop=True)


In [None]:
results_df
results_df.style.set_caption("MCMC COMPARISON") 


In [None]:

#Storing results into a pickled file
"""
with open('./Results/Model_1/Simple_run_results_3.pkl', 'wb') as fp:
    pickle.dump(results_df, fp)
"""    

#### Multiple runs
Here we perform several runs (calibrating the same model) of each algorithm using the funciton multirun in the module named "cal".
Note that we let the default aglorithm started points.

In [None]:
all_results = dict()

In [None]:
sampler = pm.sample_smc
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    tune = 1000,
    chains=4,
    cores=4, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,    
    )

In [None]:
sampler = pm.DEMetropolisZ
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    chains=4,
    cores=4,
    tune = 1000, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    )

In [None]:
sampler = pm.Metropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    tune = 1000, 
    chains=4,
    cores=4,
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    )

In [None]:
sampler = infer.NUTS
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    tune = 1000,
    chains=4,
    cores=4, 
    bcm_model = bcm_model_1,
    nmodel=nmodel,
    n_iterations = 100,
    n_jobs = 2,
    )

In [None]:
sampler = pm.DEMetropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 1000,
    tune = 1000,
    chains=6,
    cores=4, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    )

In [None]:
all_results["DEMetropolis"]

In [None]:
df = pd.concat(all_results)

In [None]:
#Storing the results for later analysis

# with open('./Results/Model_1/Multi_run_results_without_init_1.pkl', 'wb') as fp:
#     pickle.dump(all_results, fp)


#### Summarizing the 100 results

We call the function group_summary from the calibrate module. This will help to figure out the average performance
of sampler over 100 runs

In [None]:
#Loading a pickle file 
# with open('./Results/Model_1/Multi_run_results_3.pkl', 'rb') as fp:
#     multi_res = pickle.load(fp) #It's a dict
# df = pd.concat(multi_res) #converting to dataframe

Here best results refers to the run with the min 'Max_rhat' for each sampler over the  runs

In [None]:
best_results = pd.DataFrame()
for sampler in all_results.keys():
    temp = all_results[sampler].round(3)
    temp
    best_rhat = temp.loc[[temp["Rhat_max"].idxmin()]]
    best_results = pd.concat([best_results,best_rhat])

In [None]:
#Computing the Relative ESS
best_results["Rel_Ess"] = best_results['Min_Ess'].astype(float)/(best_results["Draws"].astype(float)*best_results['Chains'].astype(float))


In [None]:
summaries_mean, prcnt_succ = cal.group_summary(df)

In [None]:
summaries_mean

In [None]:
prcnt_succ = prcnt_succ.round(2)

In [None]:
prcnt_succ

#### Using arviz for trace visualization

In [None]:
for idata, sampler in zip(best_results.Trace, best_results.Sampler):
    print(sampler)
    az.plot_rank(idata,figsize=(9,4))
    plt.show()

#### Bar Ploting Comparison

In [None]:
# plot_comparison_Bars(results_df=res)
best_results["Run"] = best_results.Sampler + "\nDraws=" + best_results.Draws.astype(str) + "\nTune=" + best_results.Tune.astype(str) +"\nChains=" + best_results.Chains.astype(str)
plot_comparison_bars(best_results.round(2))

### Fitting test

Here we test if the model is well fitted to the data, we will use the best results.

In [None]:
IDATA = best_results["Trace"]

In [None]:
map_res = dict()
for idata, sampler in zip(IDATA,best_results["Sampler"]):
    map_res[sampler] = cal.fitting_test(idata, bcm_model_1, model_1)

In [None]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

NUTS = map_res["NUTS"]["I"]
SMC = map_res["sample_smc"]["I"]
DEM = map_res["DEMetropolis"]["I"]
MH = map_res["Metropolis"]["I"]
DEMZ = map_res["DEMetropolisZ"]["I"]



df = pd.DataFrame(
    {   
        "MH": MH,
        "NUTS": NUTS,
        "SMC": SMC,
        "DEM": DEM,
        "DEMZ": DEMZ,
        "observed": bcm_model_1.targets["active_cases"].data,
    }
)
df.plot(kind="scatter", labels=output_labels, title="Model fitting")#

## Uncertainty sampling

Will be discarded 

In [None]:
# Wrapper function captures our bcm from the main namespace to pass into map_parallel
# Using this idiom in closures/factory functions is typical
def run_sample(idx_sample):
    idx, params = idx_sample
    return idx, bcm_model_1.run(params)

# Run the samples through our BCM using the above function
# map_parallel takes a function and an iterable as input

# We use 4 workers here, default is cpu_count/2 (assumes hyperthreading)
sample_res = dict()
for idata,sampler in zip(IDATA,best_results["Sampler"]):
    sample_idata = az.extract(idata, num_samples=4000)
    samples_df = sample_idata.to_dataframe().drop(columns=["chain","draw"])
    sample_res[sampler] = map_parallel(run_sample, samples_df.iterrows(), n_workers=4)

In [None]:
# We'll use xarray for this step; aside from computing things very quickly, it's useful
# to persist the run results to netcdf/zarr etc

import xarray as xr

In [None]:
# Build a DataArray out of our results, then assign coords for indexing
xres = xr.DataArray(np.stack([r.derived_outputs for idx, r in sample_res["NUTS"]]), 
                    dims=["sample","time","variable"])
xres = xres.assign_coords(sample=sample_idata.coords["sample"], 
                          time=map_res["NUTS"].index, variable=pd.DataFrame(map_res["NUTS"]["I"]).columns)

In [None]:
# Set some quantiles to calculate
quantiles = (0.5,0.75,0.95)

# Generate a new DataArray containing the quantiles
xquantiles = xres.quantile(quantiles,dim=["sample"])

In [None]:
# Extract these values to a pandas DataFrame for ease of plotting

uncertainty_df = xquantiles.to_dataframe(name="value").reset_index().set_index("time").pivot(columns=("variable","quantile"))["value"]

In [None]:
variable = "active_cases"
pd.options.plotting.backend = "matplotlib" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

fig = uncertainty_df["I"].plot.area(title=variable,alpha=0.7)
pd.Series(map_res["NUTS"]["I"]).plot(label = "modelled",style='--')
bcm_model_1.targets[variable].data.plot(label = "observed",style='.',color="black", ms=5, alpha=0.8)
plt.legend()