# MCMC CALIBRATION TECHNICS IN CONTEXT OF  INFECTIOUS DISEASE MODELING

## Prerequies

In [None]:
# pip install multiprocess
# %pip install --upgrade --force-reinstall multiprocess

#Compatible with latest jax version  
# %pip install summerepi2==1.3.6
# %pip install jinja2

In [None]:
import multiprocess as mp
import platform

# This is required for pymc parallel evaluation in notebooks 
# But has to be change while using a python script 
# Use this following instruction instead
# if __name__ == "__main__":
#     if platform.system() != "Windows":
#         mp.set_start_method('spawn')
    
    # rest of your code body here inside the if __name__
if platform.system() != "Windows":
    
    mp.set_start_method('forkserver')

In [None]:
import Calibrate as cal #Runing the calibration process and gathering results
from calibs_utilities import get_all_priors, get_targets, load_data
from models.models import model1, model2, bcm_seir_age_strat, bcm_sir #All the models we design for the test
from Calibrate import plot_comparison_bars

# Combining tagets and prior with our summer2 model in a BayesianCompartmentalModel (bcm_model_1)
from estival.model import BayesianCompartmentalModel
from estival.sampling.tools import likelihood_extras_for_idata
from estival.sampling.tools import likelihood_extras_for_samples


import pandas as pd
import numpy as np
# import plotly.express as px
import matplotlib.pyplot as plt
from typing import List

import pymc as pm

# We use estivals parallel tools to run the model evaluations
from estival.utils.parallel import map_parallel

import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
import pickle
from datetime import datetime
from plotly import graph_objects as go
# import jax
from jax import numpy as jnp
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
# pd.options.plotting.backend = "matplotlib"

# Application 1: The basic SIR model

## Model Definition and Configuration

A mechanistic model (ODE-Based) model discribing Infectious Disease transmission.

In [None]:
model_1 = model1() 

In [None]:
#Defining  a Bayesian Compartmental Model
#Targets and priors are already defined with the data 
#See models.py for the costumization
bcm_model_1 = bcm_sir()
# bcm_model_1 = BayesianCompartmentalModel(model_1, parameters, priors, targets)


### Trial run 

In [None]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
output_labels = {"index": "time", "value": "number infectious"}

model_1.run(bcm_model_1.parameters)
plt.rcParams["figure.figsize"] = (5, 5)


df = pd.DataFrame(
    {
        "modelled": model_1.get_outputs_df()["I"],
        "observed": bcm_model_1.targets["active_cases"].data,
    }
)
df.plot(kind="scatter", labels=output_labels) #,figsize=(3,3));

### Sampling 

In [None]:
##____Uniform Initialisation_________
def init_uniform(num_chains, parameters):
    init_vals = []
    for c in range(num_chains):
        init_vals.append({param: np.random.uniform(0.0,1.0) for param in parameters.keys()})
    
    return init_vals


init_vals_nuts = {param: jnp.array(np.random.uniform(0.0,1.0, 4)) for param in bcm_model_1.parameters.keys()}

init_vals_4 = init_uniform(4,bcm_model_1.parameters)
init_vals_6 = init_uniform(6,bcm_model_1.parameters)


In [None]:
import numpyro
from numpyro import distributions as dist
def nmodel():
    sampled = {k:numpyro.sample(k, dist.Uniform(0.0,1.0)) for k in bcm_model_1.parameters}
    ll = numpyro.factor("ll", bcm_model_1.loglikelihood(**sampled))


#### Simple Run

In [None]:
%%time
D = 2 # Dimension of the parameter's space
samplers =  [infer.NUTS] + [pm.sample_smc] + [pm.Metropolis] + [pm.DEMetropolisZ] + [pm.DEMetropolis]*2
Draws = [2000] + [2000] + [10000] + [8000]*3
# Tunes = [0] + [100, 1000]*5
Init = [init_vals_nuts] + [init_vals_4]*4 + [init_vals_6]
Chains = [4]*5 + [6]
results_df = pd.DataFrame()

