In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import entropy
import datetime
import copy
import time
import wandb
import pickle as pkl
import sys
import math
sys.path.append('../../')
from data.processing import get_data
import models
from main.seir.fitting import single_fitting_cycle
from main.seir.forecast import get_forecast, forecast_all_trials, create_all_trials_csv, create_decile_csv_new
from main.seir.sensitivity import calculate_sensitivity_and_plot
from utils.generic.create_report import save_dict_and_create_report
from utils.generic.config import read_config
from utils.generic.enums import Columns
from utils.fitting.loss import Loss_Calculator
from utils.generic.logging import log_wandb
from viz import plot_forecast, plot_top_k_trials, plot_ptiles
from viz.fit import plot_histogram, plot_all_histograms, plot_mean_variance, plot_scatter, plot_kl_divergence, plot_heatmap_distribution_sigmas, plot_all_params, plot_all_losses, plot_all_buckets, plot_cv_in_params, plot_recovery_loss, plot_confidence_interval
import yaml
from data.dataloader import SimulatedDataLoader

In [3]:
def update_hosp_ratios(param_set, data_config, model_config):
    df = pd.read_csv(os.path.join('../../data/data/simulated_data/', data_config['output_file_name']), index_col=0)
    if model_config['end_date']:
        if isinstance(end_date, int):
            if end_date > 0:
                raise ValueError('Please enter a negative value for end_date if entering an integer')
        if isinstance(end_date, datetime.date):
            end_date = df.loc[df['date'].dt.date == end_date].index[0] - len(df) + 1
    else:
        end_date = 0
    train_start_row = df.iloc[len(df) - (model_config['train_period'] + model_config['val_period'] + model_config['test_period']) + end_date]
    if data_config['model'] == 'SEIRHD':
        param_set['E_hosp_ratio'] = train_start_row['E'] / train_start_row['active']
        param_set['I_hosp_ratio'] = train_start_row['I'] / train_start_row['active']
    elif data_config['model'] == 'SEIR_PU':
        param_set['E_hosp_ratio'] = train_start_row['E'] / train_start_row['active']
        param_set['I_hosp_ratio'] = train_start_row['I'] / train_start_row['active']
        param_set['Pu_pop_ratio'] = train_start_row['Pu'] / train_start_row['']
    return param_set

In [4]:
model_config_filename = 'default.yaml'
simulated_config_filename = 'seirhd_fixed.yaml'
with open(os.path.join("../../configs/seir/", model_config_filename)) as configfile:
    model_config = yaml.load(configfile, Loader=yaml.SafeLoader)    
with open(os.path.join("../../configs/simulated_data/", simulated_config_filename)) as configfile:
    simulated_config = yaml.load(configfile, Loader=yaml.SafeLoader)    
actual_params = simulated_config['params']
actual_params = update_hosp_ratios(actual_params, simulated_config, model_config['fitting']['split'])

# Create short pickle file

In [14]:
save_dir = '../../misc/predictions/exp3' 
file_name = 'exp3_trial1.pickle'
with open(os.path.join(save_dir, file_name), 'rb') as handle:
    input_dict = pkl.load(handle)

In [25]:
output = []
for pd in input_dict['predictions_dicts']:
    pd_output = {}
    pd_output['prediction_dict'] = {}
    pd_output['prediction_dict']['best_params'] = pd['prediction_dict']['best_params']
    pd_output['prediction_dict']['df_loss'] = pd['prediction_dict']['df_loss']
    pd_output['run_tuple'] = pd['run_tuple']
    # pd['trials'] = pd['prediction_dict']['trials']
    # del pd['prediction_dict']
    output.append(pd_output)

In [26]:
output_dict = {}
output_dict['predictions_dicts'] = output
output_dict['model_config'] = input_dict['model_config']
output_dict['val_periods'] = input_dict['val_periods']
output_dict['train_periods'] = input_dict['train_periods']
output_dict['end_dates'] = input_dict['end_dates']
with open(os.path.join(save_dir, "exp3_trial1_short.pickle"), 'wb') as handle:
    pkl.dump(output_dict, handle)

In [None]:
run_dict = copy.deepcopy(output_dict['predictions_dicts'])
for runs in run_dict:
    