In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
import copy
import time
import pickle as pkl

import sys
sys.path.append('../../')

from data.processing import get_data

import models

from main.seir.main import single_fitting_cycle
from main.seir.forecast import get_forecast, forecast_all_trials, create_all_trials_csv, create_decile_csv_new
from utils.generic.create_report import save_dict_and_create_report
from utils.generic.config import read_config, make_date_key_str
from utils.generic.enums import Columns
from utils.fitting.loss import Loss_Calculator
from utils.generic.logging import log_wandb, log_mlflow
from viz import plot_forecast, plot_top_k_trials, plot_ptiles, plot_all_params, plot_all_losses
from viz.uncertainty import plot_beta_loss

import yaml
import wandb

In [None]:
predictions_dict = {}

In [None]:
config_filenames = ['default.yaml', 'undetected.yaml', 'seir_pu.yaml']
configs = [read_config(config_filename) for config_filename in config_filenames]
# tuple format (state, district, starting_date, ending_date, N, num_trials)
location_tuples = {
    'MUMBAI' : ('Maharashtra', 'Mumbai', None, None, 2.0e+7, 250),
    'PUNE' : ('Maharashtra', 'Pune', None, None, 0.6e+7, 250),
    'DELHI' : ('Delhi', None, None, None, 2.0e+7, 250),
    'RANCHI'  : ('Jharkhand', 'Ranchi', None, None, 0.14e+7, 250),
    'BOKARO'  : ('Jharkhand', 'Bokaro', None, None, 0.06e+7, 250),
}

In [None]:
num_rep_trials = 5
for tag, loc in location_tuples.items():
    predictions_dict[tag] = {}
    for j, config in enumerate(configs):
        config_filename = config_filenames[j].split(".")[0]
        predictions_dict[tag][config_filename] = {}
        config_params = copy.deepcopy(config['fitting'])
        config_params['data']['dataloading_params']['state'] = loc[0]
        config_params['data']['dataloading_params']['district'] = loc[1]
        config_params['split']['start_date'] = loc[2]
        config_params['split']['end_date'] = loc[3]
        config_params['default_params']['N'] = loc[4]
        config_params['fitting_method_params']['num_evals'] = loc[5]
        if loc[1] != 'Mumbai':
            config_params['data']['smooth_jump'] = False
        for k in range(num_rep_trials):
            print ("****************")
            print(tag, config_filename, k)
            print ("****************")
            predictions_dict[tag][config_filename][f'm{k}'] = single_fitting_cycle(**config_params) 

In [None]:
save_dir = '../../misc/predictions/deg_exp/'    
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for tag, tag_dict in predictions_dict.items():
    with open(os.path.join(save_dir, tag + ".pickle"), 'wb') as handle:
        pkl.dump(tag_dict, handle)

### Use the pickle file to read the predicitons_dict

In [None]:
save_dir = '../../misc/predictions/deg_exp/'    
predictions_dict_complete = {}
files = os.listdir(save_dir)
for file_name in files:
    with open(os.path.join(save_dir, file_name), 'rb') as handle:
        predictions_dict_complete[file_name.split('.')[0]] = pkl.load(handle)

In [None]:
params_distribution = plot_all_params(predictions_dict_complete, method='ensemble')

In [None]:
losses_distribution = plot_all_losses(predictions_dict_complete, method='ensemble_loss_ra')