# MCMC CALIBRATION TECHNICS FOR DETERMINISTIC ODE-BASED INFECTIOUS DISEASE MODELS

# Application 2: The SEIR age-stratified model

## Prerequies

In [None]:
import multiprocess as mp
import platform
if platform.system() != "Windows":
    
    mp.set_start_method('forkserver')


In [None]:
import Calibrate as cal #Runing the calibration process and gathering results
from calibs_utilities import get_all_priors, get_targets, load_data
from models.models import model1, model2, bcm_seir_age_strat, bcm_sir #All the models we design for the test
from Calibrate import plot_comparison_bars

# Combining tagets and prior with our summer2 model in a BayesianCompartmentalModel (bcm_model_1)
from estival.model import BayesianCompartmentalModel
from estival.sampling.tools import likelihood_extras_for_idata
from estival.sampling.tools import likelihood_extras_for_samples


import pandas as pd
import numpy as np
# import plotly.express as px
import matplotlib.pyplot as plt
from typing import List

import pymc as pm

# We use estivals parallel tools to run the model evaluations
from estival.utils.parallel import map_parallel

import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
import pickle
from datetime import datetime
from plotly import graph_objects as go
# import jax
from jax import numpy as jnp
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
# pd.options.plotting.backend = "matplotlib"
import pytensor

## Data for fitting
Here we will define a target for each age category

In [None]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

df = pd.DataFrame()
df = pd.read_csv("./data/new_cases_England_2020.csv")
df["date"] = pd.to_datetime(df.date)
Tot = df.copy()
Tot.set_index(["date"], inplace=True)
df.set_index(["age","date"], inplace=True)


In [None]:
Tot = Tot["Aug 2020":"Nov 2020"].groupby("age").sum()
Tot.plot(kind='bar')

In [None]:
#If you want to plat the curve for one specific age groupe. Here the group 60 and above.
#We apply a fortnightly rolling to the data to discard fluctuations. We only plot cases by week.

df.loc["60+"]["Aug 2020":"6 Dec 2020"].rolling(14).mean().iloc[14:]["cases"][::7].plot(kind="scatter")

In [None]:
total_cases = df.groupby("date").sum()
#Rolling by 14 days to discard fluctuations
total_cases = total_cases.rolling(14).mean().iloc[14:]
# D = total_cases["Jun 2020":"Nov 2020"].iloc[::5]

## Model Definition


In [None]:
model_config = {"compartments": ("S", "E","I","R"), 
        "population": 56490045, #England population size 2021
        "start_time": datetime(2020, 7, 25),
        "end_time": datetime(2020, 12, 1)
}
bcm_model_2 = bcm_seir_age_strat(model_config)
model_2 = model2(model_config)

In [None]:
model_2.get_default_parameters()

## Trial run

In [None]:
disp_params = {k:v.ppf(0.5) for k,v in bcm_model_2.priors.items() if "_disp" in k} #Mandatory if you tempt to calibrate your target dispersion

res = bcm_model_2.run(bcm_model_2.parameters | disp_params).derived_outputs

# res = model_2.get_outputs_df()
Infec = [f"InciXage_{i}" for i in range(0,65,5)] #Selecting the Infectious compartments
#Summing over the infectious compartments 
total_cases_pred = pd.DataFrame(res[Infec].sum(axis=1))

In [None]:
plot_start_date = datetime(2020, 8, 1)
analysis_end_date = datetime(2020, 11, 30)

# plot = model_2.get_outputs_df()["IXage_60"].plot()
plot = pd.DataFrame(total_cases_pred).plot()
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases["cases"], mode='markers', name='total_cases'))
#pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
#pivot_df["total_cases"].plot.area()

## Testing by optimization

In [None]:
import nevergrad as ng
import pandas as pd 
import numpy as np
# Import our convenience wrapper
from estival.wrappers.nevergrad import optimize_model

#### Running several optimization process to have an idea about the range of our parameters  

In [None]:
# TwoPointsDE is a good suggested default for some of our typical use cases
opt_class = ng.optimizers.TwoPointsDE
orunner = optimize_model(bcm_model_2, opt_class=opt_class)
# Here we run the optimizer in a loop, inspecting the current best point at each iteration
# Using the loss information at each step can provide the basis for stopping conditions
mle_params = []
for i in range(8):
    # Run the minimizer for a specified 'budget' (ie number of evaluations)
    rec = orunner.minimize(30000)
    # Print the loss (objective function value) of the current recommended parameters
    print(rec.loss)
    mle_params.append(rec.value[1])


