## SHAP additive explanation models

Requirements:
- TensorFlow 1.14
- Python 3.7
- Protobuf downgrade to 3.20: `pip install protobuf==3.20`
- downgrade h5py to 2.10: `pip install h5py==2.10`
- turn off masking in LSTM

In [None]:
import shap
import numpy as np
import pandas as pd

from prediction.outcome_prediction.LSTM.testing.shap_helper_functions import check_shap_version_compatibility

In [None]:
# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
# print the JS visualization code to the notebook
shap.initjs()

In [None]:
model_weights_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS02/2023_01_02_1057/test_LSTM_sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3/sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3.hdf5'
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
outcome = '3M mRS 0-2'
masking = True
units = 16
activation = 'sigmoid'
dropout = 0.2
layers = 2
optimizer = 'RMSprop'
seed = 42
test_size = 0.20

masking as to be overridden to False for shapley values to be computed

In [None]:
override_masking_value = False

In [None]:
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time

# load the dataset
X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=outcome)

n_time_steps = X.relative_sample_date_hourly_cat.max() + 1
n_channels = X.sample_label.unique().shape[0]

In [None]:
from sklearn.model_selection import train_test_split
from prediction.outcome_prediction.data_loading.data_formatting import features_to_numpy, \
    link_patient_id_to_outcome, numpy_to_lookup_table

# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)

test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]
train_X_df = X[X.patient_id.isin(pid_train)]
train_y_df = y[y.patient_id.isin(pid_train)]

train_X_np = features_to_numpy(train_X_df,
                                 ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
train_y_np = np.array([train_y_df[train_y_df.case_admission_id == cid].outcome.values[0] for cid in
                         train_X_np[:, 0, 0, 0]]).astype('float32')
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')


# create look-up table for case_admission_ids, sample_labels and relative_sample_date_hourly_cat
test_features_lookup_table = numpy_to_lookup_table(test_X_np)
train_features_lookup_table = numpy_to_lookup_table(train_X_np)

# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')
train_X_np = train_X_np[:, :, :, -1].astype('float32')

In [None]:
from prediction.utils.scoring import precision, recall, matthews
from prediction.outcome_prediction.LSTM.LSTM import lstm_generator

model = lstm_generator(x_time_shape=n_time_steps, x_channels_shape=n_channels, masking=override_masking_value, n_units=units,
                           activation=activation, dropout=dropout, n_layers=layers)

model.compile(loss='binary_crossentropy', optimizer=optimizer,
              metrics=['accuracy', precision, recall, matthews])

model.load_weights(model_weights_path)

In [None]:
train_X_np.shape

In [None]:
model.summary()

In [None]:
import tensorflow as tf
tf.__version__

In [None]:
assert tf.__version__ == '1.14.0'

## DeepSHAP


In [None]:
train_X_np.shape, test_X_np.shape

In [None]:
# Use the training data for deep explainer => can use fewer instances
explainer = shap.DeepExplainer(model, train_X_np)
# explain the testing instances (can use fewer instances)
# explaining each prediction requires 2 * background dataset size runs
shap_values = explainer.shap_values(test_X_np)
# init the JS visualization code
shap.initjs()

In [None]:
explainer.expected_value

In [None]:
len(shap_values)

In [None]:
test_X_np.shape

In [None]:
shap_values[0].shape

In [None]:
shap_values[0][0].shape

In [None]:
features = list(test_features_lookup_table['sample_label'].keys())

In [None]:
print(features)
print(len(features))

In [None]:
ts=0
subj=0

In [None]:
shap_values[0][subj][ts]

In [None]:
test_X_np[subj][ts].shape

In [None]:
ts=56
subj=11
x_test_df = pd.DataFrame(data=test_X_np[subj][ts].reshape(1,n_channels), columns = features)
shap.force_plot(explainer.expected_value[0], shap_values[0][subj][ts], x_test_df)

In [None]:
# average over all n_time_steps
x_test_df = pd.DataFrame(data=test_X_np[subj].mean(axis=0).reshape(1,n_channels), columns = features)
shap.force_plot(explainer.expected_value[0], shap_values[0][subj].mean(axis=0), x_test_df)

## Local accuracy: Check sum of shap values vs prediction


In [None]:
shap.__version__

In [None]:
shap_values[0][subj].shape

In [None]:
# verifying local accuracy of explainer model
subj = 11
pred_i = model.predict(test_X_np[subj:subj+1])
sum_shap_i = shap_values[0][subj].sum() + explainer.expected_value[0]

pred_i, sum_shap_i

As expected, these are the same.

In [None]:
from random import randint

# Plot SHAP for ONLY one subj i
subj = randint(0, len(test_X_np))
print(subj, model.predict(test_X_np[subj:subj+1]))

x_test_df = pd.DataFrame(data=test_X_np[subj], columns = features)
shap.force_plot(explainer.expected_value[0], shap_values[0][subj], x_test_df)
## Problem:  Can not take into account many observations at the same time.
### The pic below explain for only 1 subj for 72 time steps, each time step has 85 features.

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# find index of 3 features with biggest positive shap values
selected_positive_features = shap_values[0][subj].mean(axis=0).argsort()[-3:][::-1]

# find index of 3 features with biggest negative shap values
selected_negative_features = shap_values[0][subj].mean(axis=0).argsort()[:3][::-1]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))

