# Interactive visualisation of the prediction for a single subject

In [None]:
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from random import randint
import pandas as pd
import pickle

from matplotlib.lines import Line2D
from prediction.utils.utils import smooth, filter_consecutive_numbers
from prediction.utils.visualisation_helper_functions import reverse_normalisation_for_subj, LegendTitle
from preprocessing.preprocessing_tools.normalisation.reverse_normalisation import reverse_normalisation
from prediction.utils.shap_helper_functions import check_shap_version_compatibility
from sklearn.model_selection import train_test_split
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time, \
    link_patient_id_to_outcome, features_to_numpy, numpy_to_lookup_table
from prediction.utils.visualisation_helper_functions import density_jitter
from matplotlib.legend_handler import HandlerTuple

In [None]:
# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'
normalisation_parameters_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/logs_01012023_233050/normalisation_parameters.csv'
cat_encoding_path = os.path.join(os.path.dirname(features_path), f'logs_{os.path.basename(features_path).split(".")[0].split("_")[-2]}_{os.path.basename(features_path).split(".")[0].split("_")[-1]}/categorical_variable_encoding.csv')


shap_over_time_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/explanability/transformer_explainer_shap_values_over_ts_3m_mrs02_captum_n1449_all_72_cv2.pkl'
mrs02_predictions_over_time_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/predictions_over_timesteps_cv2.pkl'
death_predictions_over_time_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/predictions_over_timesteps_cv1.pkl'

out_dir = '/Users/jk1/Downloads'

In [None]:
outcome = '3M mRS 0-2'
test_size = 0.2
seed = 42
n_splits = 5
n_time_steps = 72
total_n_features = 84

In [None]:
n_features = 3

In [None]:
# load the shap values
with open(shap_over_time_path, 'rb') as handle:
    original_shap_values = pickle.load(handle)

In [None]:
only_last_timestep = False
if only_last_timestep:
    # use predictions from last timestep (as it also produces output for other timesteps)
    shap_values = [original_shap_values[-1]]
else:
    shap_values = [np.array([original_shap_values[i][:, -1, :] for i in range(len(original_shap_values))]).swapaxes(0, 1)]
shap_values_over_time = np.array(shap_values[0]).swapaxes(0, 1)

In [None]:
normalisation_parameters_df = pd.read_csv(normalisation_parameters_path)

In [None]:
with open(mrs02_predictions_over_time_path, 'rb') as handle:
    predictions_over_time = pickle.load(handle)
predictions_over_time = np.array([prediction.numpy() for prediction in predictions_over_time])

with open(death_predictions_over_time_path, 'rb') as handle:
    death_predictions_over_time = pickle.load(handle)

death_predictions_over_time = np.array([prediction.numpy() for prediction in death_predictions_over_time])

## Load data

In [None]:
X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=outcome)

# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)
# Preprocess overall train data
train_X_df = X[X.patient_id.isin(pid_train)]
train_y_df = y[y.patient_id.isin(pid_train)]
train_X_np = features_to_numpy(train_X_df,
                               ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
train_y_np = np.array([train_y_df[train_y_df.case_admission_id == cid].outcome.values[0] for cid in
                       train_X_np[:, 0, 0, 0]]).astype('float32')
train_X_np = train_X_np[:, :, :, -1].astype('float32')


# Preprocess overall test data
test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]
test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')
# create look-up table for case_admission_ids, sample_labels and relative_sample_date_hourly_cat
test_features_lookup_table = numpy_to_lookup_table(test_X_np)
# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')

In [None]:
death_X, death_y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome="3M Death")

# Reduce every patient to a single outcome (to avoid duplicates)
death_all_pids_with_outcome = link_patient_id_to_outcome(death_y, "3M Death")
death_pid_train, death_pid_test, death_y_pid_train, death_y_pid_test = train_test_split(death_all_pids_with_outcome.patient_id.tolist(),
                                                                death_all_pids_with_outcome.outcome.tolist(),
                                                                stratify=death_all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)
# Preprocess overall test data
death_test_y_df = death_y[death_y.patient_id.isin(death_pid_test)]
overlapping_death_test_y_df = death_test_y_df[death_test_y_df.patient_id.isin(pid_test)]


