In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import geoplot as gplt
import geoplot.crs as gcrs
import seaborn as sns
import matplotlib.dates as mdates

from scipy.stats import zscore

import os
import copy
import pickle
import re
from datetime import datetime, date, timedelta
from glob import glob

from utils.generic.config import read_config, make_date_key_str
from utils.generic.reichlab import *
from viz.reichlab import *

In [None]:
first_run = '2020_1022_014310'
aug22 = '2020_1031_203437'
aug29 = '2020_1031_211038'
sept05 = '2020_1031_181912'

In [None]:
bad_dict = {}

In [None]:
predictions_pkl_filename = '/scratch/users/sansiddh/covid-modelling/2020_1031_211038/predictions_dict.pkl'
with open(predictions_pkl_filename, 'rb') as f:
    predictions_dict = pickle.load(f)

In [None]:
def full_comparison(predictions_dict):
    config = predictions_dict[list(predictions_dict.keys())[0]]['m2']['run_params']
    loss_comp = config['loss']['loss_compartments'][0]
    data_last_date = config['split']['end_date']

    date_of_submission = (data_last_date + timedelta(days=2)).strftime('%Y-%m-%d')
    if loss_comp == 'deceased':
        comp = 'cum_death'
    if loss_comp == 'total':
        comp = 'inc_case'
        
    print(date_of_submission)
    print(comp)

    list_of_models = get_list_of_models(date_of_submission, comp, reichlab_path='../../../covid19-forecast-hub')
    df_all_submissions = process_all_submissions(list_of_models, date_of_submission, comp, reichlab_path='../../../covid19-forecast-hub')
    df_gt, df_gt_loss, df_gt_loss_wk, loc_name_to_key_dict = process_gt(comp, df_all_submissions, reichlab_path='../../../covid19-forecast-hub')

    df_wiai_submission = format_wiai_submission(predictions_dict, df_all_submissions, loc_name_to_key_dict, use_as_point_forecast='best')
    df_all_submissions = combine_wiai_subm_with_all(df_all_submissions, df_wiai_submission, comp)

    df_comb, df_mape, df_rank = compare_gt_pred(df_all_submissions, df_gt_loss_wk)

    num_models = len(df_mape.median(axis=1))
    print(f'Total # of models - {num_models}')
    print(df_mape.median(axis=1).sort_values())
    print(df_rank.median(axis=1).sort_values())

    df = calculate_z_score(df_mape, df_rank, model_name='Wadhwani_AI')

    fig = create_heatmap(df, var_name='z_score', center=0)
    fig = create_heatmap(df, var_name='model_rank', center=num_models//2)

    df_wadhwani = combine_with_train_error(predictions_dict, df)
    
    print(f'# -ve Z score {len(df_wadhwani[df_wadhwani["z_score"] <= 0])}')
    print(f'# +ve Z score {len(df_wadhwani[df_wadhwani["z_score"] > 0])}')
    
    fig = create_scatter_plot_mape(df_wadhwani, annotate=True)
    fig = create_scatter_plot_mape(df_wadhwani, annotate=False)

    df_bad = df_wadhwani[df_wadhwani['z_score'] > 0]

    return date_of_submission, df_bad

In [None]:
date_of_submission, df_bad = full_comparison(predictions_dict)
bad_dict[date_of_submission] = df_bad

In [None]:
for col_name, mapes in df_mape.loc[:, df_bad.index].iteritems():
    fig, ax = plt.subplots(figsize=(18, 2))
    sns.heatmap(mapes.to_numpy().reshape(1, -1), cmap='Reds', ax=ax, xticklabels=mapes.index, annot=True)
    ax.set_title(col_name)
    plt.show()

In [None]:
for state in df_bad.index:
    print(state)
    fig = predictions_dict[state]['m1']['plots']['fit']
    show_figure(fig)
    fig.show()
    fig = predictions_dict[state]['m2']['plots']['fit']
    show_figure(fig)
    fig.show()