In [None]:
import os
import statistics as st
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

In [None]:
# countries = ["Brazil", "Germany", "Spain", "France", "Britain", "India", "Italy", "Russia", "Turkey", "USA"]
continents = ["Africa", "North America", "South America", "Oceania", "Eastern Europe", "Western Europe", "Middle East", "South Asia", "Southeast-East Asia", "Central Asia"]
ANALYSIS_DIR = "2022-04-16-00_41_22"
NUM_MODELS = 100
MODEL_NAME = "DCSAGE"
WINDOW_SIZE = 7
REC_PRED_LEN = 30

In [None]:
# if not os.path.exists("./analysis-runs-multiple-models/" + ANALYSIS_DIR):
# os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR)
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/rec_vs_pert_rec_1window_1model")
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/rec_vs_pert_rec_after_bias_correct")
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/bias_correction_visuals")
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/ev_unfiltered_fits_after_bias_correct")
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/ev_based_plots_after_bias_correct_signed")
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/ev_based_plots_after_bias_correct_unsigned")
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/normal_unfiltered_fits_after_bias_correct")
os.mkdir("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/normal_based_plots_after_bias_correct")

# Loading Perturbed Predictions from Models
These cells load roll_win_pert_preds.npy, which is the perturbed recursive predictions saved
by the node perturbation analysis file when 10 countries are perturbed. We load it here and 
turn it from a Numpy array back into a Python list.

Only load this if you need the raw unperturbed and perturbed recursive preedictions. If you 
just need to make sensitivity plots, skip below to where the sensitivity 10x10 array is loaded
into this notebook.

In [None]:
roll_win_pert_pred_nested_np = np.load("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/prediction_saves/roll_win_pert_preds.npy")
roll_win_pert_pred_nested_list = list(roll_win_pert_pred_nested_np)
roll_win_pert_pred_nested_list = [list(np_arr) for np_arr in roll_win_pert_pred_nested_list]

for rolling_window_idx in range(len(roll_win_pert_pred_nested_list)):
    for model_idx in range(len(roll_win_pert_pred_nested_list[0])):
        roll_win_pert_pred_nested_list[rolling_window_idx][model_idx] = list(roll_win_pert_pred_nested_list[rolling_window_idx][model_idx])

for rolling_window_idx in range(len(roll_win_pert_pred_nested_list)):
    for model_idx in range(len(roll_win_pert_pred_nested_list[0])):
        for perturbed_idx in range(len(roll_win_pert_pred_nested_list[0][0])):
            roll_win_pert_pred_nested_list[rolling_window_idx][model_idx][perturbed_idx] = list(roll_win_pert_pred_nested_list[rolling_window_idx][model_idx][perturbed_idx])

for rolling_window_idx in range(len(roll_win_pert_pred_nested_list)):
    for model_idx in range(len(roll_win_pert_pred_nested_list[0])):
        for perturbed_idx in range(len(roll_win_pert_pred_nested_list[0][0])):
            for country_idx in range(len(roll_win_pert_pred_nested_list[0][0][0])):
                roll_win_pert_pred_nested_list[rolling_window_idx][model_idx][perturbed_idx][country_idx] = pd.DataFrame(data=roll_win_pert_pred_nested_list[rolling_window_idx][model_idx][perturbed_idx][country_idx], 
                columns=["Regular Predictions", "Ground Truth", "Extended Recursive Predictions", "Day Index"])

In [None]:
# Check types and check that dataframe at end has correct columns
print(roll_win_pert_pred_nested_np.shape)
print(type(roll_win_pert_pred_nested_list))
print(type(roll_win_pert_pred_nested_list[0]))
print(type(roll_win_pert_pred_nested_list[0][0]))
print(type(roll_win_pert_pred_nested_list[0][0][0]))
print(type(roll_win_pert_pred_nested_list[0][0][0][0]))
print(roll_win_pert_pred_nested_list[0][0][0][0].shape)
# print(roll_win_pert_pred_nested_list[0][0][0][0].head())

### Loading Unperturbed Predictions for Models
These cells load roll_win_unpert_preds.npy, which is the unperturbed recursive predictions saved
by the node perturbation analysis file. We load it here and turn it from a Numpy array back
into a Python list.

Since there are 10 country to perturb but only 1 unperturbed dataloader, this array will have
1 less dimension that roll_win_pert_preds.npy.

Note: In MPNN codebase, roll_win_unpert_preds and roll_win_pert_preds are combined into one
since array called complete_info_list.

In [None]:
roll_win_unpert_pred_nested_np = np.load("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/prediction_saves/roll_win_unpert_preds.npy")
roll_win_unpert_pred_nested_list = list(roll_win_unpert_pred_nested_np)
roll_win_unpert_pred_nested_list = [list(np_arr) for np_arr in roll_win_unpert_pred_nested_list]

for rolling_window_idx in range(len(roll_win_unpert_pred_nested_list)):
    for model_idx in range(len(roll_win_unpert_pred_nested_list[0])):
        roll_win_unpert_pred_nested_list[rolling_window_idx][model_idx] = list(roll_win_unpert_pred_nested_list[rolling_window_idx][model_idx])

for rolling_window_idx in range(len(roll_win_unpert_pred_nested_list)):
    for model_idx in range(len(roll_win_unpert_pred_nested_list[0])):
        for country_idx in range(len(roll_win_unpert_pred_nested_list[0][0])):
            roll_win_unpert_pred_nested_list[rolling_window_idx][model_idx][country_idx] = pd.DataFrame(data=roll_win_unpert_pred_nested_list[rolling_window_idx][model_idx][country_idx], 
            columns=["Regular Predictions", "Ground Truth", "Extended Recursive Predictions", "Day Index"])

In [None]:
# Check types and check that dataframe at end has correct columns
print(roll_win_unpert_pred_nested_np.shape)
print(type(roll_win_unpert_pred_nested_list))
print(type(roll_win_unpert_pred_nested_list[0]))
print(type(roll_win_unpert_pred_nested_list[0][0]))
print(type(roll_win_unpert_pred_nested_list[0][0][0]))
print(roll_win_unpert_pred_nested_list[0][0][0].shape)
# print(roll_win_unpert_pred_nested_list[0][0][0].head())

# Plot Recursive Coverage and Calculate Bias Correction

In [None]:
from dataloader.node_perturbation_dataloader import Covid10CountriesUnperturbedDataset