In [None]:
death_y_df = death_y[(death_y.patient_id.isin(death_pid_test))]

In [None]:
X_test, y_test = test_X_np, test_y_np

In [None]:
original_features = list(test_features_lookup_table['sample_label'])
feature_names = np.array(original_features)

In [None]:
non_normalised_train_X_df = reverse_normalisation(train_X_df, normalisation_parameters_df)

In [None]:
# Find subject indices in test pop that also exist in death_test population
subjs_in_both_test_sets = test_y_df.reset_index()[test_y_df.reset_index().case_admission_id.isin(death_y_df.reset_index().case_admission_id)].index


## Create working data frame
Join data in a common dataframe with shap values and feature values

In [None]:
reverse_categorical_encoding = True
pool_hourly_split_values = True

In [None]:
shap_values_df = pd.DataFrame()
for ts in range(n_time_steps):
    ts_shap_values_df = pd.DataFrame(data=np.array(shap_values[0]).swapaxes(0, 1)[ts], columns = np.array(original_features))
    ts_shap_values_df = ts_shap_values_df.reset_index()
    ts_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
    ts_shap_values_df = ts_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')
    ts_shap_values_df['time_step'] = ts
    shap_values_df = shap_values_df.append(ts_shap_values_df)


In [None]:
feature_values_df = pd.DataFrame()
for subj_idx in range(test_X_np.shape[0]):
    subj_feature_values_df = pd.DataFrame(data=test_X_np[subj_idx, :, :], columns = np.array(feature_names))
    subj_feature_values_df = reverse_normalisation_for_subj(subj_feature_values_df, normalisation_parameters_df)
    subj_feature_values_df = subj_feature_values_df.reset_index()
    subj_feature_values_df.rename(columns={'index': 'time_step'}, inplace=True)
    subj_feature_values_df['case_admission_id_idx'] = subj_idx
    subj_feature_values_df = subj_feature_values_df.melt(id_vars=['case_admission_id_idx', 'time_step'],  var_name='feature', value_name='feature_value')
    feature_values_df = feature_values_df.append(subj_feature_values_df)

In [None]:
shap_aggregation_func = 'sum' # 'median' or 'sum' (but sum makes mokes more sense)

if reverse_categorical_encoding:
    cat_encoding_df = pd.read_csv(cat_encoding_path)
    for i in range(len(cat_encoding_df)):
        cat_basename = cat_encoding_df.sample_label[i].lower().replace(' ', '_')
        cat_item_list = cat_encoding_df.other_categories[i].replace('[', '').replace(']', '').replace('\'', '').split(', ')
        cat_item_list = [cat_basename + '_' + item.replace(' ', '_').lower() for item in cat_item_list]
        for cat_item_idx, cat_item in enumerate(cat_item_list):
            #  retrieve the dominant category for this subject (0 being default category)
            feature_values_df.loc[feature_values_df.feature == cat_item, 'feature_value'] *= cat_item_idx + 1
            feature_values_df.loc[feature_values_df.feature == cat_item, 'feature'] = cat_encoding_df.sample_label[i]
            feature_values_df = feature_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).sum().reset_index()

            shap_values_df.loc[shap_values_df.feature == cat_item, 'feature'] = cat_encoding_df.sample_label[i]
            # sum the shap and feature values for each subject
            if shap_aggregation_func:
                shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).sum().reset_index()
            else:
                shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).median().reset_index()

    # give a numerical encoding to the categorical features
    cat_to_numerical_encoding = {
        'Prestroke disability (Rankin)': {0:0, 1:5, 2:4, 3:2, 4:1, 5:3},
        'categorical_onset_to_admission_time': {0:1, 1:2, 2:3, 3:4, 4:0},
        'categorical_IVT': {0:2, 1:3, 2:4, 3:1, 4:0},
        'categorical_IAT': {0:1, 1:0, 2:3, 3:2}
    }

    for cat_feature, cat_encoding in cat_to_numerical_encoding.items():
        feature_values_df.loc[feature_values_df.feature == cat_feature, 'feature_value'] = feature_values_df.loc[feature_values_df.feature == cat_feature, 'feature_value'].map(cat_encoding)


In [None]:

# For features that are downsampled to hourly values, pool the values (median, min, max)
if pool_hourly_split_values:
    hourly_split_features = ['NIHSS', 'systolic_blood_pressure', 'diastolic_blood_pressure', 'mean_blood_pressure', 'heart_rate', 'respiratory_rate', 'temperature', 'oxygen_saturation']
    for feature in hourly_split_features:
        shap_values_df.loc[shap_values_df.feature.str.contains(feature), 'feature'] = (feature[0].upper() + feature[1:]).replace('_', ' ')
        if shap_aggregation_func == 'median':
            shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).median().reset_index()
        elif shap_aggregation_func == 'sum':
            shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).sum().reset_index()

        feature_values_df.loc[feature_values_df.feature.str.contains(feature), 'feature'] = (feature[0].upper() + feature[1:]).replace('_', ' ')
        feature_values_df = feature_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).median().reset_index()


In [None]:
# Replace feature names with their english names
feature_to_english_name_correspondence_path = os.path.join(os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))))),
                                                           'preprocessing/preprocessing_tools/feature_name_to_english_name_correspondence.xlsx')
feature_to_english_name_correspondence = pd.read_excel(feature_to_english_name_correspondence_path)

In [None]:
for feature in shap_values_df.feature.unique():
    if feature in feature_to_english_name_correspondence.feature_name.values:
        shap_values_df.loc[shap_values_df.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]

for feature in feature_values_df.feature.unique():
    if feature in feature_to_english_name_correspondence.feature_name.values:
        feature_values_df.loc[feature_values_df.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]

In [None]:
use_simplified_shap_values = True

if use_simplified_shap_values:
    shap_values_over_time = []
    for ts in range(n_time_steps):
        subj_values_over_time = []
        for subj in range(len(test_X_np)):
            subj_values_over_time.append(shap_values_df[(shap_values_df.case_admission_id_idx == subj) & (shap_values_df.time_step == ts)].shap_value.values)
        shap_values_over_time.append(np.array(subj_values_over_time))
    shap_values_over_time = np.array(shap_values_over_time)

    feature_names = shap_values_df.feature.unique()
    feature_names = feature_names

## Choose subject and load prediction

In [None]:
subj = randint(0, len(test_X_np))

In [None]:
subj = 32

In [None]:
print(subj, f'Outcome for {outcome}:', y_test[subj])
print('Predicted probability of mrs02: ', predictions_over_time[-1,subj])

# find index of row in death_y_df where case_admission_id == test_y_df.iloc[subj].case_admission_id
subj_found_in_death_df = death_y_df.reset_index()[(death_y_df.reset_index().case_admission_id == test_y_df.iloc[subj].case_admission_id)]
if len(subj_found_in_death_df) > 0:
    death_idx = subj_found_in_death_df.index[0]
    print('Predicted probability of mrs3-5: ', 1 - predictions_over_time[-1,subj] - death_predictions_over_time[-1,death_idx])
    print('Predicted probability of death: ', death_predictions_over_time[-1,death_idx], f'(Outcome: {death_y_df.reset_index().iloc[death_idx].outcome})')

In [None]:
subj_pred_over_ts = predictions_over_time[:,subj]

## Plot overall subject prediction & explanation

In [None]:
n_features = 3

In [None]:
# plot a bar plot showing impact of most important features on the prediction across all n_time_steps
# find index of 3 features with biggest positive shap impart
selected_positive_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[-n_features:][::-1]

# find index of 3 features with biggest negative shap impart
selected_negative_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))

fig1 = plt.figure(figsize=(15,5))
ax1 = fig1.add_subplot(121)
ax = sns.barplot(y=np.array(feature_names)[selected_features], x=np.squeeze(shap_values_over_time[-1])[subj][selected_features], palette="RdBu_r")
ax.title.set_text(f'SHAP values for subj {subj} ')

if reverse_categorical_encoding:
    non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')
    median_norm_feature_df = non_norm_subj_df.iloc[-1][np.array(feature_names)[selected_features]]
else:
    non_norm_subj_df = reverse_normalisation_for_subj(pd.DataFrame(data=test_X_np[subj], columns = feature_names), normalisation_parameters_df)
    median_norm_feature_df = non_norm_subj_df.iloc[-1][np.array(feature_names)[selected_features]]

