In [None]:
import shap
import numpy as np
import pandas as pd
import os
import pickle
from prediction.outcome_prediction.LSTM.testing.shap_helper_functions import check_shap_version_compatibility
import seaborn as sns
import matplotlib.pyplot as plt

Requirements:
- TensorFlow 1.14
- Python 3.7
- Protobuf downgrade to 3.20: `pip install protobuf==3.20`
- downgrade h5py to 2.10: `pip install h5py==2.10`
- turn off masking in LSTM

In [None]:
# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
# print the JS visualization code to the notebook
shap.initjs()

In [None]:
model_weights_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS02/2023_01_02_1057/test_LSTM_sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3/sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3.hdf5'
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
outcome = '3M mRS 0-2'
masking = True
units = 16
activation = 'sigmoid'
dropout = 0.2
layers = 2
optimizer = 'RMSprop'
seed = 42
test_size = 0.20

In [None]:
override_masking_value = False

In [None]:
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time

# load the dataset
X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=outcome)

n_time_steps = X.relative_sample_date_hourly_cat.max() + 1
n_channels = X.sample_label.unique().shape[0]

In [None]:
from sklearn.model_selection import train_test_split
from prediction.outcome_prediction.data_loading.data_formatting import features_to_numpy, \
    link_patient_id_to_outcome, numpy_to_lookup_table

# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)

test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]
train_X_df = X[X.patient_id.isin(pid_train)]
train_y_df = y[y.patient_id.isin(pid_train)]

train_X_np = features_to_numpy(train_X_df,
                                 ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
train_y_np = np.array([train_y_df[train_y_df.case_admission_id == cid].outcome.values[0] for cid in
                         train_X_np[:, 0, 0, 0]]).astype('float32')
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')


# create look-up table for case_admission_ids, sample_labels and relative_sample_date_hourly_cat
test_features_lookup_table = numpy_to_lookup_table(test_X_np)
train_features_lookup_table = numpy_to_lookup_table(train_X_np)

# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')
train_X_np = train_X_np[:, :, :, -1].astype('float32')

In [None]:
from prediction.utils.scoring import precision, recall, matthews
from prediction.outcome_prediction.LSTM.LSTM import lstm_generator

model = lstm_generator(x_time_shape=n_time_steps, x_channels_shape=n_channels, masking=override_masking_value, n_units=units,
                           activation=activation, dropout=dropout, n_layers=layers)

model.compile(loss='binary_crossentropy', optimizer=optimizer,
              metrics=['accuracy', precision, recall, matthews])

model.load_weights(model_weights_path)

In [None]:
# Use the training data for deep explainer => can use fewer instances
explainer = shap.DeepExplainer(model, train_X_np)
# explain the testing instances (can use fewer instances)
# explaining each prediction requires 2 * background dataset size runs
shap_values = explainer.shap_values(test_X_np)

In [None]:
# save the shap values
with open(os.path.join('/Users/jk1/Downloads', 'temp_shap_values.pkl'),
          'wb') as handle:
    pickle.dump(shap_values, handle)

In [None]:
features = list(test_features_lookup_table['sample_label'].keys())

In [None]:
shap_values[0].shape

# Feature importance
Find most important features by mean absolute SHAP value

In [None]:
ten_most_important_features_by_mean_abs_shap = np.abs(shap_values[0]).mean(axis=(0, 1)).argsort()[::-1][0:13]

In [None]:
np.array(features)[ten_most_important_features_by_mean_abs_shap]

Plot sum of shap value per feature (mean feature value color coded)


In [None]:
shap.summary_plot(shap_values[0].sum(axis=(1))[:, ten_most_important_features_by_mean_abs_shap], pd.DataFrame(data=test_X_np.mean(axis=(1)), columns = features)[np.array(features)[ten_most_important_features_by_mean_abs_shap]],max_display=13, show=True)

Recreate this plot in seaborn

In [None]:
selected_shap_values = shap_values[0].sum(axis=(1))[:, ten_most_important_features_by_mean_abs_shap]
selected_shap_values_df = pd.DataFrame(data=selected_shap_values, columns = np.array(features)[ten_most_important_features_by_mean_abs_shap])

Join data in a common dataframe with shap values and feature values

In [None]:
selected_shap_values_df = selected_shap_values_df.reset_index()
selected_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)

In [None]:
selected_shap_values_df = selected_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')

In [None]:
selected_feature_values_df =  pd.DataFrame(data=test_X_np.mean(axis=(1)), columns = features)[np.array(features)[ten_most_important_features_by_mean_abs_shap]]

In [None]:
selected_feature_values_df = selected_feature_values_df.reset_index()
selected_feature_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)

In [None]:
selected_feature_values_df = selected_feature_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='feature_value')

In [None]:
features_with_shap_values_df = pd.merge(selected_shap_values_df, selected_feature_values_df, on=['case_admission_id_idx', 'feature'])

In [None]:
features_with_shap_values_df

Create color palette for feature values

In [None]:

all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

In [None]:
base_colors = sns.color_palette(['#f61067', '#012D98'], n_colors=2)
base_colors

In [None]:
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import sRGBColor, HSVColor, LabColor, LCHuvColor, XYZColor, LCHabColor, LuvColor

start_color = '#012D98'
end_color = '#f61067'

# start_color= '#049b9a'
# end_color= '#012D98'