def plot_recursive_prediction_model_coverage_with_starting_input(roll_win_unpert_pred_nested_list, rolling_window):
    """
    This function plots the 5x2 figure where each subplot represents one node. For each subplot, the ground
    truth for the node is plotted as a thick orange line, while the 100 other lines colored in black represent
    the recursive predictions of the 100 models for that node.

    Intent of this plot is to show coverage of models on the ground truth of the 
    
    Args:
        - roll_win_unpert_pred_nested_list: (523, 100, 10, 30, 4)
        - rolling window: The window for which to create this plot
    """
    dataset = Covid10CountriesUnperturbedDataset(
        dataset_npz_path="/Users/syedrizvi/Desktop/Projects/GNN_Project/DCSAGE/Node-Perturbation/datasets/10_continents_dataset_v19_node_pert.npz",
        window_size=WINDOW_SIZE, 
        data_split="entire-dataset-smooth", 
        avg_graph_structure=False)
    
    assert len(dataset.all_window_edge_attr) - WINDOW_SIZE - REC_PRED_LEN == len(roll_win_unpert_pred_nested_list), "Inconsistent window counts."
    
    first_window_cases = np.zeros((len(roll_win_unpert_pred_nested_list) + WINDOW_SIZE, 10))
    first_window_cases[:WINDOW_SIZE,:] = dataset.all_window_node_feat[rolling_window,:,:,1]  # Shape [7, 10]
    first_window_cases[first_window_cases == 0] = np.nan

    # Make dataset to show first WINDOW_SIZE days that went into model
    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(40, 16))
    plt.rcParams.update({'font.size': 20})
    fig.suptitle("DCSAGE Unperturbed Recursive Predictions Coverage (Window {})".format(rolling_window), fontsize= 30)

    idx = 0
    for row in ax:
        for col in row:
            visual_dict = {}
            for i in range(NUM_MODELS):
                vals = np.array(roll_win_unpert_pred_nested_list[rolling_window][i][idx]["Extended Recursive Predictions"])
                vals = np.pad(vals, (WINDOW_SIZE, 0), "constant", constant_values=(0,0))
                vals[vals == 0] = np.nan
                visual_dict["Model {}".format(i)] = vals

            visual_dict["Day Index"] = list(range(-1 * WINDOW_SIZE, len(roll_win_unpert_pred_nested_list[rolling_window][0][idx]["Extended Recursive Predictions"])))
            visual_df = pd.DataFrame(visual_dict)

            # Ground truth is same for all models, pick from first model
            vals2 = np.array(roll_win_unpert_pred_nested_list[rolling_window][0][idx]["Ground Truth"])
            vals2 = np.pad(vals2, (WINDOW_SIZE, 0), "constant", constant_values=(0,0))
            vals2[vals2 == 0] = np.nan
            visual_dict2 = {"Ground Truth": vals2}
            visual_dict2["Day Index"] = list(range(-1 * WINDOW_SIZE, len(roll_win_unpert_pred_nested_list[rolling_window][0][idx]["Ground Truth"]))) 
            visual_df2 = pd.DataFrame(visual_dict2)

            visual_df3 = pd.DataFrame({
                "Starting Input": first_window_cases[:,idx],
                "Day Index": list(range(-1 * WINDOW_SIZE, len(first_window_cases) - WINDOW_SIZE))
            })

            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df, ['Day Index']), palette=['gray'] * NUM_MODELS)
            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df2, ['Day Index']), linewidth = 8, palette=['orange'])
            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df3, ['Day Index']), linewidth = 8, palette=['blue'])
            col.set_title(continents[idx])
            col.set_ylim([0, 6])
            box = col.get_position()
            col.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            col.legend().remove()
            idx += 1

    plt.savefig("./window{}_{}models_rec_pred_coverage.png".format(rolling_window, NUM_MODELS), bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
for rolling_window in range(0, len(roll_win_unpert_pred_nested_list), 50):
    plot_recursive_prediction_model_coverage_with_starting_input(roll_win_unpert_pred_nested_list, rolling_window)

### Bias correction code, correction factor is average of daywise ratio of ground truth over median curve.

Note: in np array, pd DataFrame column indices are:

columns=["Regular Predictions", "Ground Truth", "Extended Recursive Predictions", "Day Index"]

In [None]:
def bias_correction(window_prediction_np_array, window_idx):
    """
    This function accepts a numpy array of unperturbed predictions on 1 window by all models, and returns 
    10 numbers representing the bias corrections for 10 continents for the specified window.

    Bias correction is computed by putting ground truth on x-axis and mean recursive prediction on y-axis
    for each continent, and then calculating slope of the correlation plot.

    Args:
        - window_prediction_np_array: unperturbed predictions for all models on one window, 
            shape (100, 10, 30, 4)
    """
    ground_truth = window_prediction_np_array[:,:,:,1]  # shape (100, 10, 30)
    ground_truth = ground_truth[0,:,:]  # Ground truth same for all models, get from 1st model. Shape (10, 30)

    all_model_recursive_preds = window_prediction_np_array[:,:,:,2]  # shape (100, 10, 30)
    median_recursive_pred = np.median(all_model_recursive_preds, axis=0)  # shape (10, 30)
    # mean_recursive_pred = np.mean(all_model_recursive_preds, axis=0)  # shape (10, 30)

    # fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(30,12))
    # fig.suptitle("Window {} Median Prediction and Ground Truth Curve".format(window_idx), fontsize=30)
    # node_idx = 0
    # for row in ax:
    #     for col in row:
    #         # Ground truth is same for all models, so index model 0. Will be [0,node_idx,:]
    #         col.plot(median_recursive_pred[node_idx, :], label=continents[node_idx] + " Median Recursive Pred")
    #         col.plot(ground_truth[node_idx, :], label="Ground Truth")
    #         col.set_title(continents[node_idx])
    #         col.set_xlabel("Day Index")
    #         col.set_ylabel("Number of Cases (Log10)")
    #         col.legend()
    #         col.set_ylim([0, 6])
    #         node_idx += 1

    # plt.ylim(0,6)
    # plt.savefig("window{}_median_curves".format(window_idx), bbox_inches="tight", facecolor="white")
    # plt.show()
    # plt.clf()
    # plt.close()

    daywise_ratios = ground_truth / median_recursive_pred  # shape (10, 30)
    # print(daywise_ratios[:,0])
    averaged_ratios = np.mean(daywise_ratios, axis=1)  # shape (10,)
    return averaged_ratios
    

In [None]:
# Pass numpy array instead of Python lists, makes indexing easier.
rolling_window_bias_corrections = []
for window in range(0, len(roll_win_unpert_pred_nested_np)):
    window_bias_corrections = bias_correction(roll_win_unpert_pred_nested_np[window], window)
    rolling_window_bias_corrections.append(window_bias_corrections)

rolling_window_bias_corrections = np.array(rolling_window_bias_corrections)
print(rolling_window_bias_corrections.shape)

window_bias_corrections = np.mean(rolling_window_bias_corrections, axis=0)
print(window_bias_corrections.shape)  # Want 1 bias correction for each continent

### Plot corrected recursive prediction coverage plot

In [None]:
def plot_corrected_recursive_prediction_model_coverage_with_starting_input(roll_win_unpert_pred_nested_list, rolling_window):
    """
    This function plots the 5x2 figure where each subplot represents one node. For each subplot, the ground
    truth for the node is plotted as a thick orange line, while the 100 other lines colored in black represent
    the recursive predictions of the 100 models for that node.

    Bias corrections are applied on this plot, according to the average ratio on a window between ground truth and recursive prediction median/mean curve.
    
    Args:
        - roll_win_unpert_pred_nested_list: (523, 100, 10, 30, 4)
        - rolling window: The window for which to create this plot
    """
    dataset = Covid10CountriesUnperturbedDataset(
        dataset_npz_path="/Users/syedrizvi/Desktop/Projects/GNN_Project/DCSAGE/Node-Perturbation/datasets/10_continents_dataset_v19_node_pert.npz",
        window_size=WINDOW_SIZE, 
        data_split="entire-dataset-smooth", 
        avg_graph_structure=False)
    
    assert len(dataset.all_window_edge_attr) - WINDOW_SIZE - REC_PRED_LEN == len(roll_win_unpert_pred_nested_list), "Inconsistent window counts."
    
    first_window_cases = np.zeros((len(roll_win_unpert_pred_nested_list) + WINDOW_SIZE, 10))
    first_window_cases[:WINDOW_SIZE,:] = dataset.all_window_node_feat[rolling_window,:,:,1]  # Shape [7, 10]
    first_window_cases[first_window_cases == 0] = np.nan

    # Make dataset to show first WINDOW_SIZE days that went into model
    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(40, 16))
    plt.rcParams.update({'font.size': 20})
    fig.suptitle("DCSAGE Unperturbed Median Ratio Corrected Recursive Predictions Coverage (Window {})".format(rolling_window), fontsize= 30)

    idx = 0
    for row in ax:
        for col in row:
            visual_dict = {}
            for i in range(NUM_MODELS):
                vals = np.array(roll_win_unpert_pred_nested_list[rolling_window][i][idx]["Extended Recursive Predictions"])  # Shape (30,)
                vals *= window_bias_corrections[idx]
                vals = np.pad(vals, (WINDOW_SIZE, 0), "constant", constant_values=(0,0))
                vals[vals == 0] = np.nan
                visual_dict["Model {}".format(i)] = vals

            visual_dict["Day Index"] = list(range(-1 * WINDOW_SIZE, len(roll_win_unpert_pred_nested_list[rolling_window][0][idx]["Extended Recursive Predictions"])))
            visual_df = pd.DataFrame(visual_dict)

            # Ground truth is same for all models, pick from first model
            vals2 = np.array(roll_win_unpert_pred_nested_list[rolling_window][0][idx]["Ground Truth"])
            vals2 = np.pad(vals2, (WINDOW_SIZE, 0), "constant", constant_values=(0,0))
            vals2[vals2 == 0] = np.nan
            visual_dict2 = {"Ground Truth": vals2}
            visual_dict2["Day Index"] = list(range(-1 * WINDOW_SIZE, len(roll_win_unpert_pred_nested_list[rolling_window][0][idx]["Ground Truth"]))) 
            visual_df2 = pd.DataFrame(visual_dict2)

            visual_df3 = pd.DataFrame({
                "Starting Input": first_window_cases[:,idx],
                "Day Index": list(range(-1 * WINDOW_SIZE, len(first_window_cases) - WINDOW_SIZE))
            })

            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df, ['Day Index']), palette=['gray'] * NUM_MODELS)
            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df2, ['Day Index']), linewidth = 8, palette=['orange'])
            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df3, ['Day Index']), linewidth = 8, palette=['blue'])
            col.set_title(continents[idx])
            col.set_ylim([0, 6])
            box = col.get_position()
            col.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            col.legend().remove()
            idx += 1

    plt.savefig("./window{}_corrected_{}models_rec_pred_coverage.png".format(rolling_window, NUM_MODELS), bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
# Make sure to use same window that bias correction was calculated on
for window in range(0, len(roll_win_unpert_pred_nested_list), 50):
    plot_corrected_recursive_prediction_model_coverage_with_starting_input(roll_win_unpert_pred_nested_list, rolling_window=window)

# Apply bias correction to all rolling window prediction lists

In [None]:
for rolling_window_idx in range(len(roll_win_pert_pred_nested_list)):
    for model_idx in range(len(roll_win_pert_pred_nested_list[0])):
        for perturbed_idx in range(len(roll_win_pert_pred_nested_list[0][0])):
            for country_idx in range(len(roll_win_pert_pred_nested_list[0][0][0])):
                # (523, 100, 10, 10, 30, 4)
                roll_win_pert_pred_nested_list[rolling_window_idx][model_idx][perturbed_idx][country_idx]["Extended Recursive Predictions"] *= window_bias_corrections[country_idx]

for rolling_window_idx in range(len(roll_win_unpert_pred_nested_list)):
    for model_idx in range(len(roll_win_unpert_pred_nested_list[0])):
        for country_idx in range(len(roll_win_unpert_pred_nested_list[0][0])):
            # (523, 100, 10, 30, 4)
            roll_win_unpert_pred_nested_list[rolling_window_idx][model_idx][country_idx]["Extended Recursive Predictions"] *= window_bias_corrections[country_idx]

## Make Recursive vs Perturbed Recursive Plots
These plots are also for 1 window, 1 model.

In [None]:
def plot_recursive_and_perturbed_recursive_per_country(perturb_df_nested_lists, regular_df_nested_list, rolling_window, model_num):
    """
    This function plots the 5x2 figure of unperturbed recursive versus perturbed recursive predictions
    for each of the 10 countries. The entire 5x2 figure is for 1 model and 1 rolling window, so need to
    plot many of these to see across different models and windows.

    Thick black line is unperturbed recursive prediction for that country, thick orange line is the 
    ground truth for that country, and other 9 colored lines are the recursive predictions for the country
    when the other 9 countries are perturbed.
    
    Args:
        - perturb_df_nested_lists: (10, 10, 30, 4)
        - regular_df_nested_list: (10, 30, 4)
    """
    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(60, 16))
    fig.suptitle("DCSAGE Recursive vs Perturbed Recursive (Window {}, Model {})".format(rolling_window, model_num), fontsize= 30)

    idx = 0
    for row in ax:
        for col in row:
            all_subplot_values = []
            # First dataframe has recursive predictions for a country when other 9 countries are perturbed. Perturbed country index first
            visual_df = pd.DataFrame({
                "{} Perturbed".format(continents[0]): perturb_df_nested_lists[0][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[1]): perturb_df_nested_lists[1][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[2]): perturb_df_nested_lists[2][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[3]): perturb_df_nested_lists[3][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[4]): perturb_df_nested_lists[4][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[5]): perturb_df_nested_lists[5][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[6]): perturb_df_nested_lists[6][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[7]): perturb_df_nested_lists[7][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[8]): perturb_df_nested_lists[8][idx]["Extended Recursive Predictions"],
                "{} Perturbed".format(continents[9]): perturb_df_nested_lists[9][idx]["Extended Recursive Predictions"],
                "Day Index": list(range(len(regular_df_nested_list[idx]["Extended Recursive Predictions"]))) 
            })
            visual_df = visual_df.drop(columns=[continents[idx] + " Perturbed"])

            for i in range(10):
                all_subplot_values += list(perturb_df_nested_lists[i][idx]["Extended Recursive Predictions"])

            # Second dataframe has unperturbed recursive predictions for the country
            visual_df2 = pd.DataFrame({
                "Unperturbed Recursive": regular_df_nested_list[idx]["Extended Recursive Predictions"],
                "Day Index": list(range(len(regular_df_nested_list[idx]["Extended Recursive Predictions"]))) 
            })
            all_subplot_values += list(regular_df_nested_list[idx]["Extended Recursive Predictions"])

            visual_df3 = pd.DataFrame({
                "Ground Truth": regular_df_nested_list[idx]["Ground Truth"],
                "Day Index": list(range(len(regular_df_nested_list[idx]["Extended Recursive Predictions"]))) 
            })
            for i in range(11):  # Add ground truth 11 times to balance out all perturbed. We need a middle value
                # between unperturbed and perturbed to make good y-axis range
                all_subplot_values += list(regular_df_nested_list[idx]["Ground Truth"])
            all_subplot_values_mean = sum(all_subplot_values) / len(all_subplot_values)

            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df, ['Day Index']))
            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df2, ['Day Index']), linewidth = 4, palette=['black'])
            sns.lineplot(ax=col, x='Day Index', y='value', hue='variable', data=pd.melt(visual_df3, ['Day Index']), linewidth = 4, palette=['orange'])
            col.set_title(continents[idx])
            col.set_ylim([all_subplot_values_mean - 2, all_subplot_values_mean + 2])
            # box = col.get_position()
            # col.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            col.legend(loc='center left', bbox_to_anchor=(1, 0.5))
            idx += 1

    fig.tight_layout()
    plt.savefig("./window{}_model{}_rec_vs_pert_rec_2x5.png".format(rolling_window, model_num), bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
# Plot recursive vs perturbed recursive plot for first model, first window of test set
for rolling_window in [0, 100, 200, 300, 400, 500]:
# for rolling_window in [25, 50, 75, 100, 125]:
    for model_idx in range(0, NUM_MODELS, 5):
    # for model_idx in [46, 51, 3, 5, 6, 39]:
    # for model_idx in range(8):
        plot_recursive_and_perturbed_recursive_per_country(
            roll_win_pert_pred_nested_list[rolling_window][model_idx], 
            roll_win_unpert_pred_nested_list[rolling_window][model_idx],
            rolling_window = rolling_window,
            model_num=model_idx)

# Compute or Load 10x10 sensitivity array

In [None]:
def node_perturbation_difference_heatmap(perturb_df_nested_lists, regular_df_nested_list):
    """
    This function is a helper function for getting sensitivity 10x10 list. Main loop that is
    iterating over rolling windows and models will call this function to get 10x10 for a 
    single model for a single window.
    
    Args:
        - perturb_df_nested_lists: (10, 10, 30, 4): 10 perturbed countries, 
            10 countries in graph, pd.DataFrame of shape (30, 4)
        - regular_df_nested_list: (10, 30, 4): 10 countries in graph, pd.DataFrame (30, 4)
    """
    aggreg_differences_lists = []

    for perturbed_country_idx in range(10):
        aggreg_differences = []
        for country_idx in range(10):
            if country_idx == perturbed_country_idx:
                aggreg_differences.append(np.nan)
            else:
                difference_list = perturb_df_nested_lists[perturbed_country_idx][country_idx]['Extended Recursive Predictions'] - regular_df_nested_list[country_idx]['Extended Recursive Predictions']
                # Take unsigned sensitivity
                # difference_list = np.abs(difference_list)
                aggreg_differences.append(difference_list.sum())

        aggreg_differences_lists.append(aggreg_differences)
    return aggreg_differences_lists

Important:
The code cell below uses the helper function above and computes the sensitivity 10x10 array - 
shape will be (num_windows, num_models, 10, 10)) - from the perturbed and unperturbed arrays.