# mle_params = rec.value[1]



### We can refine our priors using the optimization results

In [None]:
from estival import priors as esp

L = list(pd.DataFrame(mle_params).max().keys())
T = {param : esp.UniformPrior(param,(0.0, val_max)) for param, val_max in zip(L,pd.DataFrame(mle_params).max().round(7)) if param not in ["S_0","seed","detect_prop"]}

In [None]:
T["S_0"] = esp.UniformPrior("S_0",(1000000,3e6))
T["seed"] = esp.UniformPrior("seed",(1,1600))
T["detect_prop"] = esp.UniformPrior("detect_prop",(0.0,1.0))

In [None]:
bcm_model_2.priors  = T

In [None]:
#Reacting the other transmission rate value by the one in mle_param
# for param in bcm_model_2.parameters :
#     if param not in ["S_0","seed","age_transmission_rate_0", "incubation_period", "infectious_period", "detect_prop"]:
#         mle_params[param] = mle_params["age_transmission_rate_0"]


#print(mle_params[-1])

res_opt = bcm_model_2.run(mle_params[-1])

In [None]:
total_cases_pred_opt = pd.DataFrame(res_opt.derived_outputs[Infec].sum(axis=1))

# total_cases_pred_opt_0 = pd.DataFrame(res_opt.derived_outputs['InciXage_0'])

In [None]:
D = total_cases["Aug 2020":"Dec 2020"][::7]
# D = cases_0["Aug 2020":"nov 2020"]

In [None]:
plot_start_date = datetime(2020, 7, 1)
analysis_end_date = datetime(2020, 12, 7)

# plot = model_2.get_outputs_df()["IXage_60"].plot()
plot = pd.DataFrame(total_cases_pred_opt).plot()
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
plot.add_trace(go.Scatter(x=D.index, y=D["cases"], mode='markers', name='total_cases'))

## Calibration

In [None]:
# ##____Uniform Initialisation for each chain_________
# chains = 4
# init_vals = []
# for c in range(chains):
#     temp = {param: np.random.uniform(0.0,1.0) for param in list(bcm_model_2.parameters.keys())[:-3]}
#     temp["seed"] = np.random.uniform(1,1200)
#     # temp["incubation_period"] = np.random.uniform(1.,15.) 
#     # temp["infectious_period"] = np.random.uniform(1.,15.)
#     init_vals.append(temp)


# init_vals_nuts = {param: jnp.array(np.random.uniform(0.0,1.0, 4)) for param in list(bcm_model_2.parameters)[:-3]}
# init_vals_nuts["seed"] = jnp.array(np.random.uniform(1.,1200, 4))
# # init_vals_nuts["infectious_period"] = jnp.array(np.random.uniform(1.,15.0, 4))

In [None]:
def nmodel_2():
    unif_priors = list(bcm_model_2.parameters)[:-5]
    sampled = {k:numpyro.sample(k, dist.Uniform(0.0,0.8)) for k in unif_priors}

    sampled["S_0"] = numpyro.sample("S_0", dist.Uniform(1800000.0, 56490045.0))
    sampled["seed"] = numpyro.sample("seed", dist.Uniform(1.0,2000))
    sampled["detect_prop"] = numpyro.sample("detect_prop", dist.Uniform(0.0,1.0))

    #Adding the normal priors for the incubation and infectious periods

    # sampled["incubation_period"] = numpyro.sample("incubation_period", dist.TruncatedNormal(7.3, 2.0, low=1., high=14.))
    # sampled["infectious_period"] = numpyro.sample("infectious_period", dist.TruncatedNormal(5.4, 3.0, low=1., high=14.))
    # Log-likelihood
    disp_params = {k:v.ppf(0.5) for k,v in bcm_model_2.priors.items() if "_disp" in k}

    log_likelihood = bcm_model_2.loglikelihood(**sampled | disp_params)

    numpyro.factor("ll",log_likelihood)

### Performing a Single run for each algorithm

In [None]:

samplers = [pm.sample_smc,pm.Metropolis,pm.DEMetropolisZ,pm.DEMetropolis] #, infer.NUTS]
Draws = [40000]*4
# Init = [init_vals_4]*4
results_df = pd.DataFrame()

for sampler, draws in zip(samplers, Draws):#, Init):
    results = cal.Single_analysis(sampler = sampler, 
            draws = draws,
            chains= 4,
            cores = 4,
            tune = 5000,
            bcm_model = bcm_model_2,
            )
            
    results_df = pd.concat([results_df,results])



results_df = results_df.reset_index(drop=True)

In [None]:
#Storing results on a pickle file
# with open('./Results/Model_2/Simple_run_results_3.pkl', 'wb') as fp:
    # pickle.dump(results_df, fp)

#Reading a stored results
# results_df = pd.read_pickle("./Results/Model_2/Simple_run_results_3.pkl")

In [None]:
summaries_mean, prcnt_succ = cal.group_summary(results_df)

In [None]:
prcnt_succ
Trace = results_df["Trace"]

In [None]:
idata = Trace[0]

In [None]:
summary_df = az.summary(idata, hdi_prob=0.94)


In [None]:
summary_df.reset_index()

In [None]:
import seaborn as sns
# Extract the summary statistics for plotting
summary_df = az.summary(idata, hdi_prob=0.94)

# Prepare data for plotting
plot_data = summary_df.reset_index()[1:]
plot_data = plot_data.melt(id_vars='index', value_vars=['mean', 'hdi_3%', 'hdi_97%'], var_name='stat', value_name='value')
plot_data['parameter'] = plot_data['index']

# Create the box plot
plt.figure(figsize=(10, 6))
sns.boxplot(data=plot_data, x='parameter', y='value', hue='stat', palette='Set2')

# Add plot labels and title
plt.title('Box Plot of Parameter Estimates with HDI Intervals')
plt.xlabel('Parameter')
plt.ylabel('Value')
plt.legend(title='Statistic', loc='best')

# Show plot
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


In [12]:
#ADDING THE NUTS SAMPLER
#Probabily the nuts will take a while to finish--We need to reduce the number of iterations and or warmup
res_nuts = cal.Single_analysis(sampler = infer.NUTS, 
            draws = 20000,
            tune = 2000,
            chains = 4,
            cores=4,
            bcm_model = bcm_model_2,
            nmodel=nmodel_2,
            # initial_params = init_vals_nuts)
)

results_df = pd.concat([results_df,res_nuts]) 
results_df = results_df.reset_index(drop=True)
results_df = results_df.drop(columns="index")

In [None]:
#Storing results on a pickle file
# with open('./Results/Model_2/Simple_run_results_3.pkl', 'wb') as fp:
    # pickle.dump(results_df, fp)

In [None]:
def fitting_test(sampler,idata, bcm, model):
    from estival.sampling.tools import likelihood_extras_for_samples,likelihood_extras_for_idata
    if sampler in ["DEMetropolis", "DEMetropolisZ"]:
        likelihood_df = likelihood_extras_for_idata(idata, bcm) #More faster
    else :
        likelihood_df = likelihood_extras_for_samples(idata.posterior, bcm)

    ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)

    # Extract the parameters from the calibration samples
    map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
    #Reafecting the other transmission rate value by the one in mle_param
    # for param in bcm_model_2.parameters :
    #     if param not in ["S_0","seed","age_transmission_rate_0", "incubation_period", "infectious_period","detect_prop"]:
    #         map_params[param] = map_params["age_transmission_rate_0"]
    # map_params['incubation_period']= 5.4
    # map_params['infectious_period'] = 7.3
    # print(map_params)
    bcm.loglikelihood(**map_params), ldf_sorted.iloc[0]["loglikelihood"]
    # Run the model with these parameters
    model.run(map_params)
    # ...and plot some results
    return model.get_derived_outputs_df(),map_params

In [None]:
Map_res = dict()
Map_params = dict()
for row in results_df.index:
    idata = results_df.Trace.loc[row]
    sampler =  results_df.at[row, "Sampler"]
    Map_res[sampler],Map_params[sampler] = fitting_test(sampler,idata, bcm_model_2,model_2)    

In [None]:
for sampler in Map_res.keys():
    model_2.run(Map_params[sampler])
    Map_res[sampler] = model_2.get_derived_outputs_df()

