In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import entropy
import datetime
import copy
import time
import wandb
import pickle as pkl
import sys
import math
sys.path.append('../../')
from data.processing import get_data
import models
from main.seir.fitting import single_fitting_cycle
from main.seir.forecast import get_forecast, forecast_all_trials, create_all_trials_csv, create_decile_csv_new
from main.seir.sensitivity import calculate_sensitivity_and_plot
from utils.generic.create_report import save_dict_and_create_report
from utils.generic.config import read_config
from utils.generic.enums import Columns
from utils.fitting.loss import Loss_Calculator
from utils.generic.logging import log_wandb
from viz import plot_forecast, plot_top_k_trials, plot_ptiles
from viz.fit import plot_histogram, plot_all_histograms, plot_mean_variance, plot_scatter, plot_kl_divergence, plot_heatmap_distribution_sigmas, plot_all_params, plot_all_losses, plot_all_buckets, plot_cv_in_params, plot_recovery_loss, plot_confidence_interval
import yaml
from data.dataloader import SimulatedDataLoader

In [None]:
def update_hosp_ratios(param_set, data_config, model_config):
    df = pd.read_csv(os.path.join('../../data/data/simulated_data/', data_config['output_file_name']), index_col=0)
    if model_config['end_date']:
        # print(model_config['end_date'])
        if isinstance(model_config['end_date'], int):
            if end_date > 0:
                raise ValueError('Please enter a negative value for end_date if entering an integer')
        if isinstance(model_config['end_date'], datetime.date):
            df['date'] = pd.to_datetime(df['date'])
            end_date = df.loc[df['date'].dt.date == model_config['end_date']].index[0] - len(df) + 1
    else:
        end_date = 0
    train_start_row = df.iloc[len(df) - (model_config['train_period'] + model_config['val_period'] + model_config['test_period']) + end_date]
    if data_config['model'] == 'SEIRHD':
        param_set['E_hosp_ratio'] = train_start_row['E'] / train_start_row['active']
        param_set['I_hosp_ratio'] = train_start_row['I'] / train_start_row['active']
    elif data_config['model'] == 'SEIRHD_Beta':
        param_set['E_hosp_ratio'] = train_start_row['E'] / train_start_row['active']
        param_set['I_hosp_ratio'] = train_start_row['I'] / train_start_row['active']
    elif data_config['model'] == 'SEIR_PU':
        param_set['E_hosp_ratio'] = train_start_row['E'] / train_start_row['active']
        param_set['I_hosp_ratio'] = train_start_row['I'] / train_start_row['active']
        param_set['Pu_pop_ratio'] = train_start_row['Pu'] / train_start_row['']
    return param_set

# Create short pickle file

In [None]:
save_dir = '../../misc/predictions/exp3' 
file_name = 'exp3_fixed_params_multiple_val.pickle'
with open(os.path.join(save_dir, file_name), 'rb') as handle:
    input_dict = pkl.load(handle)

In [None]:
output = []
for pred in input_dict['predictions_dicts']:
    pd_output = {}
    pd_output['prediction_dict'] = {}
    pd_output['prediction_dict']['best_params'] = pred['prediction_dict']['best_params']
    pd_output['prediction_dict']['df_loss'] = pred['prediction_dict']['df_loss']
    pd_output['run_tuple'] = pred['run_tuple']
    # pd['trials'] = pd['prediction_dict']['trials']
    # del pd['prediction_dict']
    output.append(pd_output)

In [None]:
output_dict = {}
output_dict['predictions_dicts'] = output
output_dict['model_config'] = input_dict['model_config']
output_dict['val_periods'] = input_dict['val_periods']
output_dict['train_periods'] = input_dict['train_periods']
output_dict['end_dates'] = input_dict['end_dates']
with open(os.path.join('../../misc/predictions/exp2', "exp2_trial2_short.pickle"), 'wb') as handle:
    pkl.dump(output_dict, handle)

## Plots

In [None]:
save_dir = '../../misc/predictions/exp2' 
file_name = 'exp2_trial1_short.pickle'
with open(os.path.join(save_dir, file_name), 'rb') as handle:
    run_dict = pkl.load(handle)

In [None]:
val_periods = run_dict['val_periods']
format_str = '%d-%m-%Y' # The format
end_date = datetime.datetime.strptime('31-12-2020',format_str).date()

In [None]:
plot_dict = {val:[] for val in val_periods}

In [None]:
train_period = 28
for run in run_dict['predictions_dicts'] : 
    run_tuple = run['run_tuple']
    if(run_tuple['train'] != train_period or run_tuple['end_date'] != end_date):
        continue
    losses = run['prediction_dict']['df_loss']
    plot_dict[run_tuple['val']].append(np.mean(losses[which_loss]))

In [None]:
for val in val_periods : 
    plot_dict[val] = np.mean(np.array(plot_dict[val]))

In [None]:
plot_dict.keys(), plot_dict.values()

In [None]:
n_subplots = 1
ncols = 2
nrows = math.ceil(n_subplots/ncols)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, 
                        figsize=(18, 6*nrows))
ax_counter = 0
which_loss = 'val'
loss_dict = plot_dict[which_loss]
ax = axs.flat[ax_counter]
for end_date, end_date_dict in loss_dict.items():
    ax.plot(list(end_date_dict.keys()), list(end_date_dict.values()), label=end_date)
ax_counter += 1
ax.set_title(which_loss)
ax.set_xlabel("Val periods")
ax.set_ylabel("Average MAPE loss")
ax.legend(title="end date")

# Plot 2

In [None]:
save_dir = '../../misc/predictions/exp2' 
file_name = 'exp2_trial2_short.pickle'
with open(os.path.join(save_dir, file_name), 'rb') as handle:
    run_dict = pkl.load(handle)

In [None]:
run_dict['end_dates'], run_dict['val_periods']

In [None]:
val_periods = [14, 28, 42, 56, 70, 84, 98, 112]
format_str = '%d-%m-%Y' # The format
end_date = datetime.datetime.strptime('1-1-2021',format_str).date()

In [None]:
plot_dict_1 = {val:[] for val in val_periods}

In [None]:
train_period = 28
for run in run_dict['predictions_dicts'] : 
    run_tuple = run['run_tuple']
    if(run_tuple['train'] != train_period or run_tuple['end_date'] != end_date):
        continue
    losses = run['prediction_dict']['df_loss']
    for val in val_periods:
        plot_dict_1[val].append(np.mean(losses["val_" + str(val)]))

In [None]:
plot_dict_1_std = {}
for val in val_periods : 
    plot_dict_1[val] = np.mean(np.array(plot_dict_1[val]))
    plot_dict_1_std[val] = np.mean(np.array(plot_dict_1[val]))

In [None]:
plot_dict_1.keys(), plot_dict_1.values()

In [None]:
plt.plot(list(plot_dict.keys()), list(plot_dict.values()), label="sc1")
plt.plot(list(plot_dict_1.keys()), list(plot_dict_1.values()), label="sc2")
plt.legend(loc='best')
plt.ylabel("Average Val Loss (MAPE %)")
plt.xlabel("Val period")
plt.title("sc1: None \n sc2: T_recov_fatal, T_inf, T_inc")