In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import entropy
import datetime
import copy
import time
import wandb
import pickle as pkl

import sys
sys.path.append('../../')

from data.processing import get_data

import models

from main.seir.fitting import single_fitting_cycle
from main.seir.forecast import get_forecast, forecast_all_trials, create_all_trials_csv, create_decile_csv_new
from main.seir.sensitivity import calculate_sensitivity_and_plot
from utils.generic.create_report import save_dict_and_create_report
from utils.generic.config import read_config
from utils.generic.enums import Columns
from utils.fitting.loss import Loss_Calculator
from utils.generic.logging import log_wandb
from viz import plot_forecast, plot_top_k_trials, plot_ptiles
from viz.fit import plot_histogram, plot_all_histograms, plot_mean_variance, plot_scatter, plot_kl_divergence, plot_heatmap_distribution_sigmas, plot_all_params, plot_all_losses, plot_all_buckets
import yaml

In [None]:
predictions_dict = {}

In [None]:
output_folder = '../../misc/reports/{}'.format(datetime.datetime.now().strftime("%Y_%m%d_%H%M%S"))

In [None]:
config_filenames = ['uncer.yaml','default.yaml']
model_names = ['MCMC','BO']

In [None]:
configs = [read_config(cnf) for cnf in config_filenames]

In [None]:
dates = []
for i in range(1):
    dateT = datetime.datetime.now() - datetime.timedelta(days = 20)
    dates.append(dateT)
dates

In [None]:
num_rep_trials = 1
for tag, end_date in enumerate(dates):
    predictions_dict[tag] = {}
    for j, config in enumerate(configs):
        predictions_dict[tag][model_names[j]] = {}
        config_params = copy.deepcopy(config['fitting'])
        config_params['split']['end_date'] = end_date.date()
        for k in range(num_rep_trials):
            predictions_dict[tag][model_names[j]][f'm{k}'] = single_fitting_cycle(**config_params)
            uncertainty_args = {'predictions_dict': predictions_dict[tag][model_names[j]][f'm{k}'], **config['uncertainty']['uncertainty_params']}
            uncertainty = config['uncertainty']['method'](**uncertainty_args)
            uncertainty_forecasts = uncertainty.get_forecasts()

In [None]:
with open('../../misc/predictions/predictions_dict_perc.pickle', 'wb') as handle:
    pkl.dump(predictions_dict, handle)

In [None]:
with open('../../misc/predictions/predictions_dict_perc.pickle', 'rb') as handle:
    predictions_dict = pkl.load(handle)

In [None]:
trials = []
for i in range(num_rep_trials):
    trials.append(f'm{i}')
trials

In [None]:
loss_type = ['train','val']
compartments = ['total', 'recovered', 'deceased']

In [None]:
import pandas as pd

In [None]:
for l in loss_type:
    fig, AX = plt.subplots(nrows=1, ncols=4, sharex=True,figsize=(15, 8))
    for i,c in enumerate(compartments):
        MC_loss =[ predictions_dict[0]['MCMC'][i]['df_loss'][l][c] for i in trials]
        BO_loss =[ predictions_dict[0]['BO'][i]['df_loss'][l][c] for i in trials]
        data = {"MCMC":MC_loss,"BO":BO_loss}
        df = pd.DataFrame(data,columns = ["MCMC","BO"])
        sns.barplot(data =df ,ax = AX[i],palette = 'bright')
        AX[i].title.set_text(c+" "+l +" loss")

In [None]:
model_params = {
        'MCMC': [ 'lockdown_R0', 'T_inc', 'T_inf', 'T_inf', 'T_recov', 'T_recov_fatal', 'P_fatal', 'E_hosp_ratio', 'I_hosp_ratio','sigma'],
        'BO': [ 'lockdown_R0', 'T_inc', 'T_inf', 'T_inf', 'T_recov', 'T_recov_fatal', 'P_fatal', 'E_hosp_ratio', 'I_hosp_ratio','sigma'],
}

In [None]:
plot_all_params(predictions_dict, model_params, method='ensemble_combined')

In [None]:
which_compartments = {model_names[i]: config['fitting']['loss']['loss_compartments'] for i, config in enumerate(configs)}
plot_all_losses(predictions_dict, which_losses=['train', 'val'], which_compartments=which_compartments)

In [None]:
from main.seir.forecast import _get_top_k_trials as topk

In [None]:
params,losses = topk(predictions_dict[tag][model_names[j]][f'm{k}'])

In [None]:
loss = {}
for tag in range (1):
    for j,mn in enumerate(model_names) :
        a = []
        for t in trials:
            _,l = topk(predictions_dict[tag][model_names[j]][t], k = 5)
            a.extend(l)
        if mn == 'MCMC':
            a = [4*i for i in a]
        loss[mn] = copy.copy(a)

In [None]:
(loss['MCMC'])

In [None]:
df = pd.DataFrame(loss,columns = ["MCMC","BO"])
plt.figure(
    figsize=(7,12))
sns.barplot(data = df )
plt.title("Top 50 losses")