for sampler, draws, chains, init in zip(samplers, Draws, Chains, Init):
    # if sampler.__name__ == "NUTS":
    #     init = init_vals_nuts
    # else:
    #     init = init_vals_4
    results = cal.Single_analysis(sampler = sampler, 
            draws = draws,
            chains=chains,
            cores = chains,
            tune = 1000,
            bcm_model = bcm_model_1,
            # n_iterations = 1,
            nmodel=nmodel,
            # n_jobs = 1,
            initial_params = init

    )
            
    results_df = pd.concat([results_df,results])



results_df = results_df.reset_index(drop=True)


In [None]:
results_df
results_df.style.set_caption("MCMC COMPARISON") 


In [None]:

#Storing results on a pickle file
with open('./Results/Model_1/Simple_run_results_3.pkl', 'wb') as fp:
    pickle.dump(results_df, fp)

In [None]:
#Loading a pickle file
with open('./Results/Model_1/Simple_run_results_3.pkl', 'rb') as fp:
    res = pickle.load(fp)

# res = pd.concat(res)

In [None]:
res

In [None]:
Trace = res.Trace

In [None]:
idata = Trace[3]

In [None]:
idata

##### Bar Ploting Comparison

In [None]:
# plot_comparison_Bars(results_df=res)
plot_comparison_bars(res.round(2))

#### Multiple runs

In [None]:
all_results = dict()

In [None]:
sampler = pm.sample_smc
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 2000,
    tune = 1000,
    chains=4,
    cores=4, 
    bcm_model = bcm_model_1,
    nmodel=nmodel,
    n_iterations = 100,
    n_jobs = 3,
    initial_params = init_vals_4
    )

In [None]:
sampler = pm.DEMetropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 8000,
    tune = 1000,
    chains=6,
    cores=4, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    initial_params = init_vals_6
    )

In [None]:
sampler = pm.DEMetropolisZ
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 8000,
    chains=4,
    cores=4,
    tune = 1000, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    initial_params = init_vals_4
    )

In [None]:
sampler = pm.Metropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    tune = 1000, 
    chains=4,
    cores=4,
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    initial_params = init_vals_4
    )

In [None]:
sampler = infer.NUTS
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 2000,
    tune = 1000,
    chains=4,
    cores=4, 
    bcm_model = bcm_model_1,
    nmodel=nmodel,
    n_iterations = 100,
    n_jobs = 2,
    initial_params = init_vals_nuts
    )

In [None]:
#Storing the results for later analysis

with open('./Results/Model_1/Multi_run_results_3.pkl', 'wb') as fp:
    pickle.dump(all_results, fp)


#### Summarizing the 100 results

We call the function group_summary from the calibrate modules. This will help to figure out the average performance
of sampler over 100 runs

In [None]:
#Loading a pickle file
with open('./Results/Model_1/Multi_run_results_3.pkl', 'rb') as fp:
    multi_res = pickle.load(fp)

In [None]:
df = pd.concat(multi_res)

In [None]:
summaries_mean, prcnt_succ = cal.group_summary(df)

In [None]:
prcnt_succ

## Using arviz for trace visualization

In [None]:
for idata, Run, draws, tune in zip(res.Trace, res.Run, res.Draws, res.Tune):
    subset = idata.sel(draw=slice(0, None), groups="posterior")
    print("Run = ",Run)
    az.plot_trace(subset, figsize=(16,3.2*len(subset.posterior)),compact=False)#, lines=[("m", {}, mtrue), ("c", {}, ctrue)]);
    plt.show()



### Fitting test

Here we test if the model is well fitted to the data, we will use the results from the single run

In [None]:
IDATA = res["Trace"]

In [None]:
map_res = dict()
for idata, run in zip(IDATA,res["Run"]):
    # print(idata)
    map_res[run] = cal.fitting_test(idata, bcm_model_1, model_1)

In [None]:
modelled = map_res["NUTS\nDraws=2000\nTune=1000"]["I"]