number_of_colors = 50

start_rgb = hex_to_rgb_color(start_color)
end_rgb = hex_to_rgb_color(end_color)

palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)
custom_cmap = sns.color_palette(palette, n_colors=number_of_colors, as_cmap=True)
sns.color_palette(palette, n_colors=number_of_colors)

In [None]:
n_features = 2
marker_size = 5 / n_features

# ax = sns.swarmplot(data=features_with_shap_values_df, x="shap_value", y="feature", s=marker_size, hue='feature_value', alpha=1, palette=palette)
plt.legend([], [], frameon=False)
plt.show()

In [None]:
pd.DataFrame(data=test_X_np.mean(axis=(1)), columns = features)[np.array(features)[ten_most_important_features_by_mean_abs_shap]]

In [None]:
g = sns.FacetGrid(data=features_with_shap_values_df,
                  row="feature",  hue='feature_value',
                  height=3, aspect=4, palette=palette)
g.map(sns.swarmplot, 'shap_value')

Notes:
- Swarm plot with dodge=True could be interesting to show the distribution of feature values for each feature, but takes a lot of time to plot

## Final function

Reusing Original code from SHAP

Preqrequisites: pd.Dataframe with shap values and feature values for each feature, along with indexes for each case

In [None]:
from matplotlib.colors import ListedColormap

plot_shap_direction_label = True
plot_legend = False


plt.gcf().set_size_inches(10, 10)

row_height = 0.4
alpha = 0.8

for pos, feature in enumerate(features_with_shap_values_df.feature.unique()):
    shaps = features_with_shap_values_df[features_with_shap_values_df.feature.isin([feature])].shap_value.values
    values = features_with_shap_values_df[features_with_shap_values_df.feature.isin([feature])].feature_value
    plt.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)

    values = np.array(values, dtype=np.float64)  # make sure this can be numeric

    N = len(shaps)
    nbins = 100
    quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
    inds = np.argsort(quant + np.random.randn(N) * 1e-6)
    layer = 0
    last_bin = -1
    ys = np.zeros(N)
    for ind in inds:
        if quant[ind] != last_bin:
            layer = 0
        ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
        layer += 1
        last_bin = quant[ind]
    ys *= 0.9 * (row_height / np.max(ys + 1))

    # trim the color range, but prevent the color range from collapsing
    vmin = np.nanpercentile(values, 5)
    vmax = np.nanpercentile(values, 95)
    if vmin == vmax:
        vmin = np.nanpercentile(values, 1)
        vmax = np.nanpercentile(values, 99)
        if vmin == vmax:
            vmin = np.min(values)
            vmax = np.max(values)
    if vmin > vmax: # fixes rare numerical precision issues
        vmin = vmax

    # plot the non-nan values colored by the trimmed feature value
    cvals = values.astype(np.float64)
    cvals_imp = cvals.copy()
    cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
    cvals[cvals_imp > vmax] = vmax
    cvals[cvals_imp < vmin] = vmin
    plt.scatter(shaps, pos + ys,
               cmap=ListedColormap(palette), vmin=vmin, vmax=vmax, s=16,
               c=cvals, alpha=alpha, linewidth=0,
               zorder=3, rasterized=len(shaps) > 500)


import matplotlib.cm as cm

axis_color="#333333"
if plot_legend:
    m = cm.ScalarMappable(cmap=ListedColormap(palette))
    m.set_array([0, 1])
    cb = plt.colorbar(m, ticks=[0, 1], aspect=500)
    cb.set_ticklabels(['FEATURE_VALUE_LOW', 'FEATURE_VALUE_HIGH'])
    cb.set_label('Feature vlaue', size=12, labelpad=0)
    cb.ax.tick_params(labelsize=11, length=0)
    cb.set_alpha(1)
    cb.outline.set_visible(False)
    bbox = cb.ax.get_window_extent().transformed(plt.gcf().dpi_scale_trans.inverted())
    cb.ax.set_aspect((bbox.height - 0.9) * 20)

plt.gca().xaxis.set_ticks_position('bottom')
plt.gca().yaxis.set_ticks_position('none')
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().tick_params(color=axis_color, labelcolor=axis_color)

yticklabels = features_with_shap_values_df.feature.unique()
plt.yticks(range(len(features_with_shap_values_df.feature.unique())), yticklabels, fontsize=13)
plt.gca().tick_params('y', length=20, width=0.5, which='major')
plt.gca().tick_params('x', labelsize=11)
plt.ylim(-1, len(features_with_shap_values_df.feature.unique()))
plt.xlabel('SHAP Value \n (impact on model output)', fontsize=13)
plt.grid(color='white', axis='y')

plt.xlim(-0.25, 0.15)

# Plot additional explanation with the shap value X axis
if plot_shap_direction_label:
    x_ticks_coordinates = plt.xticks()[0]
    x_ticks_labels = [item.get_text() for item in plt.xticks()[1]]
    # let x tick label be the coordinate with 2 decimals
    x_ticks_labels = [f'{x_ticks_coordinate:.2f}' for x_ticks_coordinate in x_ticks_coordinates]
    x_ticks_labels[0] = f'Toward worse \n outcome'
    x_ticks_labels[-1] = f'Toward better \n outcome'
    plt.xticks(x_ticks_coordinates, x_ticks_labels)

plt.show()