ax2 = fig1.add_subplot(122)
font_size=12
bbox=[0, 0, 1, 1]
ax2.axis('off')
cell_text = []
for row in range(len(median_norm_feature_df)):
    cell_text.append([median_norm_feature_df.iloc[row].astype(str)])
mpl_table = ax2.table(cellText = cell_text, rowLabels = median_norm_feature_df.index, bbox=bbox, colLabels=['Normalised value'], cellLoc='center', colLoc='center', loc='center')
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)

fig1.set_tight_layout(True)
# set figure title
fig1.suptitle(f'Explanation of prediction for subj {subj} with a probability of good outcome of {subj_pred_over_ts[-1]:.2f}', fontsize=20)

plt.show()

In [None]:
# fig1.savefig(os.path.join(out_dir, 'final_prediction.png'), dpi=600)

In [None]:
fig1 = plt.figure(figsize=(15,5))
ax1 = fig1.add_subplot(111)

# plot a bar plot showing impact of most important features on the prediction across all n_time_steps
# find index of 3 features with biggest positive shap impart
selected_positive_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[-n_features:][::-1]
# find index of 3 features with biggest negative shap impart
selected_negative_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[:n_features][::-1]
selected_features = np.concatenate((selected_positive_features, selected_negative_features))

ax = sns.barplot(y=np.array(feature_names)[selected_features], x=np.squeeze(shap_values_over_time[-1])[subj][selected_features], palette="RdBu_r", ax=ax1)

non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')



## Plot relevant features in relation to training population

In [None]:
fig2 = plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Selected features", fontsize=18, y=0.99, x=0.52, horizontalalignment='center')

# set number of columns (use 3 to demonstrate the change)
ncols = 3
# calculate number of rows
nrows = len(selected_features) // ncols + (len(selected_features) % ncols > 0)

# loop through the length of features and keep track of index
for n, feature in enumerate(selected_features):
    # add a new subplot iteratively using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)

    temp_pop_df = non_normalised_train_X_df[non_normalised_train_X_df.sample_label == feature_names[feature % total_n_features]]
    sns.histplot(temp_pop_df.value, ax=ax)
    plt.scatter(median_norm_feature_df[feature_names[feature % total_n_features]], 0, marker='o', s=500)
    if (n % ncols) == 1:
        if n <= len(selected_features) / 2:
            ax.set_title(r"$\bf{Positive\ features}$" +f'\n\n{feature_names[feature % total_n_features]}')
        else:
            ax.set_title(r"$\bf{Negative\ features}$" + f'\n\n{feature_names[feature % total_n_features]}')

    else:
        ax.set_title(feature_names[feature % total_n_features])

plt.tight_layout()

In [None]:
# fig2.savefig(os.path.join(out_dir, 'features_histogram_comparison.png'), dpi=600)

In [None]:
fig2_5 = plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Selected features", fontsize=18, y=0.99, x=0.52, horizontalalignment='center')

overlay_population_scatter = True
overlay_subject_scatter = False
plot_legend = True
shap_value_base = 1000
label_font_size = 12

# set number of columns (use 3 to demonstrate the change)
ncols = n_features
# calculate number of rows
nrows = 2

