In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import copy
import time
import wandb

import sys
sys.path.append('../../')

from data.processing import get_data

import models

from main.seir.fitting import single_fitting_cycle
from main.seir.forecast import get_forecast, forecast_all_trials, create_all_trials_csv, create_decile_csv_new
from main.seir.sensitivity import calculate_sensitivity_and_plot
from utils.generic.create_report import save_dict_and_create_report
from utils.generic.config import read_config
from utils.generic.enums import Columns
from utils.fitting.loss import Loss_Calculator
from utils.generic.logging import log_wandb
from viz import plot_forecast, plot_top_k_trials, plot_ptiles
from viz.fit import plot_histogram, plot_all_histograms, plot_mean_variance
import yaml

In [None]:
predictions_dict = {}

In [None]:
config_filename = 'default.yaml'
config = read_config(config_filename)

In [None]:
output_folder = '../../misc/reports/{}'.format(datetime.datetime.now().strftime("%Y_%m%d_%H%M%S"))

## Perform M1 and M2 fits

In [None]:
num_rep_trials = 5 

location_tuples = [('Tamil Nadu', 'Chennai'), ('Delhi', None)]
for i, loc in enumerate(location_tuples):
    config_params = copy.deepcopy(config['fitting'])
    config_params['data']['dataloading_params']['state'] = loc[0]
    config_params['data']['dataloading_params']['district'] = loc[1]
    if loc[1] != 'Mumbai':
        config_params['data']['smooth_jump'] = False
    predictions_dict[loc] = {}
    for i in range(num_rep_trials):
        predictions_dict[loc][f'm{i}'] = single_fitting_cycle(**config_params) 


# predictions_dict['fitting_date'] = datetime.datetime.now().strftime("%Y-%m-%d")

In [None]:
predictions_dict['m1']['best_params']

In [None]:
predictions_dict['m2']['best_params']

## Loss Dataframes

### M1 Loss DataFrame

In [None]:
predictions_dict['m1']['df_loss']

### M2 Loss DataFrame

In [None]:
predictions_dict['m2']['df_loss']

In [None]:
wandb.init(project="covid-modelling")
wandb.run.name = "degeneracy-exps-location"+wandb.run.name
for key, loc_dict in predictions_dict.items():
    fig, ax, histograms = plot_all_histograms(loc_dict, key)
    wandb.log({f"histograms/{key[0]}_{key[1]}": [wandb.Image(fig)]})
    fig, axs, df_mean_var = plot_mean_variance(loc_dict, key)
    wandb.log({f"mean_var/{key[0]}_{key[1]}": [wandb.Image(fig)]})

In [None]:
param_distributions = {
    'E_hosp_ratio': scipy.stats.expon,
    'I_hosp_ratio': scipy.stats.gamma,
    'P_fatal': scipy.stats.beta,
    'T_inc': scipy.stats.norm,
    'T_inf': scipy.stats.norm,
    'T_recov_fatal': scipy.stats.norm,
    'T_recov_severe': scipy.stats.norm,
    'lockdown_R0': scipy.stats.norm
}

In [None]:
fig, axs = plt.subplots(nrows=len(params_dicts['m1'])//2, ncols=2, figsize=(18, 6*(len(params_dicts['m1'])//2)))
for i, param in enumerate(params_dicts['m1'].keys()):
    ax = axs.flat[i]
    kl_matrix = np.array([[entropy(histograms[run1][param]['probability'], histograms[run2][param]['probability']) for run2 in histograms.keys()] for run1 in histograms.keys()])
    sns.heatmap(kl_matrix, annot=True, xticklabels = np.arange(1, kl_matrix.shape[0]+1)*500, yticklabels = np.arange(1, kl_matrix.shape[0]+1)*500, vmax=10, ax=ax)
    ax.set_title(f'KL Divergence matrix of parameter {param}')
plt.show()
# fig.savefig('constrained-kl-matrix.png')

In [None]:
params_mean_var = copy.deepcopy(params_dicts)
for run in params_mean_var.keys():
    for param in params_mean_var[run].keys():
        params_mean_var[run][param] = np.std(params_mean_var[run][param])

In [None]:
pd.DataFrame.from_dict(params_mean_var)

In [None]:
for run in params_dicts.keys():
    fig, axs = plt.subplots(nrows=len(param_distributions)//2, ncols=2, figsize=(18, 6*(len(param_distributions)//2)))
    for i, param in enumerate(params_dicts[run].keys()):
        dist = param_distributions[param]
        param_trials = params_dicts[run][param]
        dist_fit = dist.fit(param_trials)
        sampling_points = np.linspace(np.min(param_trials), np.max(param_trials), len(param_trials))
        pdf_fitted = dist.pdf(sampling_points, *dist_fit[:-2], loc=dist_fit[-2], scale=dist_fit[-1])
        
        ax = axs.flat[i]
        ax.hist(params_dicts[run][param], density=True)
        ax.plot(sampling_points, pdf_fitted)
        ax.set_title(f'Historgram of parameter {param} for run {run}')
        ax.set_ylabel('Density')
    plt.show()

In [None]:
param_trials