This function takes significant time to compute this array. If you already have the 
sensitivity array saved you don't need to run this cell, just run np.load() on the saved
sensitivity 10x10 array from google drive.

In [None]:
roll_win_aggreg_diff_nested_list = []
for roll_win_idx in range(len(roll_win_unpert_pred_nested_list)):
    if roll_win_idx % 20 == 0:
        print("Rolling window", roll_win_idx)

    model_sens_score_nested_lists = []
    for model_idx in range(len(roll_win_unpert_pred_nested_list[0])):
        aggreg_differences_lists = node_perturbation_difference_heatmap(
            roll_win_pert_pred_nested_list[roll_win_idx][model_idx], 
            roll_win_unpert_pred_nested_list[roll_win_idx][model_idx])
        model_sens_score_nested_lists.append(aggreg_differences_lists)
    
    roll_win_aggreg_diff_nested_list.append(model_sens_score_nested_lists)
# roll_win_aggreg_diff_nested_list is [num_windows, NUM_MODELS, 10, 10]

If you want to save the 10x10 sensitivity list that has been created, run this cell

In [None]:
roll_win_aggreg_diff_nested_list = np.array(roll_win_aggreg_diff_nested_list)
print(roll_win_aggreg_diff_nested_list.shape)
np.save("./{}_7day_100model_meanagg_v19_10x10_bias_corrected_signed.npy".format(MODEL_NAME), np.array(roll_win_aggreg_diff_nested_list))
# np.save("./{}_7day_100model_meanagg_v19_10x10_bias_corrected_unsigned.npy".format(MODEL_NAME), np.array(roll_win_aggreg_diff_nested_list))

