# MCMC CALIBRATION TECHNICS IN CONTEXT OF  INFECTIOUS DISEASE MODELING

## Prerequies

In [None]:
# pip install multiprocess
# %pip install --upgrade --force-reinstall multiprocess

#Compatible with latest jax version  
# %pip install summerepi2==1.3.6
# %pip install jinja2

In [1]:
import multiprocess as mp
import platform

# This is required for pymc parallel evaluation in notebooks 
# But has to be change while using a python script 
# Use this following instruction instead
# if __name__ == "__main__":
#     if platform.system() != "Windows":
#         mp.set_start_method('spawn')
    
    # rest of your code body here inside the if __name__
if platform.system() != "Windows":
    
    mp.set_start_method('forkserver')

In [2]:
import Calibrate as cal #Runing the calibration process and gathering results
from calibs_utilities import get_all_priors, get_targets, load_data
from models.models import model1, model2, bcm_seir_age_strat, bcm_sir #All the models we design for the test
from Calibrate import plot_comparison_bars

# Combining tagets and prior with our summer2 model in a BayesianCompartmentalModel (bcm_model_1)
from estival.model import BayesianCompartmentalModel
from estival.sampling.tools import likelihood_extras_for_idata
from estival.sampling.tools import likelihood_extras_for_samples


import pandas as pd
import numpy as np
# import plotly.express as px
import matplotlib.pyplot as plt
from typing import List

import pymc as pm

# We use estivals parallel tools to run the model evaluations
from estival.utils.parallel import map_parallel

import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
import pickle
from datetime import datetime
from plotly import graph_objects as go
# import jax
from jax import numpy as jnp
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
# pd.options.plotting.backend = "matplotlib"
import pytensor



In [None]:
# %pip install pytensor==2.20
# 5.11
# pm.__version__
# %pip install pymc==5.11



# Application 1: The basic SIR model

## Model Definition and Configuration

A mechanistic model (ODE-Based) model discribing Infectious Disease transmission.

In [None]:
model_1 = model1() 

In [None]:
#Defining  a Bayesian Compartmental Model
#Targets and priors are already defined with the data 
#See models.py for the costumization
bcm_model_1 = bcm_sir()
# bcm_model_1 = BayesianCompartmentalModel(model_1, parameters, priors, targets)


### Trial run 

In [None]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
output_labels = {"index": "time", "value": "number infectious"}

model_1.run(bcm_model_1.parameters)
plt.rcParams["figure.figsize"] = (5, 5)


df = pd.DataFrame(
    {
        "modelled": model_1.get_outputs_df()["I"],
        "observed": bcm_model_1.targets["active_cases"].data,
    }
)
df.plot(kind="scatter", labels=output_labels) #,figsize=(3,3));

### Sampling 

In [None]:
##____Uniform Initialisation_________
def init_uniform(num_chains, parameters):
    init_vals = []
    for c in range(num_chains):
        init_vals.append({param: np.random.uniform(0.0,1.0) for param in parameters.keys()})
    
    return init_vals


init_vals_nuts = {param: jnp.array(np.random.uniform(0.0,1.0, 4)) for param in bcm_model_1.parameters.keys()}

init_vals_4 = init_uniform(4,bcm_model_1.parameters)
init_vals_6 = init_uniform(6,bcm_model_1.parameters)



In [None]:
#Defining the Numpyro model using our likelihood from the BaysianCompartmentalModel
def nmodel():
    sampled = {k:numpyro.sample(k, dist.Uniform(0.0,1.0)) for k in bcm_model_1.parameters}
    ll = numpyro.factor("ll", bcm_model_1.loglikelihood(**sampled))

#### Simple Run

In [None]:
%%time
D = 2 # Dimension of the parameter's space
samplers =  [infer.NUTS] + [pm.sample_smc] + [pm.Metropolis] + [pm.DEMetropolisZ] + [pm.DEMetropolis]*2
Draws = [2000] + [2000] + [10000] + [8000]*3
# Tunes = [0] + [100, 1000]*5
Init = [init_vals_nuts] + [init_vals_4]*4 + [init_vals_6]
Chains = [4]*5 + [6]
results_df = pd.DataFrame()

