In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import entropy
import datetime
import copy
import time
import wandb
import pickle as pkl

import sys
sys.path.append('../../')

from data.processing import get_data

import models

from main.seir.fitting import single_fitting_cycle
from main.seir.forecast import get_forecast, forecast_all_trials, create_all_trials_csv, create_decile_csv_new
from main.seir.sensitivity import calculate_sensitivity_and_plot
from utils.generic.create_report import save_dict_and_create_report
from utils.generic.config import read_config
from utils.generic.enums import Columns
from utils.fitting.loss import Loss_Calculator
from utils.generic.logging import log_wandb
from viz import plot_forecast, plot_top_k_trials, plot_ptiles
from viz.fit import plot_histogram, plot_all_histograms, plot_mean_variance, plot_scatter, plot_kl_divergence, plot_heatmap_distribution_sigmas, plot_all_params, plot_all_losses, plot_all_buckets
import yaml

In [None]:
predictions_dict = {}

In [None]:
config_filenames = ['experiments/seirhd.yaml', 'experiments/undetected.yaml', 'experiments/seir_pu.yaml', 'experiments/pu_testing.yaml']
model_params = {
        'SEIRHD': [ 'lockdown_R0', 'T_inc', 'T_inf', 'T_inf', 'T_recov', 'T_recov_fatal', 'P_fatal', 'E_hosp_ratio', 'I_hosp_ratio'],
        'SEIR_Undetected': [ 'T_inc', 'T_inf_D', 'T_inf_U', 'T_recov', 'T_recov_fatal', 'beta', 'd', 'P_fatal', 'I_D_hosp_ratio', 'I_U_hosp_ratio', 'E_hosp_ratio'],
        'SEIR_PU': [ 'T_inc', 'T_inf_U', 'T_recov', 'T_recov_fatal', 'beta', 'd', 'P_fatal', 'I_hosp_ratio', 'E_hosp_ratio','Pu_pop_ratio'],
        'SEIR_PU_Testing': [ 'T_inc', 'T_inf_U', 'T_recov', 'T_recov_fatal', 'beta', 'd', 'P_fatal', 'I_hosp_ratio', 'E_hosp_ratio','Pu_pop_ratio']
    }
model_names = list(model_params.keys())
configs = [read_config(config_filename) for config_filename in config_filenames]

In [None]:
output_folder = '../../misc/reports/{}'.format(datetime.datetime.now().strftime("%Y_%m%d_%H%M%S"))

In [None]:
predictions_dict.keys()

In [None]:
# tuple format (state, district, starting_date, ending_date)
# TODO: take config blocks instead of state or district
location_tuples = [
    ('Maharashtra', 'Mumbai'),
    ('Maharashtra', 'Pune'),
]

In [None]:
num_rep_trials = 2
for i, loc in enumerate(location_tuples):
    predictions_dict[loc] = {}
    for j, config in enumerate(configs):
        predictions_dict[loc][model_names[j]] = {}
        config_params = copy.deepcopy(config['fitting'])
        config_params['data']['dataloading_params']['state'] = loc[0]
        config_params['data']['dataloading_params']['district'] = loc[1]
        if loc[1] != 'Mumbai':
            config_params['data']['smooth_jump'] = False
        for k in range(num_rep_trials):
            predictions_dict[loc][model_names[j]][f'm{k}'] = single_fitting_cycle(**config_params) 

In [None]:
with open('../../misc/predictions/predictions_dict.pickle', 'wb') as handle:
    pkl.dump(predictions_dict, handle)

In [None]:
wandb.init(project="covid-modelling")
wandb.run.name = "degeneracy-exps-location"+wandb.run.name

In [None]:
plot_all_params(predictions_dict, model_params, method='ensemble_combined')

In [None]:
which_compartments = {model_names[i]: config['fitting']['loss']['loss_compartments'] for i, config in enumerate(configs)}
plot_all_losses(predictions_dict, which_losses=['train', 'val'], which_compartments=which_compartments)

In [None]:
plot_all_buckets(predictions_dict, which_buckets=['S', 'I', 'E', 'I_U'], compare='location')