# normalize shape values by dividing by average shap value for each time step
normalized_shap_values = shap_values[0][subj] / shap_values[0][subj].mean(axis=0)

fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))
for feature in selected_features:
    sns.scatterplot(y=shap_values[0][subj][:, feature], x=range(n_time_steps), ax=ax1, label=features[feature])
    sns.scatterplot(y=normalized_shap_values[:, feature], x=range(n_time_steps), ax=ax2, label=features[feature])


In [None]:
# plot a bar plot showing impact of most important features on the prediction across all n_time_steps
subj = randint(0, len(test_X_np))

# find index of 3 features with biggest positive shap impart
selected_positive_features = shap_values[0][subj].sum(axis=0).argsort()[-3:][::-1]
np.array(features)[selected_positive_features]

# find index of 3 features with biggest negative shap impart
selected_negative_features = shap_values[0][subj].sum(axis=0).argsort()[:3][::-1]
np.array(features)[selected_negative_features]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))

fig = plt.figure(figsize=(15,5))
ax1 = fig.add_subplot(121)
ax = sns.barplot(y=np.array(features)[selected_features], x=shap_values[0][subj].sum(axis=0)[selected_features], palette="RdBu_r")
ax.title.set_text(f'SHAP values for subj {subj} ')

x_subj_df = pd.DataFrame(data=test_X_np[subj], columns = features)
median_norm_feature_df = x_subj_df.median(axis=0)[selected_features]

ax2 = fig.add_subplot(122)
font_size=12
bbox=[0, 0, 1, 1]
ax2.axis('off')
cell_text = []
for row in range(len(median_norm_feature_df)):
    cell_text.append([median_norm_feature_df.iloc[row].astype(str)])
mpl_table = ax2.table(cellText = cell_text, rowLabels = median_norm_feature_df.index, bbox=bbox, colLabels=['Normalised value'], cellLoc='center', colLoc='center', loc='center')
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)
fig.set_tight_layout(True)
# set figure title
fig.suptitle(f'Explanation of prediction for subj {subj} with a probability of good outcome of {model.predict(test_X_np[subj:subj+1])[0][0]:.2f}', fontsize=20)

plt.show()


In [None]:
################# Plot AVERAGE shap values for ALL subjects  #####################
## Consider ABSOLUTE of SHAP values ##
shap_average_value = np.abs(shap_values[0]).mean(axis=0)

x_average_value = pd.DataFrame(data=test_X_np.mean(axis=0), columns = features)
shap.force_plot(0, shap_average_value, x_average_value)