for sampler, draws, chains, init in zip(samplers, Draws, Chains, Init):
    # if sampler.__name__ == "NUTS":
    #     init = init_vals_nuts
    # else:
    #     init = init_vals_4
    results = cal.Single_analysis(sampler = sampler, 
            draws = draws,
            chains=chains,
            cores = chains,
            tune = 1000,
            bcm_model = bcm_model_1,
            # n_iterations = 1,
            nmodel=nmodel,
            # n_jobs = 1,
            initial_params = init

    )
            
    results_df = pd.concat([results_df,results])



results_df = results_df.reset_index(drop=True)


In [None]:
results_df
results_df.style.set_caption("MCMC COMPARISON") 


In [None]:

#Storing results on a pickle file
with open('./Results/Model_1/Simple_run_results_3.pkl', 'wb') as fp:
    pickle.dump(results_df, fp)

#### Multiple runs

In [None]:
all_results = dict()

In [None]:
sampler = pm.sample_smc
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    tune = 1000,
    chains=4,
    cores=4, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,    
    )

In [None]:
sampler = pm.DEMetropolisZ
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    chains=4,
    cores=4,
    tune = 1000, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    )

In [None]:
sampler = pm.Metropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    tune = 1000, 
    chains=4,
    cores=4,
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    )

In [None]:
sampler = infer.NUTS
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 10000,
    tune = 1000,
    chains=4,
    cores=4, 
    bcm_model = bcm_model_1,
    nmodel=nmodel,
    n_iterations = 100,
    n_jobs = 2,
    )

In [None]:
sampler = pm.DEMetropolis
all_results[sampler.__name__] = cal.multirun(
    sampler, 
    draws = 1000,
    tune = 1000,
    chains=6,
    cores=4, 
    bcm_model = bcm_model_1,
    n_iterations = 100,
    n_jobs = 3,
    )

In [None]:
all_results["DEMetropolis"]

In [None]:
#Storing the results for later analysis

with open('./Results/Model_1/Multi_run_results_without_init_1.pkl', 'wb') as fp:
    pickle.dump(all_results, fp)


#### Summarizing the 100 results

We call the function group_summary from the calibrate modules. This will help to figure out the average performance
of sampler over 100 runs

In [None]:
#Loading a pickle file
with open('./Results/Model_1/Multi_run_results_3.pkl', 'rb') as fp:
    multi_res = pickle.load(fp) #It's a dict
df = pd.concat(multi_res) #converting to dataframe

Here best results refers to the run with the min 'Max_rhat' for each sampler over the  runs

In [None]:
best_results = pd.DataFrame()
for sampler in multi_res.keys():
    temp = multi_res[sampler].round(3)
    temp
    best_rhat = temp.loc[[temp["Rhat_max"].idxmin()]]
    best_results = pd.concat([best_results,best_rhat])

In [None]:
best_results = best_results.reset_index(drop=True)
#Computing the Relative ESS
best_results["Rel_Ess"] = best_results['Min_Ess'].astype(float)/(best_results["Draws"].astype(float)*best_results['Chains'].astype(float))


In [None]:
summaries_mean, prcnt_succ = cal.group_summary(df)

In [None]:
summaries_mean.round(2).to_latex(column_format="|c|c|c|c|c|c|c|c|c|c|c|")

In [None]:
prcnt_succ = prcnt_succ.round(2)

In [None]:
prcnt_succ#.to_latex(column_format='|c|c|')

#### Using arviz for trace visualization

In [None]:
for idata, sampler in zip(best_results.Trace, best_results.Sampler):
    print(sampler)
    az.plot_rank(idata,figsize=(9,4))
    plt.show()

#### Bar Ploting Comparison