If you are loading 10x10 sensitivity array that was downloaded from google drive, run this cell

In [None]:
roll_win_aggreg_diff_nested_list = np.load("./analysis-runs-multiple-models/" + ANALYSIS_DIR + "/prediction_saves/DCSAGE_7day_100model_meanagg_v19_10x10_bias_corrected_unsigned.npy")
print(roll_win_aggreg_diff_nested_list.shape)

### Plot edge perturbation heatmap
Aggregated 10x10 heatmap across model dimension and rolling window dimension.
Plot both summation and mean heatmap.

In [None]:
# def plot_edge_perturbation_heatmap(roll_win_aggreg_diff_nested_list, title, save_path, kind="mean"):
#     assert kind in ["mean", "sum"]

#     if kind == "mean":
#         # Average across model dimension, then across windows
#         matr = np.nanmean(np.nanmean(roll_win_aggreg_diff_nested_list, axis=1), axis=0)
#         plt.figure(figsize=(12,8), dpi=100)
#     else:
#         matr = np.nansum(np.nansum(roll_win_aggreg_diff_nested_list, axis=1), axis=0)
#         plt.figure(figsize=(16,10), dpi=100)

#     sns.heatmap(matr, annot=True, cmap="Blues", fmt=".3f", xticklabels=continents, yticklabels=continents)
#     plt.title(title)
#     plt.savefig(save_path, bbox_inches="tight", facecolor="white")
#     plt.close()


In [None]:
# plot_edge_perturbation_heatmap(roll_win_aggreg_diff_nested_list, 
#     title="{} {}-day Edge Perturbation Mean Heatmap ({} Models, {} Windows)".format(MODEL_NAME, WINDOW_SIZE, NUM_MODELS, len(roll_win_aggreg_diff_nested_list)), 
#     save_path="./{}_mean_edge_pert_heatmap.png".format(MODEL_NAME), 
#     kind="mean")
# plot_edge_perturbation_heatmap(roll_win_aggreg_diff_nested_list, 
#     title="{} {}-day Edge Perturbation Summed Heatmap ({} Models, {} Windows)".format(MODEL_NAME, WINDOW_SIZE, NUM_MODELS, len(roll_win_aggreg_diff_nested_list)), 
#     save_path="./{}_summed_edge_pert_heatmap.png".format(MODEL_NAME), 
#     kind="sum")

### Make Figure for Sensitivity Score Trends Over Rolling Windows by Each Model

In [None]:
def plot_multiple_model_sens_score_trends_lineplot(roll_win_aggreg_diff_nested_list):
    """
    This function plots the 5x2 figure of sensitivity score lineplots over rolling windows 
    (x-axis), where each model is a single line on the subplot (no averaging across models).
    Sensitivity scores are calculated by summing across row of 10x10 sensitivity array.
    
    Args:
        - roll_win_aggreg_diff_nested_list: 10x10 sensitivity array, shape (num_windows, NUM_MODELS, 10, 10)
    """
    # Sum the fourth dimension to get sensitivity scores, (num_windows, NUM_MODELS, 10).
    sensitivty_score_nested_np = np.nansum(np.array(roll_win_aggreg_diff_nested_list), axis=3)

    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(30,15))
    plt.rcParams.update({'font.size': 20})
    fig.suptitle("{} {} Model Sensitivity Score Trends".format(MODEL_NAME, NUM_MODELS), fontsize= 30)

    node_idx = 0
    for row in ax:
        for col in row:
            node_subplot_dict = { "Model_{}".format(model_idx): sensitivty_score_nested_np[:,model_idx,node_idx] for model_idx in range(NUM_MODELS) }
            node_subplot_dict["Rolling Window Index"] = list(range(len(roll_win_aggreg_diff_nested_list)))
            visual_df = pd.DataFrame(node_subplot_dict)

            sns.lineplot(ax=col, x='Rolling Window Index', y='Sensitivity Scores', hue='Model', data=pd.melt(visual_df, ['Rolling Window Index'], value_name="Sensitivity Scores", var_name="Model"))
            col.set_title(continents[node_idx])
            col.legend().remove()
            col.set_ylim([-70, 70])
            node_idx += 1

    filename = str(NUM_MODELS) + "_models_sens_score_trends"
    plt.tight_layout()
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_multiple_model_sens_score_trends_lineplot(roll_win_aggreg_diff_nested_list)

### Make Figure for Sensitvity Score Distribution Before Filtering

In [None]:
def plot_multiple_model_roll_win_sens_distribution(roll_win_aggreg_diff_nested_list):
    """
    This function plots the 5x2 figure of sensitivity score distributions for each of the
    10 continents. Each distribution will contain num_windows * num_models sensitivity scores.
    Sensitivity scores are calculated by summing across row of 10x10 sensitivity array.
    
    Args:
        - roll_win_aggreg_diff_nested_list: 10x10 sensitivity array, shape (num_windows, NUM_MODELS, 10, 10)
    """
    sensitivty_score_nested_np = np.nansum(np.array(roll_win_aggreg_diff_nested_list), axis=3)

    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(35,16))
    plt.rcParams.update({'font.size': 20})
    fig.suptitle("{} {} Model Sensitivity Score Distributions".format(MODEL_NAME, NUM_MODELS), fontsize=30)

    pert_node_idx = 0
    for row in ax:
        for col in row:
            country_sensitivity_scores = list(sensitivty_score_nested_np[:,:,pert_node_idx].flatten())
            print("{} has {} sensitivity scores in distribution.".format(continents[pert_node_idx], len(country_sensitivity_scores)))

            visual_df = pd.DataFrame({
                "Sensitivity Score": country_sensitivity_scores,
            })

            sns.histplot(ax=col, x='Sensitivity Score', data=visual_df, kde=True)
            mode = st.mode(country_sensitivity_scores)
            median = np.median(np.array(country_sensitivity_scores))
            mean = np.mean(np.array(country_sensitivity_scores))
            stddev = np.array(country_sensitivity_scores).std()
            col.set_title(continents[pert_node_idx] + "\nMode " + str(round(mode, 2)) + ", Mean: " + str(round(mean, 2)) + "\nMedian: " + str(round(median, 2)) + ", Std: " + str(round(stddev, 2)))
            col.set_xlim([-40, 40])
            pert_node_idx += 1

    plt.tight_layout()
    filename = str(NUM_MODELS) + "_models_sens_score_distrib"
    filename = "{}_models_sens_score_distrib".format(NUM_MODELS)
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor="white")
    plt.clf()
    plt.close()

In [None]:
plot_multiple_model_roll_win_sens_distribution(roll_win_aggreg_diff_nested_list)

### Sum sensitivity over other 10 nodes

In [None]:
# Get sensitivity score distribution for each country
print(roll_win_aggreg_diff_nested_list.shape)
sensitivty_score_nested_np = np.nansum(np.array(roll_win_aggreg_diff_nested_list), axis=3)
print("Shape:", sensitivty_score_nested_np.shape)

In [None]:
# No filtering of models is performed
filtering_K = -1
outlier_model_indices = []

## Plotting Sensitivity Score Lineplot After Filtering

In [None]:
def plot_multiple_model_sens_score_trends_after_filtering(roll_win_aggreg_diff_nested_list):
    """
    This function plots the same 5x2 figure of sensitivity score lineplots over rolling windows (x-axis) as 
    before, but now we filter out outlier models.
    
    Args:
        - roll_win_aggreg_diff_nested_list: 10x10 sensitivity array, shape (num_windows, NUM_MODELS, 10, 10)
    """
    num_models_after_filtering = NUM_MODELS - len(outlier_model_indices)

    # Sum the fourth dimension to get sensitivity scores, (num_windows, NUM_MODELS, 10). Then remove outlier models
    sensitivty_score_nested_np = np.nansum(np.array(roll_win_aggreg_diff_nested_list), axis=3)
    sensitivty_score_nested_np = np.delete(sensitivty_score_nested_np, outlier_model_indices, axis=1)
    print("Shape of sensitivity array after filtering:", sensitivty_score_nested_np.shape)

    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(30,15))
    plt.rcParams.update({'font.size': 20})
    fig.suptitle("{} {} Model Sensitivity Score Trends After {} IQR Filtering"
                    .format(MODEL_NAME, num_models_after_filtering, filtering_K), fontsize= 30)

    node_idx = 0
    for row in ax:
        for col in row:
            country_subplot_dict = { "Model_{}".format(model_idx): sensitivty_score_nested_np[:,model_idx,node_idx] for model_idx in range(num_models_after_filtering) }
            country_subplot_dict["Rolling Window Index"] = list(range(len(roll_win_aggreg_diff_nested_list)))
            visual_df = pd.DataFrame(country_subplot_dict)

            sns.lineplot(ax=col, x='Rolling Window Index', y='Sensitivity Scores', hue='Model', data=pd.melt(visual_df, ['Rolling Window Index'], value_name="Sensitivity Scores", var_name="Model"))
            col.set_title(continents[node_idx])
            col.legend().remove()
            col.set_ylim([-50, 50])
            node_idx += 1

    filename = str(num_models_after_filtering) + "_models_sens_score_trends_after_{}IQR_filtering".format(filtering_K)
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_multiple_model_sens_score_trends_after_filtering(roll_win_aggreg_diff_nested_list)

