In [None]:
import shap
import numpy as np
import pandas as pd
import os
import json
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
from prediction.utils.shap_helper_functions import check_shap_version_compatibility

Requirements:
- TensorFlow 1.14
- Python 3.7
- Protobuf downgrade to 3.20: `pip install protobuf==3.20`
- downgrade h5py to 2.10: `pip install h5py==2.10`
- turn off masking in LSTM

In [None]:
# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
# print the JS visualization code to the notebook
shap.initjs()

In [None]:
mrs02_test_features_lookup_table_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/test_lookup_dict.json'
mrs02_shap_values_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/explanability/transformer_explainer_shap_values_over_ts_3m_mrs02_captum_n1449_all_72_cv2.pkl'

death_test_features_lookup_table_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/test_lookup_dict.json'
death_shap_values_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/transformer_explainer_shap_values_over_ts_death_captum_n1431_all_ts_cv1.pkl'

death_in_hospital_test_features_lookup_table_path = '/Users/jk1/temp/opsum_prediction_output/transformer/Death_in_hospital/testing/15_02_23/test_lookup_dict.json'
death_in_hospital_shap_values_path = '/Users/jk1/temp/opsum_prediction_output/transformer/Death_in_hospital/inference/death_in_hospital_shap_values_2/transformer_explainer_shap_values_over_ts.pkl'

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
outcome = 'Death in hospital'
seed = 42
test_size = 0.2
n_splits = 5
save_plot_data = True

In [None]:
if outcome == '3M mrs02':
    # load the test features lookup table from json as dict
    test_features_lookup_table = json.load(open(mrs02_test_features_lookup_table_path))
    # load the shap values
    with open(mrs02_shap_values_path,
              'rb') as handle:
        original_shap_values = pickle.load(handle)
elif outcome == '3M Death':
    # load the test features lookup table from json as dict
    test_features_lookup_table = json.load(open(death_test_features_lookup_table_path))
    # load the shap values
    with open(death_shap_values_path,
              'rb') as handle:
        original_shap_values = pickle.load(handle)
elif outcome == 'Death in hospital':
    # load the test features lookup table from json as dict
    test_features_lookup_table = json.load(open(death_in_hospital_test_features_lookup_table_path))
    # load the shap values
    with open(death_in_hospital_shap_values_path,
              'rb') as handle:
        original_shap_values = pickle.load(handle)
else:
    raise ValueError('Outcome not supported')

In [None]:
only_last_timestep = True
if only_last_timestep:
    # use predictions from last timestep (as it also produces output for other timesteps)
    shap_values = [original_shap_values[-1]]

else:
    shap_values = [np.array([original_shap_values[i][:, -1, :] for i in range(len(original_shap_values))]).swapaxes(0, 1)]

In [None]:
features = list(test_features_lookup_table['sample_label'].keys())

In [None]:
shap_values[0].shape

In [None]:
n_subj = shap_values[0].shape[0]
n_timesteps = shap_values[0].shape[1]

In [None]:
from tqdm import tqdm

shap_values_df = pd.DataFrame()
for subj_idx in tqdm(range(shap_values[0].shape[0])):
    subj_df = pd.DataFrame(shap_values[0][subj_idx])
    subj_df.reset_index(inplace=True)
    subj_df.rename(columns={'index': 'timestep'}, inplace=True)
    subj_df['subj_idx'] = subj_idx
    shap_values_df = shap_values_df.append(subj_df, ignore_index=True)


In [None]:
shap_values_df = shap_values_df.melt(id_vars=['subj_idx', 'timestep'], var_name='feature_idx', value_name='shap_value')

# Palette Creation
Create color palette for feature values

In [None]:

all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

In [None]:
base_colors = sns.color_palette(['#f61067', '#012D98'], n_colors=2)
base_colors

In [None]:
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import sRGBColor, HSVColor, LabColor, LCHuvColor, XYZColor, LCHabColor, LuvColor

# start_color = '#012D98'
# end_color = '#f61067'

start_color= '#049b9a'
end_color= '#012D98'

number_of_colors = n_timesteps * n_subj

start_rgb = hex_to_rgb_color(start_color)
end_rgb = hex_to_rgb_color(end_color)

palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)
custom_cmap = sns.color_palette(palette, n_colors=number_of_colors, as_cmap=True)
sns.color_palette(palette, n_colors=number_of_colors)