df = pd.DataFrame(
    {
        "modelled": modelled,
        "observed": bcm_model_1.targets["active_cases"].data,
    }
)
df.plot(kind="scatter", labels=output_labels, title="Model fitting")

## Uncertainty sampling

In [None]:
# Use the arviz extract method to obtain some samples, then convert to a DataFrame
sample_idata = az.extract(idata, num_samples=4000)
samples_df = sample_idata.to_dataframe().drop(columns=["chain","draw"])


In [None]:
# Wrapper function captures our bcm from the main namespace to pass into map_parallel
# Using this idiom in closures/factory functions is typical
def run_sample(idx_sample):
    idx, params = idx_sample
    return idx, bcm_model_1.run(params)

# Run the samples through our BCM using the above function
# map_parallel takes a function and an iterable as input

# We use 4 workers here, default is cpu_count/2 (assumes hyperthreading)
sample_res = map_parallel(run_sample, samples_df.iterrows(), n_workers=4)


In [None]:
# We'll use xarray for this step; aside from computing things very quickly, it's useful
# to persist the run results to netcdf/zarr etc

import xarray as xr

In [None]:
map_res = cal.fitting_test(idata, bcm_model_1, model_1)

In [None]:
map_res["I"]

In [None]:
# Build a DataArray out of our results, then assign coords for indexing
xres = xr.DataArray(np.stack([r.derived_outputs for idx, r in sample_res]), 
                    dims=["sample","time","variable"])
xres = xres.assign_coords(sample=sample_idata.coords["sample"], 
                          time=map_res.index, variable=pd.DataFrame(map_res["I"]).columns)

In [None]:
# Set some quantiles to calculate
quantiles = (0.5,0.75,0.95)

# Generate a new DataArray containing the quantiles
xquantiles = xres.quantile(quantiles,dim=["sample"])

In [None]:
# Extract these values to a pandas DataFrame for ease of plotting

uncertainty_df = xquantiles.to_dataframe(name="value").reset_index().set_index("time").pivot(columns=("variable","quantile"))["value"]

In [None]:
variable = "active_cases"
pd.options.plotting.backend = "matplotlib" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

fig = uncertainty_df["I"].plot.area(title=variable,alpha=0.7)
pd.Series(map_res["I"]).plot(label = "modelled",style='--')
bcm_model_1.targets[variable].data.plot(label = "observed",style='.',color="black", ms=5, alpha=0.8)
plt.legend()

## Analysing the posterior likelihood landscape analysis using ELA

In [None]:
# !pip install pflacco
from pflacco.classical_ela_features import *
from pflacco.local_optima_network_features import compute_local_optima_network, calculate_lon_features
#__To___create_a_initial____sample
from pflacco.sampling import create_initial_sample

# Application 2: The SEIR age-stratified model

## Data for fitting
Here we will define a target for each age category

In [None]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

df = pd.DataFrame()
df = pd.read_csv("./data/new_cases_England_2020.csv")
df["date"] = pd.to_datetime(df.date)
df.set_index(["age","date"], inplace=True)
# df['date'] = pd.to_datetime(df['date'].str.split(' - ').str[0])
# df.index=dfdate
# targets_data = df.drop(columns='date')
#pivot_df = df.pivot(index='date', columns='incidence', values='incidence')
#pivot_df["total_cases"]=pivot_df.sum(1)



In [None]:
age_strat = [f"{i}" for i in range(0,65,5)]
parameters = {
    'age_transmission_rate_'+ str(age) : 0.25 for age in age_strat
        }
parameters['incubation_period']= 6
parameters['infectious_period'] = 7.3

In [None]:
ages_labels = [f"{i:02}_{i+4:02}" for i in range(0,60, 5)] + ["60+"]
targets_data = dict()
for age in ages_labels:
    targets_data[age] = df.loc[age]
    # plt.plot(cases_per_age[age])
# plt.show()
# targets_data = pd.concat(targets_data)


In [None]:
d = pd.concat(targets_data)
d = d.groupby("date").sum()
total_cases = d.rolling(14).mean().iloc[14:]

In [None]:
total_cases

## Model Definition