In [None]:
# plot_comparison_Bars(results_df=res)
best_results["Run"] = best_results.Sampler + "\nDraws=" + best_results.Draws.astype(str) + "\nTune=" + best_results.Tune.astype(str) +"\nChains=" + best_results.Chains.astype(str)
plot_comparison_bars(best_results.round(2))

### Fitting test

Here we test if the model is well fitted to the data, we will use the results from the single run

In [None]:
IDATA = best_results["Trace"]

In [None]:
map_res = dict()
for idata, sampler in zip(IDATA,best_results["Sampler"]):
    map_res[sampler] = cal.fitting_test(idata, bcm_model_1, model_1)

In [None]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

NUTS = map_res["NUTS"]["I"]
SMC = map_res["sample_smc"]["I"]
DEM = map_res["DEMetropolis"]["I"]
MH = map_res["Metropolis"]["I"]
DEMZ = map_res["DEMetropolisZ"]["I"]



df = pd.DataFrame(
    {   
        "MH": MH,
        "NUTS": NUTS,
        "SMC": SMC,
        "DEM": DEM,
        "DEMZ": DEMZ,
        "observed": bcm_model_1.targets["active_cases"].data,
    }
)
df.plot(kind="scatter", labels=output_labels, title="Model fitting")#

## Uncertainty sampling

In [None]:
# Wrapper function captures our bcm from the main namespace to pass into map_parallel
# Using this idiom in closures/factory functions is typical
def run_sample(idx_sample):
    idx, params = idx_sample
    return idx, bcm_model_1.run(params)

# Run the samples through our BCM using the above function
# map_parallel takes a function and an iterable as input

# We use 4 workers here, default is cpu_count/2 (assumes hyperthreading)
sample_res = dict()
for idata,sampler in zip(IDATA,best_results["Sampler"]):
    sample_idata = az.extract(idata, num_samples=4000)
    samples_df = sample_idata.to_dataframe().drop(columns=["chain","draw"])
    sample_res[sampler] = map_parallel(run_sample, samples_df.iterrows(), n_workers=4)

In [None]:
# We'll use xarray for this step; aside from computing things very quickly, it's useful
# to persist the run results to netcdf/zarr etc

import xarray as xr

In [None]:
# Build a DataArray out of our results, then assign coords for indexing
xres = xr.DataArray(np.stack([r.derived_outputs for idx, r in sample_res["NUTS"]]), 
                    dims=["sample","time","variable"])
xres = xres.assign_coords(sample=sample_idata.coords["sample"], 
                          time=map_res["NUTS"].index, variable=pd.DataFrame(map_res["NUTS"]["I"]).columns)

In [None]:
# Set some quantiles to calculate
quantiles = (0.5,0.75,0.95)

# Generate a new DataArray containing the quantiles
xquantiles = xres.quantile(quantiles,dim=["sample"])

In [None]:
# Extract these values to a pandas DataFrame for ease of plotting

uncertainty_df = xquantiles.to_dataframe(name="value").reset_index().set_index("time").pivot(columns=("variable","quantile"))["value"]

In [None]:
variable = "active_cases"
pd.options.plotting.backend = "matplotlib" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

fig = uncertainty_df["I"].plot.area(title=variable,alpha=0.7)
pd.Series(map_res["NUTS"]["I"]).plot(label = "modelled",style='--')
bcm_model_1.targets[variable].data.plot(label = "observed",style='.',color="black", ms=5, alpha=0.8)
plt.legend()

# Application 2: The SEIR age-stratified model

## Data for fitting
Here we will define a target for each age category

In [3]:
pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting

df = pd.DataFrame()
df = pd.read_csv("./data/new_cases_England_2020.csv")
df["date"] = pd.to_datetime(df.date)
df.set_index(["age","date"], inplace=True)



In [4]:
total_cases = df.groupby("date").sum()
#Rolling by 14 days to discard fluctuations
total_cases = total_cases.rolling(14).mean().iloc[14:]
# D = total_cases["Jun 2020":"Nov 2020"].iloc[::5]

