In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib import cm

import geoplot as gplt
import geoplot.crs as gcrs
import seaborn as sns
import statsmodels.api as sm
import scipy.stats as stats
import matplotlib.dates as mdates

from scipy.stats import zscore

import os
import copy
import pickle
import re
from datetime import datetime, date, timedelta
from glob import glob

from utils.generic.config import read_config, make_date_key_str
from utils.generic.reichlab import *
from viz.reichlab import *

In [None]:
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
## for Palatino and other serif fonts use:
#rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)

In [None]:
us_states_abbv_df = pd.read_csv('../../data/data/us_states_abbv.csv')
us_states_abbv_dict = dict(zip(us_states_abbv_df['state'], us_states_abbv_df['state_code']))

In [None]:
first_run = '2020_1022_014310'
aug22 = '2020_1031_203437'
aug29 = '2020_1031_211038'
sept05 = '2020_1031_181912'
sept26 = '2020_1031_100231'

In [None]:
bad_dict = {}

In [None]:
predictions_pkl_filename = '/scratch/users/sansiddh/covid-modelling/2020_1111_162416/predictions_dict.pkl'
with open(predictions_pkl_filename, 'rb') as f:
    predictions_dict = pickle.load(f)

In [None]:
def full_comparison(predictions_dict, us_states_abbv_dict):
    try:
        config = predictions_dict[list(predictions_dict.keys())[0]]['m2']['run_params']
    except:
        config_filename = 'us2.yaml'
        config = read_config(config_filename)['fitting']

    loss_comp = config['loss']['loss_compartments'][0]
    data_last_date = config['split']['end_date']
    date_of_submission = (data_last_date + timedelta(days=2)).strftime('%Y-%m-%d')
    if loss_comp == 'deceased':
        comp = 'cum_death'
    if loss_comp == 'total':
        comp = 'cum_case'
    print(comp)
    print(date_of_submission)

    list_of_models = get_list_of_models(date_of_submission, comp, reichlab_path='../../../covid19-forecast-hub',
                                        num_submissions_filter=45)
    df_all_submissions = process_all_submissions(list_of_models, date_of_submission, comp, reichlab_path='../../../covid19-forecast-hub')
    df_gt, df_gt_loss, df_gt_loss_wk, loc_name_to_key_dict = process_gt(comp, df_all_submissions, reichlab_path='../../../covid19-forecast-hub')

    df_wiai_submission = format_wiai_submission(predictions_dict, df_all_submissions, loc_name_to_key_dict,
                                                which_fit='m2', use_as_point_forecast='ensemble_mean', skip_percentiles=True)
    df_all_submissions = combine_wiai_subm_with_all(df_all_submissions, df_wiai_submission, comp)

    df_comb, df_mape, df_rank = compare_gt_pred(df_all_submissions, df_gt_loss_wk)
    df_mape.drop(['Guam', 'Virgin Islands', 'Northern Mariana Islands'], axis=1, inplace=True)
    df_rank.drop(['Guam', 'Virgin Islands', 'Northern Mariana Islands'], axis=1, inplace=True)
        
    num_models = len(df_mape.median(axis=1))
    print(f'Total # of models - {num_models}')
    print(df_mape.loc[:, np.logical_not(df_mape.loc['Wadhwani_AI', :].isna())].median(axis=1).sort_values())
    print(df_rank.loc[:, np.logical_not(df_rank.loc['Wadhwani_AI', :].isna())].median(axis=1).sort_values())

    df = calculate_z_score(df_mape, df_rank, model_name='Wadhwani_AI')

    fig = create_heatmap(df, var_name='z_score', center=0)
    fig = create_heatmap(df, var_name='non_param_z_score', center=0)
    fig = create_heatmap(df, var_name='model_rank', center=num_models//2)

    df_wadhwani = combine_with_train_error(predictions_dict, df)
    
    print(f'# -ve Z score {len(df_wadhwani[df_wadhwani["z_score"] <= 0])}')
    print(f'# +ve Z score {len(df_wadhwani[df_wadhwani["z_score"] > 0])}')
    
    fig = create_scatter_plot_mape(df_wadhwani, annotate=True, abbv=True, abbv_dict=us_states_abbv_dict, log_scale=True)
    fig = create_scatter_plot_mape(df_wadhwani, annotate=True, abbv=True, abbv_dict=us_states_abbv_dict, log_scale=False)

    df_bad = df_wadhwani[df_wadhwani['z_score'] > 0]

    return date_of_submission, df_mape, df_rank, df_bad, df_wadhwani

In [None]:
predictions_dict.keys()

In [None]:
date_of_submission, df_mape, df_rank, df_bad, df_wadhwani = full_comparison(predictions_dict, us_states_abbv_dict)
bad_dict[date_of_submission] = df_bad

In [None]:
fig, axs = plt.subplots(figsize=(14, 14), nrows=2, ncols=1)
_ = create_geoplot_choropleth(df_wadhwani, var='z_score', vcenter=0, cmap='coolwarm', ax=axs.flat[0])
_ = create_geoplot_choropleth(df_wadhwani, var='non_param_z_score', vcenter=0, cmap='coolwarm', ax=axs.flat[1])
fig.colorbar(cm.ScalarMappable(norm=colors.TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1), cmap='coolwarm'), ax=axs.flat[0])
fig.colorbar(cm.ScalarMappable(norm=colors.TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1), cmap='coolwarm'), ax=axs.flat[1])