# Time importance
Find most important timepoints

### Summed absolute SHAP features along time

In [None]:
summed_shap_along_features = np.abs(shap_values[0]).sum(axis=-1)

In [None]:
from matplotlib.colors import ListedColormap
import matplotlib.cm as cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import LabColor


def plot_shap_along_time(summed_shap_along_features, n_timesteps, n_subj,
        ax,
    plot_legend = True,
    plot_colorbar = True,
    plot_median = True,
    tick_label_size = 11,
    label_font_size = 13,
                         alpha=0.9
    ):

    # create palette
    start_color= '#049b9a'
    end_color= '#012D98'
    number_of_colors = n_timesteps * n_subj
    start_rgb = hex_to_rgb_color(start_color)
    end_rgb = hex_to_rgb_color(end_color)    
    palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)

    ax = sns.scatterplot(x=np.tile(np.arange(0, n_timesteps), n_subj) + (np.random.rand(n_subj * n_timesteps) / 1), y=summed_shap_along_features.flatten(),
                         hue=np.log(summed_shap_along_features.flatten()), ax=ax,
                         alpha=0.05, legend=False, palette=palette)
    
    if plot_median:
        # plot median summed shap value on top
        median_color = '#f61067'
        ax = sns.lineplot(x=np.arange(0, n_timesteps), y=np.median(summed_shap_along_features, axis=0), ax=ax,
                          color=median_color, markers='.', lw=2, alpha=alpha)
    
    ax.set(yscale="log")
    
    ax.set_xlabel('Time from admission (hours)', fontsize=label_font_size)
    ax.set_ylabel('Sum of absolute SHAP values', fontsize=label_font_size)
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)
    
    if plot_colorbar:
        m = cm.ScalarMappable(cmap=ListedColormap(palette))
        m.set_array([0, 1])
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=1.2)

        # get fig from ax
        fig = ax.get_figure()
        
        cb = fig.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2, ax=cax)
        cb.set_ticklabels(['Small impact on \nmodel output', 'Large impact on \nmodel output'])
        cb.ax.tick_params(labelsize=tick_label_size, length=0)
        cb.set_label('SHAP values', size=label_font_size)
        cb.ax.yaxis.set_label_position('left')
        cb.set_alpha(1)
        cb.outline.set_visible(False)
        # turn off grid and spines on cax
        cax.grid(False)
        cax.spines['right'].set_visible(False)
        cax.spines['top'].set_visible(False)
        cax.spines['left'].set_visible(False)
        cax.spines['bottom'].set_visible(False)
        cax.set_xticks([])
        cax.set_yticks([])
    
    if plot_legend:
        legend_markers = []
        legend_labels = []
        single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',
                              markersize=10)
        single_dot_label = 'Single patient at timepoint t'
        legend_markers.append(single_dot)
        legend_labels.append(single_dot_label)
    
        if plot_median:
            median_line = mlines.Line2D([], [], color=median_color, linestyle='-')
            median_line_label = 'Median'
            legend_markers.append(median_line)
            legend_labels.append(median_line_label)
    
        ax.legend(legend_markers, legend_labels, title='Summed absolute SHAP', fontsize=tick_label_size, title_fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
        
    return ax

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

fig, ax = plt.subplots()

plot_shap_along_time(summed_shap_along_features, n_timesteps, n_subj, ax, plot_legend=True, plot_colorbar=True, plot_median=True)
    

In [None]:
# fig.savefig(os.path.join(output_dir, f'shap_vs_time_{outcome.replace(" ", "_")}.svg'), bbox_inches="tight", format='svg', dpi=1200)
fig.savefig(os.path.join(output_dir, f'shap_vs_time_{outcome.replace(" ", "_")}.tiff'), bbox_inches="tight", format='tiff', dpi=600)

In [None]:
if save_plot_data:
    with open(os.path.join('/Users/jk1/Downloads', f'{outcome.replace(" ", "_")}_shap_along_time_figure_data.pkl'), 'wb') as f:
        pickle.dump((summed_shap_along_features, n_timesteps, n_subj), f)

### Plotting all shap values

In [None]:
ax = sns.scatterplot(data=shap_values_df, x='timestep', y='shap_value')
ax.set(yscale="symlog")

In [None]:
ax = sns.scatterplot(data=shap_values_df, x='timestep', y='shap_value')