In [None]:
df.loc["05_09"].rolling(14).mean().iloc[14:].plot()
df.loc["10_14"].rolling(14).mean().iloc[14:].plot()

## Model Definition


In [5]:
model_config = {"compartments": ("S", "E","I","R"), # "Ip","Ic", "Is", "R"),
        "population": 56490045, #England population size 2021
        "start_time": datetime(2020, 8, 1),
        "end_time": datetime(2020, 11, 30)
}
bcm_model_2 = bcm_seir_age_strat(model_config)
model_2 = model2(model_config)

In [None]:
model_2.get_default_parameters()

## Trial run

In [None]:
disp_params = {k:v.ppf(0.5) for k,v in bcm_model_2.priors.items() if "_disp" in k} #Mandatory if you tempt to calibrate your target dispersion

res = bcm_model_2.run(bcm_model_2.parameters | disp_params).derived_outputs

# res = model_2.get_outputs_df()
Infec = [f"IXage_{i}" for i in range(0,65,5)] #Selecting the Infectious compartments
#Summing over the infectious compartments 
total_cases_pred = pd.DataFrame(res[Infec].sum(axis=1))

In [None]:

plot_start_date = datetime(2020, 8, 1)
analysis_end_date = datetime(2020, 11, 30)

# plot = model_2.get_outputs_df()["IXage_60"].plot()
plot = pd.DataFrame(total_cases_pred).plot()
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases["cases"], mode='markers', name='total_cases'))
#pd.options.plotting.backend = "plotly" #To allow plotly graphic. Swich to "matplotlib" if facing some troubles while ploting
#pivot_df["total_cases"].plot.area()

## Testing by optimization

In [None]:
import nevergrad as ng
import pandas as pd 
import numpy as np
# Import our convenience wrapper
from estival.wrappers.nevergrad import optimize_model

In [None]:
# TwoPointsDE is a good suggested default for some of our typical use cases
opt_class = ng.optimizers.TwoPointsDE
orunner = optimize_model(bcm_model_2, opt_class=opt_class)
# Here we run the optimizer in a loop, inspecting the current best point at each iteration
# Using the loss information at each step can provide the basis for stopping conditions

# for i in range(8):
    # Run the minimizer for a specified 'budget' (ie number of evaluations)
rec = orunner.minimize(1000)
    # Print the loss (objective function value) of the current recommended parameters
# print(rec.loss)
mle_params = rec.value[1]
res_opt = bcm_model_2.run(mle_params)

In [None]:
total_cases_pred_opt = pd.DataFrame(res_opt.derived_outputs[Infec].sum(axis=1))

In [None]:
D = total_cases["Aug 2020":"Nov 2020"]

In [None]:
plot_start_date = datetime(2020, 8, 1)
analysis_end_date = datetime(2020, 11, 30)

# plot = model_2.get_outputs_df()["IXage_60"].plot()
plot = pd.DataFrame(total_cases_pred_opt).plot()
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
plot.add_trace(go.Scatter(x=D.index, y=D["cases"], mode='markers', name='total_cases'))

## Calibration

In [None]:
##____Uniform Initialisation for each chain_________
chains = 4
init_vals = []
for c in range(chains):
    temp = {param: np.random.uniform(0.0,1.0) for param in list(bcm_model_2.parameters.keys())[:-3]}
    temp["seed"] = np.random.uniform(1,1200)
    # temp["incubation_period"] = np.random.uniform(1.,15.) 
    # temp["infectious_period"] = np.random.uniform(1.,15.)
    init_vals.append(temp)


init_vals_nuts = {param: jnp.array(np.random.uniform(0.0,1.0, 4)) for param in list(bcm_model_2.parameters)[:-3]}
init_vals_nuts["seed"] = jnp.array(np.random.uniform(1.,1200, 4))
# init_vals_nuts["infectious_period"] = jnp.array(np.random.uniform(1.,15.0, 4))

