# Interactive visualisation of the prediction for a single subject

In [None]:
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from random import randint
import pandas as pd
import pickle
from preprocessing.preprocessing_tools.normalisation.reverse_normalisation import reverse_normalisation
from prediction.utils.shap_helper_functions import check_shap_version_compatibility
from prediction.utils.utils import flatten
from prediction.outcome_prediction.data_loading.data_loader import load_data
from prediction.utils.utils import aggregate_features_over_time
from sklearn.model_selection import train_test_split
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time, \
    link_patient_id_to_outcome, features_to_numpy, numpy_to_lookup_table

In [None]:

# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'
normalisation_parameters_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/logs_01012023_233050/normalisation_parameters.csv'
shap_over_time_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration/testing/tree_explainer_shap_values_over_ts.pkl'
predictions_over_time_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration/testing/predictions_over_timesteps.pkl'
out_dir = '/Users/jk1/Downloads'

In [None]:
outcome = '3M mRS 0-2'
test_size = 0.2
seed = 42
n_splits = 5
moving_average = False
n_time_steps = 72
total_n_features = 84

In [None]:
n_features = 3

In [None]:
with open(shap_over_time_path, 'rb') as handle:
    shap_values_over_time = pickle.load(handle)
shap_values_over_time = np.array(shap_values_over_time)
normalisation_parameters_df = pd.read_csv(normalisation_parameters_path)

In [None]:
with open(predictions_over_time_path, 'rb') as handle:
    predictions_over_time = pickle.load(handle)
predictions_over_time = np.array(predictions_over_time)

## Load data

In [None]:
X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=outcome)

# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)
# Preprocess overall train data
train_X_df = X[X.patient_id.isin(pid_train)]
train_y_df = y[y.patient_id.isin(pid_train)]
train_X_np = features_to_numpy(train_X_df,
                               ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
train_y_np = np.array([train_y_df[train_y_df.case_admission_id == cid].outcome.values[0] for cid in
                       train_X_np[:, 0, 0, 0]]).astype('float32')
train_X_np = train_X_np[:, :, :, -1].astype('float32')


# Preprocess overall test data
test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]
test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')
# create look-up table for case_admission_ids, sample_labels and relative_sample_date_hourly_cat
test_features_lookup_table = numpy_to_lookup_table(test_X_np)
# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')

In [None]:
X_test, y_test = aggregate_features_over_time(test_X_np, test_y_np, moving_average=moving_average)
X_test = X_test.reshape(-1, n_time_steps, X_test.shape[-1]).astype('float32')

In [None]:
original_features = list(test_features_lookup_table['sample_label'])

In [None]:
avg_feature_names = [f'avg_{item}' for item in list(test_features_lookup_table['sample_label'])]
min_feature_names = [f'min_{item}' for item in list(test_features_lookup_table['sample_label'])]
max_feature_names = [f'max_{item}' for item in list(test_features_lookup_table['sample_label'])]

feature_names = flatten([original_features, avg_feature_names, min_feature_names, max_feature_names])

In [None]:
non_normalised_train_X_df = reverse_normalisation(train_X_df, normalisation_parameters_df)

## Choose subject and load prediction

In [None]:
subj = randint(0, len(test_X_np))

In [None]:
subj = 182

In [None]:
print(subj, predictions_over_time[-1,subj])

In [None]:
subj_pred_over_ts = predictions_over_time[:,subj]

## Plot overall subject prediction & explanation

In [None]:
shap_values_over_time.shape

In [None]:
np.squeeze(shap_values_over_time[-1])[subj].argsort()[-n_features:][::-1]

In [None]:
n_features = 3

In [None]:
from prediction.utils.visualisation_helper_functions import reverse_normalisation_for_subj

# plot a bar plot showing impact of most important features on the prediction across all n_time_steps
# find index of 3 features with biggest positive shap impart
selected_positive_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[-n_features:][::-1]

# find index of 3 features with biggest negative shap impart
selected_negative_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))

fig1 = plt.figure(figsize=(15,5))
ax1 = fig1.add_subplot(121)
ax = sns.barplot(y=np.array(feature_names)[selected_features], x=np.squeeze(shap_values_over_time[-1])[subj][selected_features], palette="RdBu_r")
ax.title.set_text(f'SHAP values for subj {subj} ')