In [None]:
# Extracting data for plotting
    
fig, ax = plt.subplots(figsize=(10, 7))
for sampler in Map_params.keys():
    parameters = [k for k in Map_params[sampler].keys() if k not in ["S_0","detect_prop","seed", "incubation_period","infectious_period"]]
    values = [val for val, k in zip(Map_params[sampler].values(),Map_params[sampler].keys()) if k not in ["S_0","detect_prop","seed", "incubation_period","infectious_period"]]
    # Create the plot


    # Create a scatter plot
    ax.scatter(parameters, values, s=80,label= sampler)  # s: size of the points

    # Add labels to each point
    for param, value in zip(parameters, values):
        ax.text(param, value, f'{value:.2f}', fontsize=9, ha='right', va='bottom')

    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45, ha='right')

    # Add titles and labels
plt.title('Parameter Values')
plt.xlabel('Parameters')
plt.ylabel('Values')
plt.legend()

# Display the plot
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(1,3,figsize=(18, 5))
for i,paramet in  enumerate(["seed","detect_prop","S_0"]):
    for sampler in Map_params.keys():
        ax = axes[i]
        parameters = [paramet]
        values = [val for val, k in zip(Map_params[sampler].values(),Map_params[sampler].keys()) if k in [paramet]]
        # Create the plot


        # Create a scatter plot
        ax.scatter(parameters, values, s=80,label= sampler)  # s: size of the points

        # Add labels to each point
        for param, value in zip(parameters, values):
            ax.text(param, value, f'{value:.2f}', fontsize=9, ha='right', va='bottom')

        # Rotate x-axis labels for better readability
       # Rotate x-axis labels for better readability
    ax.tick_params(axis='x', rotation=45)
    
    # Set titles and labels
    ax.set_title(f'Parameters for {sampler}')
    ax.set_xlabel('Parameters')
    ax.set_ylabel('Values')
    ax.legend()
    # Display the plot
    plt.tight_layout()
plt.show()

In [None]:
model_2.run(Map_params["sample_smc"])


In [None]:
Infec = [f"InciXage_{i}" for i in range(0,65,5)]


In [None]:
map_res = model_2.get_derived_outputs_df()

In [None]:
# total_cases_pred_DEM = map_res_smc[Infec].sum(axis=1)
# map_res = model_2.get_outputs_df() # Map_res["sample_smc"]
total_cases_pred_SMC = map_res[Infec].sum(axis=1)

In [None]:
total_cases = total_cases["Aug 2020":"Nov 2020"].iloc[::14]

In [None]:
Tot_cases_sampler = dict()
for sampler in Map_res.keys():
    data = {'Inci_per_age': [str(lab) for lab in Infec], 'Tot_cases': [Map_res[sampler][str(lab)].sum() for lab in Infec]}
    df = pd.DataFrame(data)
    df = df.set_index('Inci_per_age')
    Tot_cases_sampler[sampler] = df

In [None]:
plot_start_date = datetime(2020, 8, 1)
analysis_end_date = datetime(2020, 11, 30)
output_labels = {"name": "Sampler","index": "Time", "value": "Incidence"}

#plot= map_res["IXage_5"].plot()
plot = pd.DataFrame({"SMC":Map_res["sample_smc"][Infec].sum(axis = 1),
                     "NUTS": Map_res["NUTS"][Infec].sum(axis=1),
                     "DEMZ": Map_res["DEMetropolisZ"][Infec].sum(axis=1),
                     "DEM": Map_res["DEMetropolis"][Infec].sum(axis=1),
                     "MH": Map_res["Metropolis"][Infec].sum(axis=1)
}).plot(labels=output_labels)
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
# plot.add_trace(go.Scatter(x=df.loc["05_09"].index, y=df.loc["05_09"]["cases"], mode='markers', name='observed'))
plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases["cases"], mode='markers', name='Tot obs cases'))
plot.update_layout(
    legend_title_text='Sampler'
)
# plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases_pred_DEM, name='DEM'))



In [None]:
results_df["Rel_Ess"] = results_df['Min_Ess'].astype(float)/(results_df["Draws"].astype(float)*results_df['Chains'].astype(float))


In [None]:
plot_comparison_bars(results_df.round(4))