## Plotting Sensitivity Score Distribution After Filtering

In [None]:
def plot_multiple_model_sens_distrib_after_filtering(roll_win_aggreg_diff_nested_list):
    """
    This function plots the same 5x2 figure of sensitivity score distributions for each of the
    10 countries as before, but now we filter out outlier models.
    
    Args:
        - roll_win_aggreg_diff_nested_list: 10x10 sensitivity array, shape (num_windows, NUM_MODELS, 10, 10)
    """
    num_models_after_filtering = NUM_MODELS - len(outlier_model_indices)
    sensitivty_score_nested_np = np.nansum(np.array(roll_win_aggreg_diff_nested_list), axis=3)
    sensitivty_score_nested_np = np.delete(sensitivty_score_nested_np, outlier_model_indices, axis=1)

    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(35,12))
    plt.rcParams.update({'font.size': 20})
    fig.suptitle("{} {} Model Sensitivity Score Distributions After {} IQR Filtering"
                    .format(MODEL_NAME, num_models_after_filtering, filtering_K), fontsize=30)

    pert_node_idx = 0
    for row in ax:
        for col in row:
            country_sensitivity_scores = sensitivty_score_nested_np[:,:,pert_node_idx].flatten()
            print("{} has {} sensitivity scores in distribution.".format(continents[pert_node_idx], len(country_sensitivity_scores)))
            
            visual_df = pd.DataFrame({
                "Sensitivity Score": country_sensitivity_scores,
            })

            sns.histplot(ax=col, x='Sensitivity Score', data=visual_df, kde=True)
            mode = st.mode(country_sensitivity_scores)
            median = np.median(np.array(country_sensitivity_scores))
            stddev = np.array(country_sensitivity_scores).std()
            col.set_title(continents[pert_node_idx] + "\nMode " + str(round(mode, 2)) + ", Median: " + str(round(median, 2)) + ", Std: " + str(round(stddev, 2)))
            col.set_xlim([-40, 40])
            pert_node_idx += 1

    filename = "{}_models_sens_score_distrib_after_{}IQR_filtering".format(num_models_after_filtering, filtering_K)
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor="white")
    plt.clf()
    plt.close()

In [None]:
plot_multiple_model_sens_distrib_after_filtering(roll_win_aggreg_diff_nested_list)

# Fit EV and Normal Distribution on Windows Per Country
See which distribution tends to fit best based on lowest KL-divergence score

In [None]:
def kl_divergence(p, q):
	return sum(p[i] * np.log2(p[i]/q[i]) for i in range(len(p)))

In [None]:
sensitivty_score_nested_np.shape

In [None]:
# Remove outliers
sensitivty_score_nested_np = np.delete(sensitivty_score_nested_np, outlier_model_indices, axis=1)
print("Shape after removing models:", sensitivty_score_nested_np.shape)
# print("Shape unfiltered:", sensitivty_score_nested_np.shape)

In [None]:
# Fit several distributions per window and node, see which gives lowest KL-Divergence
distributions = ['gumbel_r', 'norm']  # 'gumbel_l'
roll_win_list = []

for roll_win in range(len(sensitivty_score_nested_np)):
    if roll_win % 50 == 0:
        print("Rolling window", roll_win)
    node_best_fit = []
    for country_idx in range(10):
        distrib_KL_diverg_scores = []
        for distrib_name in distributions:
            values = sensitivty_score_nested_np[roll_win,:,country_idx].flatten()
            distribution = getattr(stats, distrib_name)
            parameters = distribution.fit(values)

            # Getting empirical probabilities from seaborn histplot bins
            fig = sns.histplot(x=sensitivty_score_nested_np[roll_win,:,country_idx].flatten(), kde=True, stat="probability")
            xvals, empirical_probs = fig.get_lines()[0].get_data()
            if distrib_name == "gumbel_l" or distrib_name == "gumbel_r" or distrib_name == "norm":
                theoretical_probs = distribution.pdf(xvals, loc=parameters[0], scale=parameters[1])
            elif distrib_name == "genextreme":
                theoretical_probs = distribution.pdf(xvals, c=parameters[0], loc=parameters[1], scale=parameters[2])
            elif distrib_name == "lognorm":
                theoretical_probs = distribution.pdf(xvals, s=parameters[0], loc=parameters[1], scale=parameters[2])
            else:
                raise Exception("Unknown distribution specified:")
            plt.clf()

            # Replace zeros with a small value, qk cannot be zero
            empirical_probs[empirical_probs == 0] = 0.001
            theoretical_probs[theoretical_probs == 0] = 0.001
            kl_diverg = kl_divergence(p=empirical_probs, q=theoretical_probs)
            distrib_KL_diverg_scores.append(kl_diverg)
        
        idx_min_KL_diverg = distrib_KL_diverg_scores.index(min(distrib_KL_diverg_scores))
        node_best_fit.append(distributions[idx_min_KL_diverg])
    roll_win_list.append(node_best_fit)

In [None]:
distrib_counts = { "gumbel_r": 0, "norm": 0 }
for roll_win in range(len(roll_win_list)):
    for idx in range(10):
        distrib_counts[roll_win_list[roll_win][idx]] += 1

distrib_counts

In [None]:
roll_win_list[0]

## Fit Best Distribution For Each Country on Each Window
Remove filtered models

In [None]:
# sensitivty_score_nested_np = np.delete(sensitivty_score_nested_np, outlier_model_indices, axis=1)
print("Shape after removing models:", sensitivty_score_nested_np.shape)
# print("Shape unfiltered:", sensitivty_score_nested_np.shape)

In [None]:
FIT_NORMAL = False

Fit best distribution for each country on each rolling window

In [None]:
roll_win_loc_params = []
roll_win_scale_params = []
for roll_win in range(len(sensitivty_score_nested_np)):
    location_params = []
    scale_params = []
    for node_idx in range(10):
        values = sensitivty_score_nested_np[roll_win,:,node_idx]
        if FIT_NORMAL:
            params = stats.norm.fit(values)
        else:
            # params = stats.gumbel_l.fit(values)
            params = stats.gumbel_r.fit(values)
        location_params.append(params[0])
        scale_params.append(params[1])
    
    roll_win_loc_params.append(location_params)
    roll_win_scale_params.append(scale_params)

In [None]:
roll_win_loc_params = np.array(roll_win_loc_params)
roll_win_scale_params = np.array(roll_win_scale_params)
print(roll_win_loc_params.shape)
print(roll_win_scale_params.shape)

In [None]:
# Checking what distribution looks like for 1 country on a particular window

