In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np

import sys
sys.path.append('../../')

import os
import copy
import pickle
from datetime import datetime, date, timedelta

from utils.generic.config import read_config, make_date_key_str
from utils.generic.reichlab import *
from viz.reichlab import *
from viz import plot_ptiles
from viz.uncertainty import plot_ptiles_reichlab

In [None]:
predictions_pkl_filename = '/scratch/users/sansiddh/covid-modelling/2020_1214_143227_comb/predictions_dict.pkl'
with open(predictions_pkl_filename, 'rb') as f:
    predictions_dict_d = pickle.load(f)

In [None]:
predictions_pkl_filename = '/scratch/users/sansiddh/covid-modelling/2020_1214_155731_comb/predictions_dict.pkl'
with open(predictions_pkl_filename, 'rb') as f:
    predictions_dict_t = pickle.load(f)

In [None]:
places_to_prune = ['Guam', 'Northern Mariana Islands', 'Virgin Islands']
for key in places_to_prune:
    del predictions_dict_d[key]
    del predictions_dict_t[key]

In [None]:
loc_name_to_key_dict = get_mapping(which='location_name_to_code')
us_states_abbv_dict = get_mapping(which='location_name_to_abbv')

In [None]:
config = predictions_dict_d[list(predictions_dict_d.keys())[0]]['m2']['run_params']
loss_comp = config['loss']['loss_compartments'][0]
data_last_date = config['split']['end_date']
date_of_submission = (data_last_date + timedelta(days=2)).strftime('%Y-%m-%d')
if loss_comp == 'deceased':
    comp = 'cum_death'
if loss_comp == 'total':
    comp = 'cum_case'
print(comp)
print(date_of_submission)

In [None]:
df_wiai_submission_death = format_wiai_submission(predictions_dict_d, loc_name_to_key_dict, which_fit='m2', formatting_mode='submission',                                
                                                  use_as_point_forecast='ensemble_mean', which_comp='death', skip_percentiles=False)

In [None]:
df_wiai_submission_cases = format_wiai_submission(predictions_dict_t, loc_name_to_key_dict, which_fit='m2', formatting_mode='submission', 
                                                  use_as_point_forecast='ensemble_mean', which_comp='case', skip_percentiles=False)

In [None]:
df_wiai_submission_death = df_wiai_submission_death[df_wiai_submission_death['target'].apply(lambda x: 'death' in x)]
df_wiai_submission_cases = df_wiai_submission_cases[df_wiai_submission_cases['target'].apply(lambda x: 'inc case' in x)]

In [None]:
df_wiai_submission_cases = df_wiai_submission_cases[np.logical_or(
    (df_wiai_submission_cases['type'] == 'point'), 
    (df_wiai_submission_cases['quantile'].isin([0.025, 0.100, 0.250, 0.500, 0.750, 0.900, 0.975])))]

In [None]:
df_wiai_submission_death.reset_index(inplace=True, drop=True)
df_wiai_submission_cases.reset_index(inplace=True, drop=True)

In [None]:
df_wiai_submission_comb = pd.concat([df_wiai_submission_death, df_wiai_submission_cases], ignore_index=True)
df_wiai_submission_comb

In [None]:
places_to_prune = ['Guam', 'Northern Mariana Islands', 'Virgin Islands']
places_to_prune_code = [loc_name_to_key_dict[x] for x in places_to_prune]

In [None]:
df_wiai_submission_comb = df_wiai_submission_comb[np.logical_not(df_wiai_submission_comb['location'].isin(places_to_prune_code))]
df_wiai_submission_comb.reset_index(inplace=True, drop=True)

In [None]:
from utils.generic.enums import Columns
from matplotlib.lines import Line2D
from adjustText import adjust_text
from viz.utils import axis_formatter

In [None]:
def plot_ptiles_reichlab(df_comb, model, location, target='inc death', plot_true=False, plot_point=True, plot_individual_curves=True):
    compartment = 'deceased' if 'death' in target else 'total'
    mode = 'incident' if 'inc' in target else 'cumulative'
    compartment = Columns.from_name(compartment)
    df_plot = copy(df_comb.loc[(df_comb['model'] == model) & (df_comb['location'] == location), :])
    df_plot = df_plot[[target in x for x in df_plot['target']]]
    fig, ax = plt.subplots(figsize=(12, 12))
    texts = []
    if plot_true:
        df_true = df_plot.groupby('target_end_date').mean().reset_index()
        ax.plot(df_true['target_end_date'].to_numpy(), df_true['true_value'].to_numpy(),
                '--o', color=compartment.color)
    if plot_point:
        df_point = df_plot[df_plot['type'] == 'point']
        ax.plot(df_point['target_end_date'].to_numpy(), df_point['value'].to_numpy(),
                '-o', color='black')
        
    df_quantiles = df_plot[df_plot['type'] == 'quantile']
    quantiles = df_quantiles.groupby('quantile').sum().index
    if plot_individual_curves:
        for _, qtile in enumerate(quantiles):
            df_qtile = df_quantiles[df_quantiles['quantile'] == qtile].infer_objects()
            label = round(qtile*100) if qtile*100 % 1 < 1e-8 else round(qtile*100, 1)
            sns.lineplot(x='target_end_date', y='value', data=df_qtile, ls='-')
            texts.append(plt.text(
                x=df_qtile['target_end_date'].iloc[-1], 
                y=df_qtile['value'].iloc[-1], s=label))
    else:
        sns.lineplot(x=Columns.date.name, y='value', data=df_quantiles,
                        ls='-', label=f'{compartment.label}')
            

    ax.set_xlim(ax.get_xlim()[0], ax.get_xlim()[1] + 10)
    adjust_text(texts, arrowprops=dict(arrowstyle="->", color='r', lw=0.5))
    axis_formatter(ax)
    legend_elements = []
    if plot_true:
        legend_elements += [
            Line2D([0], [0], ls='--', marker='o', color=compartment.color,
                   label=f'{mode.title()} {compartment.label} (Observed)')]
    if plot_point:
        legend_elements += [
            Line2D([0], [0], ls='-', marker='o', color='black',
                   label=f'{mode.title()} {compartment.label} Point Forecast')]

    legend_elements += [
        Line2D([0], [0], ls='-', color='blue', 
               label=f'{mode.title()} {compartment.label} Percentiles'),
    ]
    ax.legend(handles=legend_elements)
    fig.suptitle('Forecast for {}, {}, {} {}'.format(model, location, 
                                                     mode.title(), compartment.label), fontsize=16)
    fig.subplots_adjust(top=0.96)
    
    return fig, ax

In [None]:
fig, ax = plot_ptiles_reichlab(df_wiai_submission_comb, 'Wadhwani_AI-BayesOpt', '18', target='inc death')

In [None]:
del df_wiai_submission_comb['model']

In [None]:
forecast_date = df_wiai_submission_comb.loc[0, 'forecast_date'].strftime('%Y-%m-%d')
model_name = 'Wadhwani_AI-BayesOpt'

In [None]:
df_wiai_submission_comb.to_csv(f'../../../covid19-forecast-hub/data-processed/{model_name}/{forecast_date}-{model_name}.csv', index=False)