In [6]:
def nmodel_2():
    unif_priors = list(bcm_model_2.parameters)[:-3]
    sampled = {k:numpyro.sample(k, dist.Uniform(0.0,1.0)) for k in unif_priors}
    sampled["seed"] = numpyro.sample("seed", dist.Uniform(1.0,1000))
    #Adding the normal priors for the incubation and infectious periods

    # sampled["incubation_period"] = numpyro.sample("incubation_period", dist.TruncatedNormal(7.3, 2.0, low=1., high=14.))
    # sampled["infectious_period"] = numpyro.sample("infectious_period", dist.TruncatedNormal(5.4, 3.0, low=1., high=14.))
    # Log-likelihood
    disp_params = {k:v.ppf(0.5) for k,v in bcm_model_2.priors.items() if "_disp" in k}

    log_likelihood = bcm_model_2.loglikelihood(**sampled | disp_params)

    numpyro.factor("ll",log_likelihood)

### Performing a Single run for each algorithm

In [10]:

samplers = [pm.sample_smc,pm.Metropolis,pm.DEMetropolisZ,pm.DEMetropolis] #, infer.NUTS]
Draws = [20000]*4
# Init = [init_vals_4]*4
results_df = pd.DataFrame()

for sampler, draws in zip(samplers, Draws):#, Init):
    results = cal.Single_analysis(sampler = sampler, 
            draws = draws,
            chains= 4,
            cores = 4,
            tune = 5000,
            bcm_model = bcm_model_2,
            # initial_params = init_vals,
            nmodel=nmodel_2,
            # n_iterations=1,
            # n_jobs=2,
            )
            
    results_df = pd.concat([results_df,results])



results_df = results_df.reset_index(drop=True)

Initializing SMC sampler...
Sampling 4 chains in 4 jobs