In [None]:
bcm_model_2 = bcm_seir_age_strat()

In [None]:
D = targets_data["00_04"].rolling(14).mean()[14:]

In [None]:
T = bcm_model_2.targets
D = pd.DataFrame(T["incX0"].data)
D

## Trial run

In [None]:
model_2.run(parameters)

res = model_2.get_outputs_df()
Infec = [f"IXage_{i}" for i in range(0,65,5)]
total_cases_pred = res[Infec].sum(axis=1)

In [None]:

plot_start_date = datetime(2020, 8, 1)
analysis_end_date = datetime(2020, 11, 30)

# plot = model_2.get_outputs_df()["IXage_60"].plot()
plot = pd.DataFrame(total_cases_pred).plot()
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases["cases"], mode='markers', name='total_cases'))
#pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
#pivot_df["total_cases"].plot.area()

## Calibration

In [None]:
#Defining  a Bayesian Compartmental Model

bcm_model_2 = BayesianCompartmentalModel(model_2, parameters, priors,targets)
# T = bcm_model_2.targets

In [None]:
##____Uniform Initialisation for each chain_________
chains = 4
init_vals = []
for c in range(chains):
    temp = {param: np.random.uniform(0.0,1.0) for param in list(parameters.keys())[:-2]}
    temp["incubation_period"] = np.random.uniform(1.,15.) 
    temp["infectious_period"] = np.random.uniform(1.,15.)
    init_vals.append(temp)


init_vals_nuts = {param: jnp.array(np.random.uniform(0.0,1.0, 4)) for param in list(bcm_model_2.parameters)[:-2]}
init_vals_nuts["incubation_period"] = jnp.array(np.random.uniform(1.,15.0, 4))
init_vals_nuts["infectious_period"] = jnp.array(np.random.uniform(1.,15.0, 4))

In [None]:
init_vals


In [None]:
def nmodel_2():
    # import numpyro.distributions.truncated as
    unif_priors = list(bcm_model_2.parameters)[:-2]
    sampled = {k:numpyro.sample(k, dist.Uniform(0.0,1.0)) for k in unif_priors}
    #Adding the normal priors for the incubation and infectious periods
    sampled["incubation_period"] = numpyro.sample("incubation_period", dist.TruncatedNormal(7.3, 2.0, low=1., high=15.))
    sampled["infectious_period"] = numpyro.sample("infectious_period", dist.TruncatedNormal(5.4, 3.0, low=1., high=15.))

    #Definir les normal priors
    ll = numpyro.factor("ll", bcm_model_2.loglikelihood(**sampled))

In [None]:
res = cal.Single_analysis(sampler = pm.DEMetropolisZ, 
            draws = 100000,
            tune = 5000,
            chains = 4,
            cores=4,
            bcm_model = bcm_model_2,
            # n_iterations = 1,
            # n_jobs = 1,
            nmodel=nmodel_2,
            initial_params = init_vals)

In [None]:
idata = res["Trace"]
idata = idata[0]
burn_in = 50000
subset = idata.sel(draw=slice(burn_in, None), groups="posterior")


In [None]:
az.summary(subset)

In [None]:
az.plot_trace(subset, figsize=(12,2.5*len(idata.posterior)),compact=False, legend=True)
plt.tight_layout(pad = 0.005)


In [None]:
az.plot_posterior(idata)

In [None]:
map_res = fitting_test(subset, bcm_model_2,model_2)

In [None]:
total_cases_pred = map_res()[Infec].sum(axis=1)

In [None]:
total_cases_pred

In [None]:
# plot_start_date = datetime(2019, 12, 1)
# analysis_end_date = datetime(2020, 3, 11)

plot = total_cases_pred.plot()
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases["cases"], mode='markers', name='total_cases'))

In [None]:
variable = "incidence"

pd.options.plotting.backend = "plotly"
pd.DataFrame(map_res.derived_outputs["incidence"]).plot(title = f"{variable} (MLE)")
pd.DataFrame(bcm_model_2.targets[variable].data).plot(style='.')