In [None]:
################# Plot AVERAGE shap values for ALL subjects  #####################
## Consider average (+ is different from -)
shap_average_value = shap_values[0].mean(axis=0)

x_average_value = pd.DataFrame(data=test_X_np.mean(axis=0), columns = features)
shap.force_plot(explainer.expected_value[0], shap_average_value, x_average_value)

### Feature importance
Find most important features by mean absolute SHAP value

In [None]:
ten_most_important_features_by_mean_abs_shap = np.abs(shap_values[0]).mean(axis=(0, 1)).argsort()[::-1][0:13]

In [None]:
np.array(features)[ten_most_important_features_by_mean_abs_shap]

Plot sum of shap value per feature (mean feature value color coded)

In [None]:
shap.summary_plot(shap_values[0].sum(axis=(1))[:, ten_most_important_features_by_mean_abs_shap], pd.DataFrame(data=test_X_np.mean(axis=(1)), columns = features)[np.array(features)[ten_most_important_features_by_mean_abs_shap]],max_display=13, show=True)


### Reduce time dimension

Time dimension can be reduced with reshape, mean and sum

In [None]:
# flatten subjects and time dimension

shap_values_2D = shap_values[0].reshape(-1,n_channels)
X_test_2D = test_X_np.reshape(-1,n_channels)


shap_values_2D.shape, X_test_2D.shape

In [None]:
x_test_2d = pd.DataFrame(data=X_test_2D, columns = features)

In [None]:
x_test_2d.corr()

In [None]:
import matplotlib.pyplot as plt
import os
shap.summary_plot(shap_values_2D, x_test_2d,max_display=100, show=True)
# plt.savefig(os.path.join('/Users/jk1/Downloads', 'shap_summary_plot.png'), bbox_inches='tight')

In [None]:
import seaborn as sns
feature = 11
feature_name = list(test_features_lookup_table['sample_label'].keys())[list(test_features_lookup_table['sample_label'].values()).index(feature)]
sns.scatterplot(x=shap_values_2D[:,feature], y=x_test_2d[feature_name], hue=x_test_2d[feature_name])

In [None]:
import matplotlib.pyplot as plt

n_columns = 4
n_rows = n_channels // n_columns + 1

fig, axes = plt.subplots(n_rows, n_columns, figsize=(n_columns*3.5, n_rows*3.5))

