In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib import cm

import sys
sys.path.append('../../')

import geoplot as gplt
import geoplot.crs as gcrs
import seaborn as sns
import statsmodels.api as sm
import scipy.stats as stats
import matplotlib.dates as mdates

from scipy.stats import zscore
from shapely.geometry import MultiPolygon
from tabulate import tabulate

from utils.fitting.loss import Loss_Calculator

import os
import copy
import pickle
import re
from datetime import datetime, date, timedelta
from glob import glob

from utils.generic.config import read_config, make_date_key_str
from utils.generic.reichlab import *
from viz.reichlab import *

In [None]:
us_states_abbv_df = pd.read_csv('../../data/data/us_states_abbv.csv')
us_states_abbv_dict = dict(zip(us_states_abbv_df['state'], us_states_abbv_df['state_code']))

In [None]:
predictions_pkl_filename = '/scratch/users/sansiddh/covid-modelling/2020_1111_162416/predictions_dict.pkl'
with open(predictions_pkl_filename, 'rb') as f:
    predictions_dict = pickle.load(f)

In [None]:
def full_comparison(predictions_dict, us_states_abbv_dict):
    try:
        config = predictions_dict[list(predictions_dict.keys())[0]]['m2']['run_params']
    except:
        config_filename = 'us2.yaml'
        config = read_config(config_filename)['fitting']

    loss_comp = config['loss']['loss_compartments'][0]
    data_last_date = config['split']['end_date']
    date_of_submission = (data_last_date + timedelta(days=2)).strftime('%Y-%m-%d')
    if loss_comp == 'deceased':
        comp = 'cum_death'
    if loss_comp == 'total':
        comp = 'cum_case'
    print(comp)
    print(date_of_submission)

    list_of_models = get_list_of_models(date_of_submission, comp, reichlab_path='../../../covid19-forecast-hub',
                                        num_submissions_filter=45)
    df_all_submissions = process_all_submissions(list_of_models, date_of_submission, comp, reichlab_path='../../../covid19-forecast-hub')
    df_gt, df_gt_loss, df_gt_loss_wk, loc_name_to_key_dict = process_gt(comp, df_all_submissions, reichlab_path='../../../covid19-forecast-hub')

    df_wiai_submission = format_wiai_submission(predictions_dict, df_all_submissions, loc_name_to_key_dict,
                                                which_fit='m2', use_as_point_forecast='ensemble_mean', skip_percentiles=False)
    df_all_submissions = combine_wiai_subm_with_all(df_all_submissions, df_wiai_submission, comp)

    df_comb, df_mape, df_rank = compare_gt_pred(df_all_submissions, df_gt_loss_wk)
    df_mape.drop(['Guam', 'Virgin Islands', 'Northern Mariana Islands'], axis=1, inplace=True)
    df_rank.drop(['Guam', 'Virgin Islands', 'Northern Mariana Islands'], axis=1, inplace=True)
        
    num_models = len(df_mape.median(axis=1))
    print(f'Total # of models - {num_models}')
    median_mape = df_mape.loc[:, np.logical_not(df_mape.loc['Wadhwani_AI', :].isna())].median(axis=1).rename('median_mape')
    median_rank = df_rank.loc[:, np.logical_not(df_rank.loc['Wadhwani_AI', :].isna())].median(axis=1).rename('median_rank')
    merged = pd.concat([median_mape, median_rank], axis=1)

    df = calculate_z_score(df_mape, df_rank, model_name='Wadhwani_AI')

    fig = create_heatmap(df, var_name='z_score', center=0)
    fig = create_heatmap(df, var_name='non_param_z_score', center=0)

    df_wadhwani = combine_with_train_error(predictions_dict, df)
    
    print(f'# -ve Z score {len(df_wadhwani[df_wadhwani["z_score"] <= 0])}')
    print(f'# +ve Z score {len(df_wadhwani[df_wadhwani["z_score"] > 0])}')
    
    print(f'# -ve non param Z score {len(df_wadhwani[df_wadhwani["non_param_z_score"] <= 0])}')
    print(f'# +ve non param Z score {len(df_wadhwani[df_wadhwani["non_param_z_score"] > 0])}')
    
    fig = create_scatter_plot_mape(df_wadhwani, annotate=True, abbv=True, abbv_dict=us_states_abbv_dict, 
                                   stat_metric_to_use='z_score', log_scale=True)
    fig = create_scatter_plot_mape(df_wadhwani, annotate=True, abbv=True, abbv_dict=us_states_abbv_dict, 
                                   stat_metric_to_use='non_param_z_score', log_scale=True)

    return date_of_submission, df_comb, df_mape, df_rank, df_wadhwani

In [None]:
date_of_submission, df_comb, df_mape, df_rank, df_wadhwani = full_comparison(predictions_dict, us_states_abbv_dict)

In [None]:
merged = create_performance_table(df_mape, df_rank)
x = datetime.strptime(date_of_submission, '%Y-%m-%d')
data_last_date = (x - timedelta(days=2))
print('Data last date -  {}'.format(data_last_date.strftime('%Y-%m-%d')))
print('Test period till -  {}'.format((data_last_date + timedelta(days=28)).strftime('%Y-%m-%d')))
merged

In [None]:
config_filename = 'us4.yaml'
config = read_config(config_filename)

wandb_config = read_config(config_filename, preprocess=False)
wandb_config = make_date_key_str(wandb_config)

In [None]:
uncertainty_args = {'predictions_dict': predictions_dict['Texas'], 'fitting_config': config['fitting'],
                    'forecast_config': config['forecast'], 'process_trials': False, **config['uncertainty']['uncertainty_params']}
uncertainty = config['uncertainty']['method'](**uncertainty_args)

In [None]:
uncertainty_forecasts = uncertainty.get_forecasts()

In [None]:
lc = Loss_Calculator()
df_comb['perc_loss_ape'] = np.nan
for i, row in df_comb.iterrows():
    if row['type'] == 'quantile':
        df_comb.loc[i, 'perc_loss_ape'] = lc._calc_mape_perc(np.array([row['forecast_value']]), np.array([row['true_value']]), row['quantile'])

In [None]:
df_temp = df_comb[df_comb['quantile'] == 0.05]

df_mape = df_temp.groupby(['model', 'location',
                           'location_name']).mean().reset_index()
    
df_mape = df_mape.pivot(index='model', columns='location_name', 
                        values='ape')

df_rank = df_mape.rank()

In [None]:
create_performance_table(df_mape, df_rank)

In [None]:
gdf = preprocess_shape_file(filename='cb_2018_us_state_5m/cb_2018_us_state_5m.shp')

In [None]:
vars_to_plot = {
    'non_param_z_score' : {'cmap':'RdYlGn_r', 'vmin':-1, 'vcenter':0, 'vmax':1},
    'model_rank' : {'cmap':'Purples', 'vmin':0, 'vcenter':13, 'vmax':26},
}
fig, axs = plot_multiple_choropleths(df_wadhwani, gdf, vars_to_plot)

In [None]:
vars_to_plot = {
    'z_score' : {'cmap':'RdYlGn_r', 'vmin':-1, 'vcenter':0, 'vmax':1},
    'model_rank' : {'cmap':'Purples', 'vmin':0, 'vcenter':13, 'vmax':26},
}
fig, axs = plot_multiple_choropleths(df_wadhwani, gdf, vars_to_plot)

In [None]:
fig, ax = create_scatter_plot_zscores(df_wadhwani)

In [None]:
fig, axs = plot_ecdf_all_states(df_mape)

In [None]:
fig, axs = plot_qq_all_states(df_mape, fit=False, df_wadhwani=df_wadhwani)