non_norm_subj_df = reverse_normalisation_for_subj(pd.DataFrame(data=test_X_np[subj], columns = original_features), normalisation_parameters_df)
# display median of original feature value (and not aggregation) - therefore modulo is taken
median_norm_feature_df = non_norm_subj_df.median(axis=0)[selected_features % total_n_features]

ax2 = fig1.add_subplot(122)
font_size=12
bbox=[0, 0, 1, 1]
ax2.axis('off')
cell_text = []
for row in range(len(median_norm_feature_df)):
    cell_text.append([median_norm_feature_df.iloc[row].astype(str)])
mpl_table = ax2.table(cellText = cell_text, rowLabels = median_norm_feature_df.index, bbox=bbox, colLabels=['Normalised value'], cellLoc='center', colLoc='center', loc='center')
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)

fig1.set_tight_layout(True)
# set figure title
fig1.suptitle(f'Explanation of prediction for subj {subj} with a probability of good outcome of {subj_pred_over_ts[-1]:.2f}', fontsize=20)

plt.show()

In [None]:
# fig1.savefig(os.path.join(out_dir, 'final_prediction.png'), dpi=600)

## Plot relevant features in relation to training population

In [None]:
fig2 = plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Selected features", fontsize=18, y=0.99, x=0.52, horizontalalignment='center')

# set number of columns (use 3 to demonstrate the change)
ncols = 3
# calculate number of rows
nrows = len(selected_features) // ncols + (len(selected_features) % ncols > 0)

# loop through the length of features and keep track of index
for n, feature in enumerate(selected_features):
    # add a new subplot iteratively using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)

    temp_pop_df = non_normalised_train_X_df[non_normalised_train_X_df.sample_label == original_features[feature % total_n_features]]
    sns.histplot(temp_pop_df.value, ax=ax)
    plt.scatter(median_norm_feature_df[original_features[feature % total_n_features]], 0, marker='o', s=500)
    if (n % ncols) == 1:
        if n <= len(selected_features) / 2:
            ax.set_title(r"$\bf{Positive\ features}$" +f'\n\n{original_features[feature % total_n_features]}')
        else:
            ax.set_title(r"$\bf{Negative\ features}$" + f'\n\n{original_features[feature % total_n_features]}')

    else:
        ax.set_title(original_features[feature % total_n_features])

plt.tight_layout()

In [None]:
# fig2.savefig(os.path.join(out_dir, 'features_histogram_comparison.png'), dpi=600)

## Plot evolution of prediction & explanation over time

In [None]:
overall_prevailing_features = False

In [None]:
cumulative_shap_values_over_time = np.array([shap_values_over_time[ts].sum(axis=1) for ts in range(n_time_steps)])

# find index of 3 features with biggest positive shap impart & index of 3 features with biggest negative shap impart
if overall_prevailing_features:
    # prevailing features over cumulative time
    selected_negative_features = cumulative_shap_values_over_time[:, subj].argsort()[:n_features][::-1]
    selected_positive_features = cumulative_shap_values_over_time[:, subj].argsort()[-n_features:][::-1]
else:
    # prevailing features at last timepoint
    selected_positive_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[-n_features:][::-1]
    selected_negative_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))

fig3 = plt.figure(figsize=(15,10))

k=0.05
alpha=0.3

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(selected_positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(selected_negative_features))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='probability', linewidth = 2)


pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
for i, feature in enumerate(selected_features):
    subj_feature_shap_value_over_time = shap_values_over_time[:, subj, feature]
    positive_portion = (subj_feature_shap_value_over_time > 0)
    negative_portion = (subj_feature_shap_value_over_time < 0)

    pos_function = subj_feature_shap_value_over_time.copy()
    pos_function[negative_portion] = 0

    neg_function = subj_feature_shap_value_over_time.copy()
    neg_function[positive_portion] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1

    positive_feature = pos_baseline + k * pos_function
    ax.fill_between(timestep_axis, pos_baseline, positive_feature, color=feature_color, alpha=alpha, label=feature_names[feature])
    pos_baseline = positive_feature

    negative_feature = neg_baseline + k * neg_function
    ax.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
    neg_baseline = negative_feature

ax.legend(fontsize='x-large')

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

# ax.set_ylim(0,1)

plt.plot()



In [None]:
# fig3.savefig(os.path.join(out_dir, 'prediction_over_time.png'), dpi=600)

## Plot selected features over time