█

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>Metropolis: [seed]
>Metropolis: [age_transmission_rate_0]
>Metropolis: [age_transmission_rate_5]
>Metropolis: [age_transmission_rate_10]
>Metropolis: [age_transmission_rate_15]
>Metropolis: [age_transmission_rate_20]
>Metropolis: [age_transmission_rate_25]
>Metropolis: [age_transmission_rate_30]
>Metropolis: [age_transmission_rate_35]
>Metropolis: [age_transmission_rate_40]
>Metropolis: [age_transmission_rate_45]
>Metropolis: [age_transmission_rate_50]
>Metropolis: [age_transmission_rate_55]
>Metropolis: [age_transmission_rate_60]
Sampling 4 chains for 5_000 tune and 20_000 draw iterations (20_000 + 80_000 draws total) took 736 seconds.
Multiprocess sampling (4 chains in 4 jobs)
DEMetropolisZ: [seed, age_transmission_rate_0, age_transmission_rate_5, age_transmission_rate_10, age_transmission_rate_15, age_transmission_rate_20, age_transmission_rate_25, age_transmission_rate_30, age_transmission_rate_35, age_transmission_rate_40, a

In [11]:
#Storing results on a pickle file
with open('./Results/Model_2/Simple_run_results.pkl', 'wb') as fp:
    pickle.dump(results_df, fp)

In [None]:
Trace = results_df["Trace"]

In [None]:
idata = Trace[0]

In [None]:
az.summary(idata)

In [14]:
#Probabily the nuts will take a while to finish--We need to reduce the number of iterations and or warmup
res_nuts = cal.Single_analysis(sampler = infer.NUTS, 
            draws = 10000,
            tune = 1000,
            chains = 4,
            cores=4,
            bcm_model = bcm_model_2,
            nmodel=nmodel_2,
            # initial_params = init_vals_nuts)
)

results_df = pd.concat([results_df,res_nuts])


In [29]:
results_df

Unnamed: 0,Sampler,Draws,Chains,Tune,Time,Mean_ESS,Min_Ess,Ess_per_sec,Mean_Rhat,Rhat_max,Trace,Run
0,sample_smc,20000,4,5000,125.15409,14659.059689,9641.789905,117.128091,1.00027,1.000619,"(posterior, sample_stats)",sample_smc\nDraws=20000\nTune=5000\nChains=4
1,Metropolis,20000,4,5000,775.881797,8239.111571,3304.558787,10.619029,1.000568,1.001775,"(posterior, sample_stats)",Metropolis\nDraws=20000\nTune=5000\nChains=4
2,DEMetropolisZ,20000,4,5000,105.000675,239.984603,24.349807,2.285553,1.048324,1.112197,"(posterior, sample_stats)",DEMetropolisZ\nDraws=20000\nTune=5000\nChains=4
3,DEMetropolis,20000,4,5000,149.753788,516.484206,310.195933,3.448889,1.006941,1.024682,"(posterior, sample_stats)",DEMetropolis\nDraws=20000\nTune=5000\nChains=4
0,NUTS,10000,4,1000,3778.305681,27826.43898,21931.361857,7.364793,1.00014,1.000347,"(posterior, log_likelihood, sample_stats, obse...",NUTS\nDraws=10000\nTune=1000\nChains=4


In [30]:
results_df = results_df.reset_index(drop=True)

In [15]:
#Storing results on a pickle file
with open('./Results/Model_2/Simple_run_results_with_NUTS.pkl', 'wb') as fp:
    pickle.dump(results_df, fp)

In [34]:
def fitting_test(sampler,idata, bcm, model):
    from estival.sampling.tools import likelihood_extras_for_samples,likelihood_extras_for_idata
    if sampler in ["DEMetropolis", "DEMetropolisZ"]:
        likelihood_df = likelihood_extras_for_idata(idata, bcm) #More faster
    else :
        likelihood_df = likelihood_extras_for_samples(idata.posterior, bcm)

    ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)

    # Extract the parameters from the calibration samples
    map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
    map_params['incubation_period']= 5.4
    map_params['infectious_period'] = 7.3
    print(map_params)
    bcm.loglikelihood(**map_params), ldf_sorted.iloc[0]["loglikelihood"]
    # Run the model with these parameters
    model.run(map_params)
    # ...and plot some results
    return model.get_outputs_df()

In [31]:
trace = results_df.Trace.loc[0]
az.summary(trace)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
seed,337.993,330.626,1.007,995.775,2.263,1.601,18330.0,13097.0,1.0
age_transmission_rate_0,0.275,0.206,0.0,0.658,0.002,0.001,13223.0,6349.0,1.0
age_transmission_rate_5,0.264,0.204,0.0,0.649,0.002,0.001,12506.0,9528.0,1.0
age_transmission_rate_10,0.275,0.207,0.0,0.665,0.001,0.001,17369.0,16182.0,1.0
age_transmission_rate_15,0.294,0.21,0.0,0.679,0.001,0.001,17391.0,17952.0,1.0
age_transmission_rate_20,0.294,0.212,0.0,0.684,0.002,0.001,13392.0,10302.0,1.0
age_transmission_rate_25,0.283,0.208,0.0,0.672,0.002,0.001,13896.0,11699.0,1.0
age_transmission_rate_30,0.272,0.201,0.0,0.648,0.001,0.001,15571.0,11370.0,1.0
age_transmission_rate_35,0.291,0.212,0.0,0.679,0.002,0.001,12098.0,8868.0,1.0
age_transmission_rate_40,0.278,0.206,0.0,0.661,0.002,0.001,9888.0,5886.0,1.0


In [35]:
Map_res = dict()
for row in results_df.index:
    idata = results_df.Trace.loc[row]
    sampler =  results_df.at[row, "Sampler"]
    Map_res[sampler] = fitting_test(sampler,idata, bcm_model_2,model_2)    