# loop through the length of features and keep track of index
for n, feature in enumerate(selected_features):
    # add a new subplot iteratively using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)

    temp_pop_df = non_normalised_train_X_df[non_normalised_train_X_df.sample_label == feature_names[feature % total_n_features]].groupby('case_admission_id').median()
    sns.violinplot(temp_pop_df.value, ax=ax, color=feature_color_dict[feature], alpha=0.1, inner=None)
    plt.setp(ax.collections, alpha=.1)

    if overlay_population_scatter:
        ys = density_jitter(temp_pop_df.value.values, width=0.25, cluster_factor=1)
        ax.scatter(0 + ys, temp_pop_df.value, alpha=0.005, color='grey')

    if overlay_subject_scatter:
        ys = density_jitter(non_norm_subj_df[np.array(feature_names)[feature % total_n_features]].values, width=0.5)
        plt.scatter(0 + ys, non_norm_subj_df[np.array(feature_names)[feature % total_n_features]].values, marker='o', zorder=10, color='grey', alpha=0.1)

    # set weight of marker based on shap value
    marker_size = np.abs(shap_values_over_time[-1, subj, feature]) * shap_value_base

    # set z-order to make sure the scatter plot is on top
    plt.scatter(0, median_norm_feature_df[feature_names[feature % total_n_features]], marker='o', s=marker_size, zorder=10, color=feature_color_dict[feature], alpha=0.8)
    if (n % ncols) == 1:
        if n <= len(selected_features) / 2:
            ax.set_title(r"$\bf{Positive\ features}$" +f'\n\n{feature_names[feature % total_n_features]}')
        else:
            ax.set_title(r"$\bf{Negative\ features}$" + f'\n\n{feature_names[feature % total_n_features]}')

    else:
        ax.set_title(feature_names[feature % total_n_features])

    ax.set_xlim(-1, 1)
    ax.set_ylim(0, temp_pop_df.value.max() + temp_pop_df.value.max() / 5)
    # turn off x axis
    # ax.set_xticks([])

    # add a legend for the shap value marker size on the bottom right of the last plot
    if (n == len(selected_features) /2 - 1) & plot_legend:
        legend_markers, legend_labels = ax.get_legend_handles_labels()
        subj_value_1 = plt.scatter([], [], marker='o', s=shap_value_base/4, color=feature_color_dict[selected_features[0]], alpha=0.8)
        subj_value_2 = plt.scatter([], [], marker='o', s=shap_value_base/8, color=feature_color_dict[selected_features[1]], alpha=0.8)
        subj_value_3 = plt.scatter([], [], marker='o', s=shap_value_base/16, color=feature_color_dict[selected_features[2]], alpha=0.8)
        subj_marker = (subj_value_1, subj_value_2, subj_value_3)
        subj_labels = 'Subject feature value\n(Size proportional to\nweight in model)'
        legend_markers.append(subj_marker)
        legend_labels.append(subj_labels)

        violin_patch1 = mpatches.Patch(color=feature_color_dict[selected_features[0]], alpha=0.1)
        violin_patch2 = mpatches.Patch(color=feature_color_dict[selected_features[1]], alpha=0.1)
        violin_patch3 = mpatches.Patch(color=feature_color_dict[selected_features[2]], alpha=0.1)
        violin_plot_marker = (violin_patch1, violin_patch2, violin_patch3)
        violin_plot_label = 'Distribution in\ntraining population'
        legend_markers.append(violin_plot_marker)
        legend_labels.append(violin_plot_label)


        if overlay_population_scatter:
            # add a legend for the violin contour plots (population)
            legend_markers.append(plt.scatter([], [], marker='o', s=10, color='grey', alpha=0.3))
            legend_labels.append('Individual values in\ntraining population')

        # plot legend outside of right side of plot
        # and avoid that markers are on the legend box border
        ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)},
                    bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.,
                    handleheight=2, handlelength=4)

plt.tight_layout()


## Smooth shap values over time

In [None]:
smoothing_window = 15
smoothed_shap_values_over_time = []
for subj_idx in range(shap_values_over_time.shape[1]):
    subj_smoothed_shap_values_over_time = []
    for feature_idx in range(shap_values_over_time.shape[2]):
        subj_smoothed_shap_values_over_time.append(smooth(shap_values_over_time[:, subj_idx, feature_idx], smoothing_window))
    smoothed_shap_values_over_time.append(np.moveaxis(subj_smoothed_shap_values_over_time, 0, -1))
smoothed_shap_values_over_time = np.moveaxis(smoothed_shap_values_over_time, 0, 1)

## Plot evolution of prediction & explanation over time

In [None]:
overall_prevailing_features = False
weigh_by_feature_value = False
use_smoothed_shap_values = True

In [None]:
if use_smoothed_shap_values:
    working_shap_values = smoothed_shap_values_over_time
else:
    working_shap_values = shap_values_over_time

cumulative_shap_values_over_time = np.array([working_shap_values[ts].sum(axis=1) for ts in range(n_time_steps)])