for f in range(n_channels):
    feature_name = list(test_features_lookup_table['sample_label'].keys())[list(test_features_lookup_table['sample_label'].values()).index(f)]
    ax = axes[f//n_columns, f%n_columns]
    sns.scatterplot(y=shap_values_2D[:,f], x=x_test_2d[feature_name], hue=x_test_2d[feature_name], ax=ax)
    ax.set_title(feature_name)

plt.tight_layout()
plt.show()

In [None]:
shap.summary_plot(shap_values_2D, x_test_2d, plot_type="bar", max_display=100)

### Sum over time dimension

In [None]:
shap_values_2D = shap_values[0].sum(axis=1)
X_test_2D = test_X_np.mean(axis=1)
x_test_2d = pd.DataFrame(data=X_test_2D, columns = features)


shap_values_2D.shape, X_test_2D.shape

In [None]:
import matplotlib.pyplot as plt

n_columns = 4
n_rows = n_channels // n_columns + 1

fig, axes = plt.subplots(n_rows, n_columns, figsize=(n_columns*3.5, n_rows*3.5))

for f in range(n_channels):
    feature_name = list(test_features_lookup_table['sample_label'].keys())[list(test_features_lookup_table['sample_label'].values()).index(f)]
    ax = axes[f//n_columns, f%n_columns]
    sns.scatterplot(y=shap_values_2D[:,f], x=x_test_2d[feature_name], hue=x_test_2d[feature_name], ax=ax)
    ax.set_title(feature_name)

plt.tight_layout()
plt.show()

### SHAP dependence plots

In [None]:
shap.dependence_plot("uree", shap_values_2D, x_test_2d, interaction_index="creatinine")

As expected, strong interaction between uree and creatinine.

In [None]:
# automatic choice of interaction
shap.dependence_plot("age", shap_values_2D, x_test_2d)


## SHAP at each timestep

In [None]:
len_test_set = X_test_2D.shape[0]
len_test_set

In [None]:
## SHAP for each time step

for step in range(n_time_steps):
    index = [i for i in list(range(len_test_set)) if i%n_time_steps == step]
    shap_values_2D_step = shap_values_2D[index]
    x_test_2d_step = x_test_2d.iloc[index]
    print("_______ time step {} ___________".format(step))
    shap.summary_plot(shap_values_2D_step, x_test_2d_step, plot_type="bar")
    shap.summary_plot(shap_values_2D_step, x_test_2d_step)
    print("\n")

In [None]:
# plot for last time steps
step = 71
index = [i for i in list(range(len_test_set)) if i%n_time_steps == step]
shap_values_2D_step = shap_values_2D[index]
x_test_2d_step = x_test_2d.iloc[index]
shap.summary_plot(shap_values_2D_step, x_test_2d_step, plot_type="bar", show=True)
plt.close()
shap.summary_plot(shap_values_2D_step, x_test_2d_step, show=True)


In [None]:
feature = "age"
feature_idx = x_test_2d_step.columns.get_loc(feature)


In [None]:
feature_shap_values_2D_step = shap_values_2D_step[:,feature_idx:feature_idx+1]
feature_x_test_2d_step = x_test_2d_step[[feature]]

In [None]:
feature_shap_values_2D_step.shape, feature_x_test_2d_step.shape

In [None]:
shap.summary_plot(feature_shap_values_2D_step, feature_x_test_2d_step, show=False, max_display=1)
plt.tight_layout()
plt.show()

## GradientExplainer


In [None]:
# Use the training data for deep explainer => can use fewer instances
explainer_2 = shap.GradientExplainer(model, train_X_np)
# explain the testing instances (can use fewer instances)
# explaining each prediction requires 2 * background dataset size runs
shap_values_2 = explainer_2.shap_values(test_X_np)
# init the JS visualization code
shap.initjs()

In [None]:
################# Plot AVERAGE shap values for ALL subjects  #####################
## Consider ABSOLUTE of SHAP values ##
shap_average_abs_value_2 = np.abs(shap_values_2[0]).mean(axis=0)

x_average_value = pd.DataFrame(data=test_X_np.mean(axis=0), columns = features)
shap.force_plot(0, shap_average_abs_value_2, x_average_value)

### Importance for each training instance with SHAP GradientExplainer


In [None]:
################# Plot AVERAGE shap values for ALL training subjects  #####################
## Consider ABSOLUTE of SHAP values ##
shap.initjs()
shap_values_train = explainer.shap_values(train_X_np)

shap_average_abs_value_train = np.abs(shap_values_train[0]).mean(axis=0)

x_average_value_train = pd.DataFrame(data=train_X_np.mean(axis=0), columns = features)
shap.force_plot(0, shap_average_abs_value_train, x_average_value_train)

In [None]:
shap_values_train_2D = shap_values_train[0].reshape(-1,n_channels)
X_train_2D = train_X_np.reshape(-1,n_channels)


shap.summary_plot(shap_values_train_2D, X_train_2D, features)

In [None]:
# COLOR: https://seaborn.pydata.org/tutorial/color_palettes.html
import seaborn as sns
import matplotlib.pyplot as plt

for i, feature in enumerate(features):
    print(feature)

    plt.figure(figsize = (8,6))
    tmp = shap_values_train[0][:,:,i].reshape((-1,n_time_steps))
    print(tmp.shape)
    plot_shap = sns.heatmap(tmp, cmap="coolwarm")
    plt.show(plot_shap)
    print("-----------")