for single_country_idx in range(10):
    for single_window_idx in [0, 100, 200, 300, 400, 500]:
        x = np.linspace(60, 0, 100)
        if FIT_NORMAL:
            plt.plot(x, stats.norm.pdf(x, roll_win_loc_params[single_window_idx, single_country_idx], roll_win_scale_params[single_window_idx, single_country_idx]), 'r-', label='normal pdf')
        else:
            # plt.plot(x, stats.gumbel_l.pdf(x, roll_win_loc_params[single_window_idx, single_country_idx], roll_win_scale_params[single_window_idx, single_country_idx]), 'r-', label='EV pdf')
            plt.plot(x, stats.gumbel_r.pdf(x, roll_win_loc_params[single_window_idx, single_country_idx], roll_win_scale_params[single_window_idx, single_country_idx]), 'r-', label='EV pdf')
        
        plt.rcParams.update({'font.size': 16})
        sns.histplot(x=sensitivty_score_nested_np[single_window_idx,:,single_country_idx].flatten(), stat="density", label=continents[single_country_idx] + " scores")
        
        mean = np.mean(sensitivty_score_nested_np[single_window_idx,:,single_country_idx].flatten())
        median = np.median(sensitivty_score_nested_np[single_window_idx,:,single_country_idx].flatten())
        std = np.std(sensitivty_score_nested_np[single_window_idx,:,single_country_idx].flatten())

        if FIT_NORMAL:
            title = "Normal Distribution fitted on {} on window {}\nMean: {:.4f}, Median: {:.4f}, STD: {:.4f}".format(continents[single_country_idx], single_window_idx, mean, median, std)
            plt.title(title, fontsize=16)
            plt.savefig("./normal_unfiltered_fit_{}_win{}.png".format(continents[single_country_idx], single_window_idx), bbox_inches="tight", facecolor="white")
        else:
            title = "EV Distribution fitted on {} on window {}\nMean: {:.4f}, Median: {:.4f}, STD: {:.4f}".format(continents[single_country_idx], single_window_idx, mean, median, std)
            plt.title(title, fontsize=16)
            plt.savefig("./ev_unfiltered_fit_{}_win{}.png".format(continents[single_country_idx], single_window_idx), bbox_inches="tight", facecolor="white")
        
        # plt.show()
        plt.clf()

Run this line. The subsequent plots will be made using mu and sigma parameters of the fitted normal or EV distribution

In [None]:
roll_win_fitted_means = roll_win_loc_params
roll_win_fitted_stds = roll_win_scale_params

## Plot Model Average Sensitivity Score Trends After Filtering
This plot will be made with fitted Gumbel_L mean and standard deviation, calculated above

In [None]:
def plot_multiple_model_average_sens_score_trend_after_filtering(roll_win_fitted_means, roll_win_fitted_stds):
    """
    This function plots the 5x2 figure of average sensitivity score with a 1 STD interval above and below the
    mean. Here, mean and std are the mean and std calculated above by fitting gumbel_L distribution on each
    window for each model.
    
    Args:
        - roll_win_fitted_means: shape (num_windows, 10)
        - roll_win_fitted_stds: shape (num_windows, 10)
    """
    num_models_after_filtering = NUM_MODELS - len(outlier_model_indices)

    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(30,15))
    plt.rcParams.update({'font.size': 20})
    if len(outlier_model_indices) != 0:
        fig.suptitle("{} {} Model Average Sensitivity Score After {} IQR Filtering".format(MODEL_NAME, num_models_after_filtering, filtering_K), fontsize= 30)
    else:
        fig.suptitle("{} {} Model Average Fitted Sensitivity Score Mean (Unfiltered)".format(MODEL_NAME, num_models_after_filtering), fontsize= 30)
    
    node_idx = 0
    for row in ax:
        for col in row:
            # all_model_country_scores is (num_models, num_roll_windows)
            col.plot(roll_win_fitted_means[:, node_idx])
            col.fill_between(list(range(len(roll_win_fitted_means))), (roll_win_fitted_means[:, node_idx] - roll_win_fitted_stds[:, node_idx]), (roll_win_fitted_means[:, node_idx] + roll_win_fitted_stds[:, node_idx]), color='b', alpha=.1)
            col.set_title(continents[node_idx])
            col.set_xlabel("Rolling Window Index")
            col.set_ylabel("Fitted Sensitivity Score")
            col.set_ylim([0, 40])
            node_idx += 1

    if len(outlier_model_indices) != 0:
        filename = "{}_models_avg_sens_scores_after_{}IQR_filtering".format(num_models_after_filtering, filtering_K)
    else:
        filename = "{}_models_avg_sens_scores_unfiltered".format(num_models_after_filtering)
    
    plt.tight_layout()
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_multiple_model_average_sens_score_trend_after_filtering(roll_win_fitted_means, roll_win_fitted_stds)

In [None]:
def plot_violin_plot_fitted_mu_after_filtering(roll_win_fitted_means):
    """
    This function plots a violin plot with country on the x-axis and fitted sensitivity score mu parameter
    on the y-axis. The figure is meant to show the different continent distribution on the same y-axis, 
    for comparison among continents after we rank the countries.
    
    Args:
        - roll_win_fitted_means: [num_windows, 10] array
    """
    roll_win_fitted_means_normalized = roll_win_fitted_means / roll_win_fitted_means.max()
    num_models_after_filtering = NUM_MODELS - len(outlier_model_indices)
    continent_mu_parameters = []
    corresp_country_name = []

    continent_names = ["Africa", "North America", "South America", "Oceania", "Eastern Europe", "Western Europe", "Middle East", "South Asia", "Southeast- East Asia", "Central Asia"]
    continent_names = ["\n".join(name.split(" ")) for name in continent_names]
    colors = sns.color_palette().as_hex()
    medians = []

    for node_idx in range(10):
        continent_mu_parameters += list(roll_win_fitted_means_normalized[:,node_idx].flatten())
        medians.append(np.median(roll_win_fitted_means_normalized[:,node_idx].flatten()))
        corresp_country_name += [continent_names[node_idx]] * len(roll_win_fitted_means_normalized[:,node_idx].flatten())
    
    colors_sorted = [x for _,x in sorted(zip(medians, colors), reverse=True)]
    continent_names_sorted = [x for _,x in sorted(zip(medians, continent_names), reverse=True)]

    visual_df = pd.DataFrame({
        "Continent": corresp_country_name,
        "Fitted Sensitivity Mu Parameter (Scaled to 0 - 1)": continent_mu_parameters
    })

    plt.figure(figsize=(20,8))
    plt.rcParams.update({'font.size': 20})
    sns.violinplot(data=visual_df, x="Continent", y="Fitted Sensitivity Mu Parameter (Scaled to 0 - 1)", 
                    order=continent_names_sorted, palette=colors_sorted)
    if len(outlier_model_indices) != 0:
        plt.title("{} {} Models Fitted Sensitivity Mu Parameter Violin Plot After {} IQR Filtering".format(MODEL_NAME, num_models_after_filtering, filtering_K), fontsize=24)
    else:
        plt.title("{} {} Models Fitted Sensitivity Mu Parameter Violin Plot (Unfiltered)".format(MODEL_NAME, num_models_after_filtering), fontsize=24)
    
    if len(outlier_model_indices) != 0:
        filename = "{}_models_fitted_mu_violin_plot_after_{}IQR_filtering".format(num_models_after_filtering, filtering_K)
    else:
        filename = "{}_models_fitted_mu_violin_plot_unfiltered".format(num_models_after_filtering)
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_violin_plot_fitted_mu_after_filtering(roll_win_fitted_means)

In [None]:
def plot_node_avg_sens_trend_after_filtering(roll_win_fitted_means):
    """
    This function makes a figure showing node average sensitivity score over rolling windows.
    
    Args:
        - roll_win_gumbel_means: shape (num_windows, 10)
    """
    num_models_after_filtering = NUM_MODELS - len(outlier_model_indices)
    roll_win_fitted_means_normalized = roll_win_fitted_means / roll_win_fitted_means.max()

    node_dict = { continents[node_idx]: roll_win_fitted_means_normalized[:, node_idx] for node_idx in range(10) }
    node_dict["Rolling Window Index"] = list(range(len(roll_win_fitted_means_normalized)))
    visual_df = pd.DataFrame(node_dict)

    plt.figure(figsize=(16, 8), dpi=80)
    plt.rcParams.update({'font.size': 20})
    sns.lineplot(x='Rolling Window Index', y='Fitted Sensitivity Mu Parameter (Scaled 0 - 1)', hue='Continent', data=pd.melt(visual_df, ['Rolling Window Index'], value_name="Fitted Sensitivity Mu Parameter (Scaled 0 - 1)", var_name="Continent"))
    if len(outlier_model_indices) != 0:
        plt.title("{} {} Models Fitted Sensitivity Mu Parameter Trend After {}IQR Filtering".format(MODEL_NAME, num_models_after_filtering, filtering_K), fontsize=24)
    else:
        plt.title("{} {} Models Fitted Sensitivity Mu Parameter Trend Unfiltered".format(MODEL_NAME, num_models_after_filtering), fontsize=24)
    plt.xlabel('Rolling Window Idx')  
    plt.ylabel('Fitted Sensitivity Mu Parameter (Scaled 0 - 1)')
    plt.ylim(0, 1)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    if len(outlier_model_indices) != 0:
        filename = "{}_models_node_avg_sens_score_after_{}IQR_filtering".format(num_models_after_filtering, filtering_K)
    else:
        filename = "{}_models_node_avg_sens_score_unfiltered".format(num_models_after_filtering)
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_node_avg_sens_trend_after_filtering(roll_win_fitted_means)