# find index of 3 features with biggest positive shap impart & index of 3 features with biggest negative shap impart
if overall_prevailing_features:
    # prevailing features over cumulative time
    selected_negative_features = cumulative_shap_values_over_time[:, subj].argsort()[:n_features][::-1]
    selected_positive_features = cumulative_shap_values_over_time[:, subj].argsort()[-n_features:][::-1]
else:
    # prevailing features at last timepoint
    selected_positive_features = np.squeeze(working_shap_values[-1])[subj].argsort()[-n_features:][::-1]
    selected_negative_features = np.squeeze(working_shap_values[-1])[subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))

fig3 = plt.figure(figsize=(15,10))

k=0.1
alpha=0.3

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(selected_positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(selected_negative_features))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='probability', linewidth = 2)


pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
feature_color_dict = {}

for i, feature in enumerate(selected_features):
    subj_feature_shap_value_over_time = working_shap_values[:, subj, feature]
    positive_portion = (subj_feature_shap_value_over_time > 0)
    negative_portion = (subj_feature_shap_value_over_time < 0)

    pos_function = subj_feature_shap_value_over_time.copy()
    pos_function[negative_portion] = 0

    neg_function = subj_feature_shap_value_over_time.copy()
    neg_function[positive_portion] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    feature_color_dict[feature] = feature_color

    positive_feature = pos_baseline + k * pos_function
    if weigh_by_feature_value:
        positive_feature *= X_test[subj, :, feature] / X_test[:, :, feature].max()
    ax.fill_between(timestep_axis, pos_baseline, positive_feature, color=feature_color, alpha=alpha, label=feature_names[feature])
    pos_baseline = positive_feature

    negative_feature = neg_baseline + k * neg_function
    if weigh_by_feature_value:
        negative_feature *= X_test[subj, :, feature] / X_test[:, :, feature].max()
    ax.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
    neg_baseline = negative_feature

ax.legend(fontsize='x-large')

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

# ax.set_ylim(0,1)

plt.plot()



In [None]:
# fig3.savefig(os.path.join(out_dir, 'prediction_over_time.png'), dpi=600)

## Plot selected features over time

In [None]:
fig4 = plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Selected features", fontsize=18, y=0.99, x=0.52, horizontalalignment='center')

# set number of columns (use 3 to demonstrate the change)
ncols = 3
# calculate number of rows
nrows = len(selected_features) // ncols + (len(selected_features) % ncols > 0)

pos_count = 0
neg_count = 0
# loop through the length of features and keep track of index
for n, feature in enumerate(set(selected_features)):
    # add a new subplot iteratively using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    elif feature in selected_negative_features:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_negative_features_by_impact:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_positive_features_by_impact:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1

    sns.lineplot(y=feature_names[feature % total_n_features], x=non_norm_subj_df.index.name, data=non_norm_subj_df.reset_index(), color=feature_color, ax=ax)

    if (n % ncols) == 1:
        if n <= len(selected_features) / 2:
            ax.set_title(r"$\bf{Positive\ features}$" +f'\n\n{feature_names[feature % total_n_features]}')
        else:
            ax.set_title(r"$\bf{Negative\ features}$" + f'\n\n{feature_names[feature % total_n_features]}')
    else:
        ax.set_title(feature_names[feature])
plt.tight_layout()


In [None]:
# fig4.savefig(os.path.join(out_dir, 'features_over_time.png'), dpi=600)

## Identify features driving changes in prediction over time

In [None]:
threshold = 0.02
n_features_selection = 0
k=0.25
alpha=0.3
only_non_static_features = True
use_smoothed_shap_values = True

display_significant_slopes = True
n_slope_steps = 5
slope_threshold = 1.5 * threshold

display_text_labels = True
display_legend = True
display_title = False
plot_NIHSS_continuously = True
ts_marker_level = 'baseline' # 'baseline' (marker on predicted probability function) or 'shap' (marker on SHAP value function)

tick_label_size = 13
label_font_size = 16

if use_smoothed_shap_values:
    working_shap_values = smoothed_shap_values_over_time
else:
    working_shap_values = shap_values_over_time

# identify significant changes in prediction over time by a change in threshold X% in prediction
significant_positive_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) > threshold)[0])
significant_negative_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) < -threshold)[0])
significant_timesteps = np.concatenate((significant_positive_timesteps, significant_negative_timesteps))