In [None]:
fig4 = plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Selected features", fontsize=18, y=0.99, x=0.52, horizontalalignment='center')

# set number of columns (use 3 to demonstrate the change)
ncols = 3
# calculate number of rows
nrows = len(selected_features) // ncols + (len(selected_features) % ncols > 0)

pos_count = 0
neg_count = 0
# loop through the length of features and keep track of index
for n, feature in enumerate(selected_features):
    # add a new subplot iteratively using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    sns.lineplot(y=original_features[feature % total_n_features], x='index', data=non_norm_subj_df.reset_index(), color=feature_color, ax=ax)

    if (n % ncols) == 1:
        if n <= len(selected_features) / 2:
            ax.set_title(r"$\bf{Positive\ features}$" +f'\n\n{original_features[feature % total_n_features]}')
        else:
            ax.set_title(r"$\bf{Negative\ features}$" + f'\n\n{original_features[feature % total_n_features]}')
    else:
        ax.set_title(feature_names[feature])
plt.tight_layout()


In [None]:
# fig4.savefig(os.path.join(out_dir, 'features_over_time.png'), dpi=600)

## Plot contribution of a specific feature

In [None]:
np.array(original_features)

In [None]:
selected_features = ["median_mean_blood_pressure", "median_diastolic_blood_pressure", "median_systolic_blood_pressure"]


In [None]:
selected_features_idx = [np.where(np.array(original_features) == selected_feature)[0][0] for selected_feature in selected_features]
selected_features_idx

In [None]:
cumulative_shap_values_over_time = np.array([shap_values_over_time[ts][0].sum(axis=1) for ts in range(n_time_steps)])
subj_pred_over_ts = predictions_over_time[:,subj]

fig3 = plt.figure(figsize=(15,10))

k=1
alpha=0.3

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(selected_positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(selected_negative_features))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='probability', linewidth = 2)
ax2 = ax.twinx()


pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count, neg_count = 0, 0
for i, feature in enumerate(selected_features_idx):
    subj_cumulative_shap_value_over_time = cumulative_shap_values_over_time[:, subj, feature]
    positive_portion = (subj_cumulative_shap_value_over_time > 0)
    negative_portion = (subj_cumulative_shap_value_over_time < 0)

    pos_function = subj_cumulative_shap_value_over_time.copy()
    pos_function[negative_portion] = 0

    neg_function = subj_cumulative_shap_value_over_time.copy()
    neg_function[positive_portion] = 0

    if sum(subj_cumulative_shap_value_over_time) > 0:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1

    positive_feature = pos_baseline + k * pos_function
    ax.fill_between(timestep_axis, pos_baseline, positive_feature, color=feature_color, alpha=alpha, label=features[feature])
    pos_baseline = positive_feature

    negative_feature = neg_baseline + k * neg_function
    ax.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
    neg_baseline = negative_feature

    sns.scatterplot(y=features[feature], x='index', data=non_norm_subj_df.reset_index(), ax=ax2, legend=False, color=feature_color)

ax.legend(fontsize='x-large')

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

plt.plot()

In [None]:
fig4 = plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Selected features", fontsize=18, y=0.99, x=0.52, horizontalalignment='center')

# set number of columns (use 3 to demonstrate the change)
ncols = 3
# calculate number of rows
nrows = len(selected_features_idx) // ncols + (len(selected_features_idx) % ncols > 0)

pos_count = 0
neg_count = 0
# loop through the length of features and keep track of index
for n, feature in enumerate(selected_features_idx):
    # add a new subplot iteratively using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)

    if sum(subj_cumulative_shap_value_over_time) > 0:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    sns.lineplot(y=features[feature], x='index', data=non_norm_subj_df.reset_index(), color=feature_color, ax=ax)

    ax.set_title(features[feature])
plt.tight_layout()

In [None]:
shap_values_over_time[-1][0].shape

In [None]:
import shap

auto=False

X_test_2D = test_X_np.reshape(-1,n_channels)
shap_values_2D = shap_values_over_time[-1][0].reshape(-1,n_channels)
x_test_2d = pd.DataFrame(data=X_test_2D, columns = features)

for n, feature in enumerate(selected_features_idx):
    if auto == True:
        # automatic choice of interaction
        shap.dependence_plot(features[feature], shap_values_2D, x_test_2d)
    else:
        shap.dependence_plot(features[feature], shap_values_2D, x_test_2d, interaction_index="median_NIHSS")