## Plotting Rankings

In [None]:
def plot_ranks_by_sensitivty_score_mean_after_filtering(roll_win_aggreg_diff_nested_list):
    """
    This function plots the continents rankings over rolling windows, ranked by sensitivity score 
    mean. This is not our final ranking, we just want to compare this to fitted distribution
    rankings.
    
    Args:
        - roll_win_aggreg_diff_nested_list: 10x10 sensitivity array, shape (num_windows, NUM_MODELS, 10, 10)
    """
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
                'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

    sensitivty_score_nested_np = np.nansum(np.array(roll_win_aggreg_diff_nested_list), axis=3)
    if len(outlier_model_indices) != 0:
        sensitivty_score_nested_np = np.delete(sensitivty_score_nested_np, outlier_model_indices, axis=1)
    new_models_100_avg = np.nanmean(sensitivty_score_nested_np, axis=1)

    # Compute ranks
    ranks = np.zeros((len(new_models_100_avg),len(new_models_100_avg[0])))
    for i in range(len(new_models_100_avg)):
        array = new_models_100_avg[i,:]
        temp = (-array).argsort()  # negative array if we want highest sensitivity to be 1st palce
        ranks[i,:] = np.arange(len(array))[temp.argsort()] + 1  # Each position tells rank of model at that index

    plt.figure(figsize=(20, 6))
    plt.rcParams.update({'font.size': 20})
    for i in range(10):
        plt.plot(ranks[:,i], "o-", mfc="w", label=continents[i], color=colors[i])

    plt.gca().invert_yaxis()
    if len(outlier_model_indices) != 0:
        plt.title("Ranking by Sensitivity Score Mean After {}IQR Filtering".format(filtering_K), fontsize=24)
    else:
        plt.title("Ranking by Sensitivity Score Mean Unfiltered", fontsize=24)
    plt.xlabel("Rolling Window Index")
    plt.ylabel("Ranking")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    if len(outlier_model_indices) != 0:
        filename = "ranking_by_sensitivity_mean_after_{}IQR_filtering".format(filtering_K)
    else:
        filename = "ranking_by_sensitivity_mean_unfiltered"
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_ranks_by_sensitivty_score_mean_after_filtering(roll_win_aggreg_diff_nested_list)

In [None]:
def plot_ranks_by_fitted_mean_after_filtering(roll_win_fitted_means):
    """
    This function plots the continents rankings over rolling windows, ranked by fitted gumbel 
    distribution mean.
    
    Args:
        - roll_win_fitted_means: shape (num_windows, 10)
    """
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
                'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

    # Compute ranks
    ranks = np.zeros(roll_win_fitted_means.shape)
    for roll_win in range(len(roll_win_fitted_means)):
        array = roll_win_fitted_means[roll_win,:]
        temp = (-array).argsort()  # Do negative arra for highest sensitivity 1st place
        ranks[roll_win,:] = np.arange(len(array))[temp.argsort()] + 1  # Each position tells rank of model at that index

    plt.figure(figsize=(20, 6))
    plt.rcParams.update({'font.size': 20})
    for i in range(10):
        plt.plot(ranks[:,i], "o-", mfc="w", label=continents[i], color=colors[i])

    plt.gca().invert_yaxis()
    if len(outlier_model_indices) != 0:
        plt.title("Ranking by Fitted Distribution Mean After {}IQR Filtering".format(filtering_K), fontsize=24)
    else:
        plt.title("Ranking by Fitted Distribution Mean Unfiltered", fontsize=24)
    plt.xlabel("Rolling Window Index")
    plt.ylabel("Ranking")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    if len(outlier_model_indices) != 0:
        filename = "ranking_by_fitted_distrib_mean_after_{}IQR_filtering".format(filtering_K)
    else:
        filename = "ranking_by_fitted_distrib_mean_unfiltered"
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_ranks_by_fitted_mean_after_filtering(roll_win_fitted_means)

In [None]:
def save_csv_for_sensitivity_rank_bump_chart(roll_win_fitted_means):
    """
    Args:
        - roll_win_fitted_means: shape (num_windows, 10)
    """
    roll_win_fitted_means_normalized = roll_win_fitted_means / roll_win_fitted_means.max()
    continent_name_list = []
    window_idx_list = []
    score_list = []

    for roll_window_idx in range(len(roll_win_fitted_means_normalized)):
        for node_idx in range(10):
            continent_name_list.append(continents[node_idx])
            window_idx_list.append(roll_window_idx)
            score_list.append(roll_win_fitted_means_normalized[roll_window_idx, node_idx])
    
    bump_chart_df = pd.DataFrame({
        "Continent": continent_name_list,
        "Rolling Window Index": window_idx_list,
        "Fitted Sensitivity Center": score_list
    })
    bump_chart_df.to_csv("./bump_chart_data.csv", index=False)

In [None]:
save_csv_for_sensitivity_rank_bump_chart(roll_win_fitted_means)

# Correlation Plot Between Flight Rankings and Fitted Distribution Mean Rankings

Note: the code from here and below is old, and does not need to be run immedietely for node perturbation results.
Code needs to be updated for dataset v18, if we want to make some plots between flights/cases and sensitivity trends.

In [None]:
datasetv19 = np.load("./datasets/10_continents_dataset_v19_node_pert.npz")
datasetv19.files

In [None]:
daywise_outgoing_flights = np.nansum(datasetv19['flight_matrix_unscaled'], axis=1)
daywise_outgoing_flights.shape

Change flight counts from days to rolling windows by summing across recursive prediction rolling window. Length of rolling window is 30 + WINDOW_LENGTH, because model takes 1 window and then predicts 30 days, so it takes 30 + WINDOW_SIZE days of flight data in total for 1 rolling window.

len(daywise_outgoing_flights) - 30 - WINDOW_SIZE - WINDOW_SIZE is because a rolling window is 30 + WINDOW_SIZE in length, and subtracting the second WINDOW_SIZE accounts for the dataloader creating windows from days and having WINDOW_SIZE less days than windows to begin with.

In [None]:
rollwin_outgoing_flights = [daywise_outgoing_flights[idx: idx + 30 + WINDOW_SIZE, :].sum(axis=0) for idx in range(len(daywise_outgoing_flights) - 30 - WINDOW_SIZE - WINDOW_SIZE)]
rollwin_outgoing_flights = np.array(rollwin_outgoing_flights)
rollwin_outgoing_flights.shape

In [None]:
# Uncomment if want to log10 transform flights
# rollwin_outgoing_flights = np.log10(rollwin_outgoing_flights, where=rollwin_outgoing_flights != 0)

In [None]:
roll_win_fitted_means.shape

### Plotting Fitted Distribution Mu Parameter against number of flights Scatterplot

In [None]:
fitted_means = []
flight_numbers = []
continents_name = []
num_models_after_filtering = NUM_MODELS - len(outlier_model_indices)
roll_win_fitted_means_normalized = roll_win_fitted_means / roll_win_fitted_means.max()

for i in range(10):
    fitted_means += list(roll_win_fitted_means_normalized[:,i])
    flight_numbers += list(rollwin_outgoing_flights[:,i])
    continents_name += [continents[i]] * len(roll_win_fitted_means_normalized)

print(len(fitted_means))
print(len(flight_numbers))
print(len(continents_name))

visual_df = pd.DataFrame({
    "Fitted Distribution Sensitivity (Scaled to 0 - 1)": fitted_means,
    "Summed Outgoing Flights Over Rolling Window": flight_numbers,
    "Continent": continents_name
})

plt.figure(figsize=(8,8))
plt.rcParams.update({'font.size': 16})
sns.scatterplot(data=visual_df, y="Fitted Distribution Sensitivity (Scaled to 0 - 1)", x="Summed Outgoing Flights Over Rolling Window", hue="Continent")
plt.title("{} {}-day {} Models Flights vs Fitted Sensitivity".format(MODEL_NAME, WINDOW_SIZE, num_models_after_filtering), fontsize=18)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.savefig("./{}_models_fitted_mu_vs_flights_scatterplot.png".format(num_models_after_filtering), bbox_inches="tight", facecolor="white")
plt.clf()
plt.close()

### Plotting Fitted Distribution Mu Parameter Against Num Cases Scatterplot

In [None]:
num_cases_matrix = datasetv19['feature_matrix_smooth'][:,:,1]
print(num_cases_matrix.max())
print(num_cases_matrix.min())
print(num_cases_matrix.mean())
num_cases_matrix.shape

In [None]:
# Scale back up to normal scale, cannot sum log values
num_cases_matrix = np.power(10, num_cases_matrix, where=num_cases_matrix != 0)