non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')

# for each timestep, identify the feature that has the largest impact on the prediction
if only_non_static_features:
    # find non static columns in non normed df
    non_static_features = np.where(non_norm_subj_df.std() > 0.01)[0]
    if use_simplified_shap_values:
       non_static_features = np.where(np.isin(feature_names, np.array(non_norm_subj_df.std()[non_norm_subj_df.std() > 0.01].index)))[0]
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_positive_features_by_impact = non_static_features[selected_positive_features_by_impact]
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_negative_timesteps].argmin(axis=1)
    selected_negative_features_by_impact = non_static_features[selected_negative_features_by_impact]
else:
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_negative_timesteps].argmin(axis=1)

selected_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_negative_features_by_impact))

if display_significant_slopes:
    # identify features that are driving the change in prediction over time more gentle slopes (then filter out consecutive timesteps)
    significant_positive_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] > slope_threshold)[0]).difference(set(significant_positive_timesteps)))

    significant_negative_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] < -slope_threshold)[0]).difference(set(significant_negative_timesteps)))

    delta_shap_by_features = np.concatenate((working_shap_values[n_slope_steps:, subj, non_static_features], np.zeros((n_slope_steps, len(non_static_features))))) - working_shap_values[:, subj, non_static_features]
    selected_positive_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_positive_slope].argmax(axis=1)
    selected_positive_features_by_slope = non_static_features[selected_positive_features_by_slope]
    selected_negative_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_negative_slope].argmin(axis=1)
    selected_negative_features_by_slope = non_static_features[selected_negative_features_by_slope]

    selected_features_by_impact = np.concatenate((selected_features_by_impact, selected_positive_features_by_slope, selected_negative_features_by_slope))
    significant_timesteps = np.concatenate((significant_timesteps, significant_positive_slope, significant_negative_slope))
    selected_positive_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_positive_features_by_slope))
    selected_negative_features_by_impact = np.concatenate((selected_negative_features_by_impact, selected_negative_features_by_slope))

if n_features_selection == 0:
    selected_positive_features = np.array([])
    selected_negative_features = np.array([])
else:
    selected_positive_features = working_shap_values[-1,subj].argsort()[-n_features:][::-1]
    selected_negative_features = working_shap_values[-1,subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_positive_features_by_impact, selected_negative_features, selected_negative_features_by_impact)).astype(int)

fig3 = plt.figure(figsize=(15,10))


positive_color_palette = sns.color_palette("mako", n_colors=len(set(np.concatenate((selected_positive_features, selected_positive_features_by_impact)))))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(set(np.concatenate((selected_negative_features, selected_negative_features_by_impact)))))

# plot prediction over time
timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='Predicted probability', linewidth = 2)

pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
feature_color_dict = {}
for i, feature in enumerate(set(selected_features)):
    subj_feature_shap_value_over_time = working_shap_values[:, subj, feature]
    positive_portion = (subj_feature_shap_value_over_time > 0)
    negative_portion = (subj_feature_shap_value_over_time < 0)


    pos_function = subj_feature_shap_value_over_time.copy()
    neg_function = subj_feature_shap_value_over_time.copy()
    pos_function[negative_portion] = 0
    neg_function[positive_portion] = 0

    if feature in selected_features_by_impact:
        important_ts_idx = np.where(selected_features_by_impact == feature)[0]
        # set value to zero before the significant timestep (except for NIHSS if plotting continuously)
        if not np.logical_and(plot_NIHSS_continuously, feature_names[feature] == 'NIHSS'):
            pos_function[:significant_timesteps[important_ts_idx][0] + 1] = 0
            neg_function[:significant_timesteps[important_ts_idx][0] + 1] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    elif feature in selected_negative_features:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_negative_features_by_impact:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_positive_features_by_impact:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    feature_color_dict[feature] = feature_color

    if np.any(pos_function):
        positive_feature = pos_baseline + k * pos_function
        ax.fill_between(timestep_axis , pos_baseline, positive_feature, color=feature_color, alpha=alpha)
        pos_baseline = positive_feature

    if np.any(neg_function):
        negative_feature = neg_baseline + k * neg_function
        ax.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
        neg_baseline = negative_feature

    # add a legend entry for the feature fill
    ax.scatter([], [], color=feature_color, alpha=alpha, label=feature_names[feature],marker="s", s=200)