{'seed': 539.2389387130777, 'age_transmission_rate_0': 0.015467143650231486, 'age_transmission_rate_5': 0.03276435051090649, 'age_transmission_rate_10': 0.08613424950645021, 'age_transmission_rate_15': 0.26488851518984957, 'age_transmission_rate_20': 0.37165195426142345, 'age_transmission_rate_25': 0.3697583303443199, 'age_transmission_rate_30': 0.15543790393804446, 'age_transmission_rate_35': 0.2899538272375832, 'age_transmission_rate_40': 0.00046120626812804497, 'age_transmission_rate_45': 0.1196386607946489, 'age_transmission_rate_50': 0.360847897208016, 'age_transmission_rate_55': 0.32477525242134986, 'age_transmission_rate_60': 0.3046208927082216, 'incubation_period': 5.4, 'infectious_period': 7.3}
{'seed': 1146.0761132244909, 'age_transmission_rate_0': 0.029057189104429955, 'age_transmission_rate_5': 0.021899528743239218, 'age_transmission_rate_10': 0.09906803819266106, 'age_transmission_rate_15': 0.18318457961581724, 'age_transmission_rate_20': 0.31136775324034843, 'age_transmis

In [39]:
T = [{'seed': 539.2389387130777, 'age_transmission_rate_0': 0.015467143650231486, 'age_transmission_rate_5': 0.03276435051090649, 'age_transmission_rate_10': 0.08613424950645021, 'age_transmission_rate_15': 0.26488851518984957, 'age_transmission_rate_20': 0.37165195426142345, 'age_transmission_rate_25': 0.3697583303443199, 'age_transmission_rate_30': 0.15543790393804446, 'age_transmission_rate_35': 0.2899538272375832, 'age_transmission_rate_40': 0.00046120626812804497, 'age_transmission_rate_45': 0.1196386607946489, 'age_transmission_rate_50': 0.360847897208016, 'age_transmission_rate_55': 0.32477525242134986, 'age_transmission_rate_60': 0.3046208927082216, 'incubation_period': 5.4, 'infectious_period': 7.3},
{'seed': 1146.0761132244909, 'age_transmission_rate_0': 0.029057189104429955, 'age_transmission_rate_5': 0.021899528743239218, 'age_transmission_rate_10': 0.09906803819266106, 'age_transmission_rate_15': 0.18318457961581724, 'age_transmission_rate_20': 0.31136775324034843, 'age_transmission_rate_25': 0.2164782040853706, 'age_transmission_rate_30': 0.24465238926223354, 'age_transmission_rate_35': 0.23368336007944734, 'age_transmission_rate_40': 0.11319694145272417, 'age_transmission_rate_45': 0.20044761073859862, 'age_transmission_rate_50': 0.4732435410753902, 'age_transmission_rate_55': 0.20119950900578554, 'age_transmission_rate_60': 0.25277139849043323, 'incubation_period': 5.4, 'infectious_period': 7.3},
{'seed': 737.8351851156198, 'age_transmission_rate_0': 0.005748200831073296, 'age_transmission_rate_5': 0.002125276984314101, 'age_transmission_rate_10': 0.012283957329623318, 'age_transmission_rate_15': 0.28044069231375346, 'age_transmission_rate_20': 0.1635040254497786, 'age_transmission_rate_25': 0.4440302587942281, 'age_transmission_rate_30': 0.344902741528238, 'age_transmission_rate_35': 0.36666635653772645, 'age_transmission_rate_40': 0.1951797720185409, 'age_transmission_rate_45': 0.2813771477790979, 'age_transmission_rate_50': 0.2998889259623433, 'age_transmission_rate_55': 0.11310119637218947, 'age_transmission_rate_60': 0.26616513146459564, 'incubation_period': 5.4, 'infectious_period': 7.3},
{'seed': 569.3289333661925, 'age_transmission_rate_0': 0.08848608486432526, 'age_transmission_rate_5': 0.03718579687258949, 'age_transmission_rate_10': 0.02341189844055642, 'age_transmission_rate_15': 0.18428572469016213, 'age_transmission_rate_20': 0.4315234335671583, 'age_transmission_rate_25': 0.2662633052913372, 'age_transmission_rate_30': 0.15699562933223365, 'age_transmission_rate_35': 0.15156443136208286, 'age_transmission_rate_40': 0.329080777317896, 'age_transmission_rate_45': 0.25399081704238835, 'age_transmission_rate_50': 0.23565856313718944, 'age_transmission_rate_55': 0.1768109031919722, 'age_transmission_rate_60': 0.32631544701209847, 'incubation_period': 5.4, 'infectious_period': 7.3},
{'age_transmission_rate_0': 0.04824111514966805, 'age_transmission_rate_10': 0.05246681798350639, 'age_transmission_rate_15': 0.25534189918772193, 'age_transmission_rate_20': 0.5033705061489184, 'age_transmission_rate_25': 0.3575984389894895, 'age_transmission_rate_30': 0.4494389069922813, 'age_transmission_rate_35': 0.01491539786595375, 'age_transmission_rate_40': 0.11341583217025224, 'age_transmission_rate_45': 0.4211545523347463, 'age_transmission_rate_5': 0.08123980708961236, 'age_transmission_rate_50': 0.11527856700214348, 'age_transmission_rate_55': 0.09720420860277336, 'age_transmission_rate_60': 0.24796993010907342, 'seed': 968.7900546134347, 'incubation_period': 5.4, 'infectious_period': 7.3}]
Map_params = dict()
i=0
for sampler in Map_res.keys():
    Map_params[sampler] = T[i]
    i=i+1

In [42]:
Map_params.to_dataframe

{'sample_smc': {'seed': 539.2389387130777,
  'age_transmission_rate_0': 0.015467143650231486,
  'age_transmission_rate_5': 0.03276435051090649,
  'age_transmission_rate_10': 0.08613424950645021,
  'age_transmission_rate_15': 0.26488851518984957,
  'age_transmission_rate_20': 0.37165195426142345,
  'age_transmission_rate_25': 0.3697583303443199,
  'age_transmission_rate_30': 0.15543790393804446,
  'age_transmission_rate_35': 0.2899538272375832,
  'age_transmission_rate_40': 0.00046120626812804497,
  'age_transmission_rate_45': 0.1196386607946489,
  'age_transmission_rate_50': 0.360847897208016,
  'age_transmission_rate_55': 0.32477525242134986,
  'age_transmission_rate_60': 0.3046208927082216,
  'incubation_period': 5.4,
  'infectious_period': 7.3},
 'Metropolis': {'seed': 1146.0761132244909,
  'age_transmission_rate_0': 0.029057189104429955,
  'age_transmission_rate_5': 0.021899528743239218,
  'age_transmission_rate_10': 0.09906803819266106,
  'age_transmission_rate_15': 0.183184579615

In [None]:
model_2.run(map_params)

map_res_smc = model_2.get_outputs_df()


In [None]:
Infec = [f"IXage_{i}" for i in range(0,65,5)]


In [None]:
# total_cases_pred_DEM = map_res_smc[Infec].sum(axis=1)
total_cases_pred_SMC = map_res_smc[Infec].sum(axis=1)

In [None]:
total_cases = total_cases["Aug 2020":"Nov 2020"].iloc[::14]

In [None]:
plot_start_date = datetime(2020, 8, 1)
analysis_end_date = datetime(2020, 11, 30)

#plot= map_res["IXage_5"].plot()
plot = total_cases_pred_SMC.plot()
plot.update_xaxes(range=(plot_start_date, analysis_end_date))
# plot.add_trace(go.Scatter(x=df.loc["05_09"].index, y=df.loc["05_09"]["cases"], mode='markers', name='observed'))
plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases["cases"], mode='markers', name='total_cases'))
# plot.add_trace(go.Scatter(x=total_cases.index, y=total_cases_pred_DEM, name='DEM'))