In [None]:
rollwin_ncases = [num_cases_matrix[idx: idx + 30 + WINDOW_SIZE, :].sum(axis=0) for idx in range(len(num_cases_matrix) - 30 - WINDOW_SIZE - WINDOW_SIZE)]
rollwin_ncases = np.array(rollwin_ncases)
rollwin_ncases.shape

In [None]:
# Uncomment if want to log10 transform summation back down to log10 scale
# rollwin_ncases = np.log10(rollwin_ncases, where=rollwin_ncases != 0)

In [None]:
roll_win_fitted_means_normalized.shape

In [None]:
fitted_mus = []
ncases = []
continents_name = []

for i in range(10):
    fitted_mus += list(roll_win_fitted_means_normalized[:,i])
    ncases += list(rollwin_ncases[:,i])
    continents_name += [continents[i]] * len(roll_win_fitted_means_normalized)

print(len(fitted_mus))
print(len(ncases))
print(len(continents_name))

visual_df = pd.DataFrame({
    "Fitted Distribution Sensitivity (Scaled to 0 - 1)": fitted_mus,
    "Summed Cases Over Rolling Window": ncases,
    "Continents": continents_name
})

plt.figure(figsize=(8,8))
plt.rcParams.update({'font.size': 16})
sns.scatterplot(data=visual_df, x="Summed Cases Over Rolling Window", y="Fitted Distribution Sensitivity (Scaled to 0 - 1)", hue="Continents")
plt.title("{} {}-day Cases vs Fitted Distribution Sensitivity".format(MODEL_NAME, WINDOW_SIZE))
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.savefig("./{}_models_fitted_mu_vs_ncases_scatterplot.png".format(num_models_after_filtering), bbox_inches="tight", facecolor="white")
plt.clf()
plt.close()

## Plot Containment Index vs Sensitivity

In [None]:
containment_matrix = datasetv19['feature_matrix_smooth'][:,:,0]
print(containment_matrix.max())
print(containment_matrix.min())
print(containment_matrix.mean())
containment_matrix.shape

In [None]:
# Scale back up to normal scale, cannot sum log values
containment_matrix = np.power(10, containment_matrix, where=containment_matrix != 0)
print(containment_matrix.max())
print(containment_matrix.min())
print(containment_matrix.mean())

In [None]:
# Average containment index across rolling window
rollwin_containment = [containment_matrix[idx: idx + 30 + WINDOW_SIZE, :].mean(axis=0) for idx in range(len(containment_matrix) - 30 - WINDOW_SIZE - WINDOW_SIZE)]
rollwin_containment = np.array(rollwin_containment)
rollwin_containment.shape

In [None]:
# Uncomment if want to log10 transform mean containment back down to log10 scale
# rollwin_containment = np.log10(rollwin_containment, where=rollwin_containment != 0)

In [None]:
roll_win_fitted_means_normalized.shape

In [None]:
fitted_mus = []
containment = []
continents_name = []

for i in range(10):
    fitted_mus += list(roll_win_fitted_means_normalized[:,i])
    containment += list(rollwin_containment[:,i])
    continents_name += [continents[i]] * len(roll_win_fitted_means_normalized)

print(len(fitted_mus))
print(len(containment))
print(len(continents_name))

visual_df = pd.DataFrame({
    "Fitted Distribution Sensitivity (Scaled to 0 - 1)": fitted_mus,
    "Averaged Containment Over Rolling Window (in millions)": containment,
    "Continents": continents_name
})

plt.figure(figsize=(8,8))
plt.rcParams.update({'font.size': 16})
sns.scatterplot(data=visual_df, x="Averaged Containment Over Rolling Window (in millions)", y="Fitted Distribution Sensitivity (Scaled to 0 - 1)", hue="Continents")
plt.title("{} {}-day Containment vs Fitted Distribution Sensitivity".format(MODEL_NAME, WINDOW_SIZE))
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.savefig("./{}_models_fitted_mu_vs_containment_scatterplot.png".format(num_models_after_filtering), bbox_inches="tight", facecolor="white")
plt.clf()
plt.close()

# Ranking Continents by Number of Flights

In [None]:
def plot_daywise_ranks_by_unscaled_flight_averages(dataset_flight_matrix, flight_type="Outgoing"):
    """
    This function plots the continents rankings over rolling windows in terms of flights. Specify incoming, outgoing,
    or combined flights in parameters
    
    Args:
        - dataset_flight_matrix: shape (num_windows, 10, 10)
        - flight_type: "Incoming", "Outgoing", or "Combined"
    """
    assert flight_type in ["Incoming", "Outgoing", "Combined"]
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
                'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
    
    if flight_type == "Outgoing":
        daywise_flights = np.nansum(dataset_flight_matrix, axis=1)
    elif flight_type == "Incoming":
        daywise_flights = np.nansum(dataset_flight_matrix, axis=2)
    elif flight_type == "Combined":
        daywise_flights = np.nansum(dataset_flight_matrix, axis=1) + np.nansum(dataset_flight_matrix, axis=2)

    # Compute ranks
    ranks = np.zeros(daywise_flights.shape)
    for idx in range(len(daywise_flights)):
        array = daywise_flights[idx,:]
        temp = (-array).argsort()
        ranks[idx,:] = np.arange(len(array))[temp.argsort()] + 1

    plt.figure(figsize=(20, 6))
    for i in range(10):
        plt.plot(ranks[:,i], "o-", mfc="w", label=continents[i], color=colors[i])

    plt.gca().invert_yaxis()
    plt.title("Daywise Ranking by Total Number of {} Flights (Unfiltered)".format(flight_type), fontsize=18)
    plt.xlabel("Day Index")
    plt.ylabel("Ranking")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    filename = "daywise_ranking_by_total_{}_flights_unfiltered".format(flight_type)
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_daywise_ranks_by_unscaled_flight_averages(datasetv19['flight_matrix_unscaled'], flight_type="Outgoing")
plot_daywise_ranks_by_unscaled_flight_averages(datasetv19['flight_matrix_unscaled'], flight_type="Incoming")
plot_daywise_ranks_by_unscaled_flight_averages(datasetv19['flight_matrix_unscaled'], flight_type="Combined")

In [None]:
def plot_rollwin_ranks_by_unscaled_flight_averages(dataset_flight_matrix, flight_type="Outgoing"):
    """
    This function plots the continents rankings over rolling windows in terms of flights. Specify incoming, outgoing,
    or combined flights in parameters
    
    Args:
        - dataset_flight_matrix: shape (num_windows, 10, 10)
        - flight_type: "Incoming", "Outgoing", or "Combined"
    """
    assert flight_type in ["Incoming", "Outgoing", "Combined"]
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
                'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
    
    if flight_type == "Outgoing":
        daywise_flights = np.nansum(dataset_flight_matrix, axis=1)
    elif flight_type == "Incoming":
        daywise_flights = np.nansum(dataset_flight_matrix, axis=2)
    elif flight_type == "Combined":
        daywise_flights = np.nansum(dataset_flight_matrix, axis=1) + np.nansum(dataset_flight_matrix, axis=2)

    roll_win_flights = [daywise_flights[idx: idx+30+WINDOW_SIZE, :].sum(axis=0) for idx in range(len(daywise_flights) - 30 - WINDOW_SIZE - WINDOW_SIZE)]
    roll_win_flights = np.array(roll_win_flights)

    # Compute ranks
    ranks = np.zeros(roll_win_flights.shape)
    for idx in range(len(roll_win_flights)):
        array = roll_win_flights[idx,:]
        temp = (-array).argsort()
        ranks[idx,:] = np.arange(len(array))[temp.argsort()] + 1

    plt.figure(figsize=(20, 6))
    for i in range(10):
        plt.plot(ranks[:,i], "o-", mfc="w", label=continents[i], color=colors[i])

    plt.gca().invert_yaxis()
    plt.title("Rolling Window Ranking by Total Number of {} Flights (Summed Over Rolling Window, Unfiltered)".format(flight_type), fontsize=18)
    plt.xlabel("Rolling Window Index")
    plt.ylabel("Ranking")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    filename = "roll_win_ranking_by_total_{}_flights_unfiltered".format(flight_type)
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor='white')
    plt.clf()
    plt.close()

In [None]:
plot_rollwin_ranks_by_unscaled_flight_averages(datasetv19['flight_matrix_unscaled'], flight_type="Outgoing")
plot_rollwin_ranks_by_unscaled_flight_averages(datasetv19['flight_matrix_unscaled'], flight_type="Incoming")
plot_rollwin_ranks_by_unscaled_flight_averages(datasetv19['flight_matrix_unscaled'], flight_type="Combined")