# marking inflection points
for feature in set(selected_features_by_impact):
    important_ts_idx = np.where(selected_features_by_impact == feature)[0]
    for ts_idx in important_ts_idx:
        # downward inflection point
        if subj_pred_over_ts[significant_timesteps[ts_idx]] > subj_pred_over_ts[significant_timesteps[ts_idx] + 1]:
            marker = 'v'
            if ts_marker_level == 'shap':
                marker_y_level = pos_baseline[significant_timesteps[ts_idx]] + 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] + 0.005
            text_y_level = marker_y_level + 0.01
        # upward inflection point
        else:
            marker = '^'
            if ts_marker_level == 'shap':
                marker_y_level = neg_baseline[significant_timesteps[ts_idx]] - 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] - 0.005
            text_y_level = marker_y_level - 0.015

        ax.scatter(significant_timesteps[ts_idx], marker_y_level, color=feature_color_dict[feature], s=100, marker=marker, alpha=1, edgecolors='white')
        # insert a label on the plot
        if display_text_labels:
            ax.text(significant_timesteps[ts_idx]+ 0.01, text_y_level, feature_names[feature], fontsize=12, color='black')


if display_title:
    ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)

ax.set_xlabel('Time from admission (hours)', fontsize=label_font_size)
ax.set_ylabel('Probability of favorable outcome', fontsize=label_font_size)
ax.tick_params(axis='both', labelsize=tick_label_size)

if display_legend:
    legend_markers, legend_labels = ax.get_legend_handles_labels()

    # shap value shades
    shap_shades_markers = legend_markers[1:]
    shap_shades_labels = legend_labels[1:]
    legend_markers = [legend_markers[0]]
    legend_labels = [legend_labels[0]]

    # add a legend entry for the timestep markers
    ts_marker_down = plt.scatter([], [], marker='v', color='grey', s=50, alpha=0.8)
    ts_marker_up = plt.scatter([], [], marker='^', color='grey', s=50, alpha=0.8)
    ts_label = 'Positive / Negative impact on inflection of prediction'
    legend_markers.append((ts_marker_up, ts_marker_down))
    legend_labels.append(ts_label)

    # Add a subtitle for shape value shades
    legend_markers.append('')
    legend_labels.append('')
    legend_markers.append('Weight & direction of influence on model prediction')
    legend_labels.append('')

    legend_markers += shap_shades_markers
    legend_labels += shap_shades_labels

    ax.legend(legend_markers, legend_labels, fontsize=label_font_size, title='Influence on model prediction', title_fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None), str: LegendTitle({'fontsize': label_font_size})})

plt.plot()

In [None]:
# fig3.savefig(os.path.join(out_dir, f'prediction_inflection_points_{subj}_nf{shap_values_over_time.shape[-1]}_k{k}_t{threshold}_text{int(display_text_labels)}.svg'), dpi=1200, bbox_inches='tight')

Display shap value of given feature over time

In [None]:
shap_val_df = pd.DataFrame(working_shap_values[:, subj, :], columns=feature_names)

In [None]:
shap_val_df['Total Cholesterol'].plot()

In [None]:
shap_val_df['NIHSS'].plot()

Find interesting subjects

In [None]:
# find subj with top difference in np.abs(predictions_over_time[0,:] - predictions_over_time[-1,:])
top_10_diff_subj = np.argsort(np.abs(predictions_over_time[0,:] - predictions_over_time[-1,:]))[::-1][0:100]
top_10_diff_subj_in_both_sets = np.intersect1d(top_10_diff_subj, subjs_in_both_test_sets)
top_10_diff_subj_in_both_sets

In [None]:
predictions_over_time[-1, top_10_diff_subj_in_both_sets]

In [None]:
# get case admission idx with highes proBNP in feature_values_df
feature_values_df.loc[(feature_values_df['case_admission_id_idx'].isin(top_10_diff_subj_in_both_sets)) & (feature_values_df.feature == 'proBNP')].sort_values(by='feature_value', ascending=False)

Interesting subjects:
- 54

In both test sets:
- 449, 32, 20

Best: 32