In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import entropy
import datetime
import copy
import time
import wandb

import sys
sys.path.append('../../')

from data.processing import get_data

import models

from main.seir.fitting import single_fitting_cycle
from main.seir.forecast import get_forecast, forecast_all_trials, create_all_trials_csv, create_decile_csv_new
from main.seir.sensitivity import calculate_sensitivity_and_plot
from utils.generic.create_report import save_dict_and_create_report
from utils.generic.config import read_config
from utils.generic.enums import Columns
from utils.fitting.loss import Loss_Calculator
from utils.generic.logging import log_wandb
from viz import plot_forecast, plot_top_k_trials, plot_ptiles
from viz.fit import plot_histogram, plot_all_histograms, plot_mean_variance, plot_scatter, plot_kl_divergence, plot_heatmap_distribution_sigmas
import yaml

In [None]:
predictions_dict = {}

In [None]:
config_filename = 'default.yaml'
config = read_config(config_filename)

In [None]:
output_folder = '../../misc/reports/{}'.format(datetime.datetime.now().strftime("%Y_%m%d_%H%M%S"))

In [None]:
predictions_dict.keys()

## Perform Fits

In [None]:
location_tuples = [
    ('Maharashtra', 'Mumbai'), 
    ('Maharashtra', 'Pune'),
    ('West Bengal', 'Kolkata'),
    ('Karnataka', 'Bengaluru Urban'),
    ('Karnataka', 'Mysuru'),
    ('Delhi', None),
    ('Assam', None),
    ('Telangana', None),
    ('Tamil Nadu', 'Chennai'),
    ('Andhra Pradesh', 'East Godavari'),
    ('Andhra Pradesh', 'Chittoor'),
    ('Jharkhand', 'Ranchi'),
    ('Uttar Pradesh', 'Lucknow'),
    ('Uttar Pradesh', 'Agra'),
    ('Bihar', 'Patna'),
    ('Maharashtra', 'Nashik'),
    ('Maharashtra', 'Nagpur'),
    ('Maharashtra', 'Thane'),
    ('Gujarat', 'Ahmedabad'),
    ('Rajasthan', 'Jaipur')
]

In [None]:
num_rep_trials = 5

for i, loc in enumerate(location_tuples):
    config_params = copy.deepcopy(config['fitting'])
    config_params['data']['dataloading_params']['state'] = loc[0]
    config_params['data']['dataloading_params']['district'] = loc[1]
    if loc[1] != 'Mumbai':
        config_params['data']['smooth_jump'] = False
    predictions_dict[loc] = {}
    for i in range(num_rep_trials):
        predictions_dict[loc][f'm{i}'] = single_fitting_cycle(**config_params) 


In [None]:
predictions_dict

In [None]:
wandb.init(project="covid-modelling")
wandb.run.name = "degeneracy-exps-location"+wandb.run.name

In [None]:
mean_var_dict = {}
histograms_dict = {}
for key, loc_dict in predictions_dict.items():
    fig, ax, histograms_dict[key] = plot_all_histograms(loc_dict, key)
    wandb.log({f"histograms/{key[0]}_{key[1]}": [wandb.Image(fig)]})
    fig, axs, mean_var_dict[key] = plot_mean_variance(loc_dict, key)
    wandb.log({f"mean_var/{key[0]}_{key[1]}": [wandb.Image(fig)]})

In [None]:
fig, axs = plot_scatter(mean_var_dict, 'E_hosp_ratio', 'I_hosp_ratio')
wandb.log({f"scatter/{'E_hosp_ratio'}_{'I_hosp_ratio'}": [wandb.Image(fig)]})
fig, axs = plot_scatter(mean_var_dict, 'T_recov_fatal', 'P_fatal')
wandb.log({f"scatter/{'T_recov_fatal'}_{'P_fatal'}": [wandb.Image(fig)]})

In [None]:
kl_dict = {}
for key, histograms in histograms_dict.items():
    fig, axs, kl_dict[key] = plot_kl_divergence(histograms, key)
    wandb.log({f"kl_divergence/{key[0]}_{key[1]}": [wandb.Image(fig)]})

In [None]:
fig, df_comparison = plot_heatmap_distribution_sigmas(mean_var_dict, stat_measure='mean')
wandb.log({f"sigma_by_mu/mean": [wandb.Image(fig)]})

In [None]:
df_comparison