In [None]:
fig = create_geoplot_choropleth(df_wadhwani, var='non_param_z_score', vcenter=0, cmap='bwr')

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
ax.scatter(df_wadhwani['z_score'], df_wadhwani['non_param_z_score'])
ax.plot(df_wadhwani['z_score'], df_wadhwani['z_score'], '-r', label='y=x line')
ax.set_xlabel('Z Score (Mean, Std)')
ax.set_ylabel('Non Param Z Score (Median, MAD)')
ax.set_title('Non Param Z Score vs Z Score ')
ax.grid()
# for i, (index, row) in enumerate(df_wadhwani.iterrows()):
#     annot_str = us_states_abbv_dict[index]
#     ax.annotate(annot_str, (row['z_score'], row['non_param_z_score']))
ax.legend()

In [None]:
fig = create_geoplot_choropleth(df_wadhwani, var='model_rank', vcenter=13, cmap='bwr')

In [None]:
fig, axs = plt.subplots(figsize=(21, 6*15), nrows=15, ncols=3)
columns = df_mape.loc[:, np.logical_not(df_mape.loc['Wadhwani_AI', :].isna())].columns
for i, state in enumerate(columns):
    ax = axs.flat[i]
    sns.ecdfplot(data=df_mape[state], ax=ax)
    ax.axvline(df_mape.loc['Wadhwani_AI', state], ls=':', c='red', label='Wadhwani AI Submission')
#     ax.axvline(df_mape.loc['UMass-MechBayes', state], ls=':', c='maroon', label='UMass-MechBayes Submission (lowest rank)')
    ax.set_title(state)
    ax.legend()
fig.suptitle('Emperical Cumulative Distribution Function Plots for all states')
fig.subplots_adjust(top=0.97)

In [None]:
fig, axs = plt.subplots(figsize=(21, 6*15), nrows=15, ncols=3)
columns = df_mape.loc[:, np.logical_not(df_mape.loc['Wadhwani_AI', :].isna())].columns
for i, state in enumerate(columns):
    ax = axs.flat[i]
#     sm.qqplot(df_mape[state], dist=stats.norm, fit=True, line='45', ax=ax)
    sm.qqplot(df_mape[state], dist=stats.norm, loc=df_wadhwani.loc[state, 'mean_mape'], 
              scale=df_wadhwani.loc[state, 'std_mape'], line='45', ax=ax)
    ax.set_title(state)
fig.suptitle('Q-Q plots for all states')
fig.subplots_adjust(top=0.97)

In [None]:
for col_name, mapes in df_mape.loc[:, df_bad.index].iteritems():
    fig, ax = plt.subplots(figsize=(18, 2))
    sns.heatmap(mapes.to_numpy().reshape(1, -1), cmap='Reds', ax=ax, xticklabels=mapes.index, annot=True)
    ax.set_title(col_name)
    plt.show()

In [None]:
for state in df_bad.index:
    print(state)
    fig = predictions_dict[state]['m1']['plots']['fit']
    show_figure(fig)
    fig.show()
    fig = predictions_dict[state]['m2']['plots']['fit']
    show_figure(fig)
    fig.show()

In [None]:
dfs = list(bad_dict.values())
np.intersect1d(dfs[0].index, dfs[1].index)