In [None]:
import shap
import numpy as np
import pandas as pd
import os
import pickle
from prediction.utils.shap_helper_functions import check_shap_version_compatibility
import seaborn as sns
import matplotlib.pyplot as plt

Requirements:
- TensorFlow 1.14
- Python 3.7
- Protobuf downgrade to 3.20: `pip install protobuf==3.20`
- downgrade h5py to 2.10: `pip install h5py==2.10`
- turn off masking in LSTM

In [None]:
# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
# print the JS visualization code to the notebook
shap.initjs()

In [None]:
death_shap_values_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/transformer_explainer_shap_values_over_ts_death_captum_n1431_all_ts_cv1.pkl'
mrs_shap_values_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/explanability/transformer_explainer_shap_values_over_ts_3m_mrs02_captum_n1449_all_72_cv2.pkl'
death_in_hosp_shap_values_path = '/Users/jk1/temp/opsum_prediction_output/transformer/Death_in_hospital/inference/death_in_hospital_shap_values_2/transformer_explainer_shap_values_over_ts.pkl'
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'
cat_encoding_path = os.path.join(os.path.dirname(features_path), f'logs_{os.path.basename(features_path).split(".")[0].split("_")[-2]}_{os.path.basename(features_path).split(".")[0].split("_")[-1]}/categorical_variable_encoding.csv')

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
outcome = '3M Death'
seed = 42
test_size = 0.20
n_splits = 5

In [None]:
if outcome == '3M mRS 0-2':
    shap_values_path = mrs_shap_values_path
elif outcome == '3M Death':
    shap_values_path = death_shap_values_path
elif outcome == 'Death in hospital':
    shap_values_path = death_in_hosp_shap_values_path
else:
    raise ValueError(f'Outcome {outcome} not supported')

## Load the data

In [None]:
from prediction.outcome_prediction.data_loading.data_loader import load_data

pids, train_data, test_data, train_splits, test_features_lookup_table = load_data(features_path, labels_path, outcome, test_size, n_splits, seed)


In [None]:
# load the shap values
with open(os.path.join(shap_values_path), 'rb') as handle:
    original_shap_values = pickle.load(handle)

In [None]:
only_last_timestep = True
if only_last_timestep:
    # use predictions from last timestep (as it also produces output for other timesteps)
    shap_values = [original_shap_values[-1]]

else:
    shap_values = [np.array([original_shap_values[i][:, -1, :] for i in range(len(original_shap_values))]).swapaxes(0, 1)]

In [None]:
features = list(test_features_lookup_table['sample_label'].keys())

In [None]:
shap_values[0].shape

### Create working data frame

Join data in a common dataframe with shap values and feature values

In [None]:
selected_shap_values = shap_values[0].sum(axis=(1))
selected_shap_values_df = pd.DataFrame(data=selected_shap_values, columns = np.array(features))
selected_shap_values_df = selected_shap_values_df.reset_index()
selected_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
selected_shap_values_df = selected_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')


In [None]:
test_X_np, test_y_np = test_data


In [None]:
selected_feature_values_df =  pd.DataFrame(data=test_X_np.mean(axis=(1)), columns = features)
selected_feature_values_df = selected_feature_values_df.reset_index()
selected_feature_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
selected_feature_values_df = selected_feature_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='feature_value')


In [None]:
features_with_shap_values_df = pd.merge(selected_shap_values_df, selected_feature_values_df, on=['case_admission_id_idx', 'feature'])

In [None]:
reverse_categorical_encoding = True

if reverse_categorical_encoding:
    cat_encoding_df = pd.read_csv(cat_encoding_path)
    for i in range(len(cat_encoding_df)):
        cat_basename = cat_encoding_df.sample_label[i].lower().replace(' ', '_')
        cat_item_list = cat_encoding_df.other_categories[i].replace('[', '').replace(']', '').replace('\'', '').split(', ')
        cat_item_list = [cat_basename + '_' + item.replace(' ', '_').lower() for item in cat_item_list]
        for cat_item_idx, cat_item in enumerate(cat_item_list):
            #  retrieve the dominant category for this subject (0 being default category)
            features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_item, 'feature_value'] *= cat_item_idx + 1
            features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_item, 'feature'] = cat_encoding_df.sample_label[i]
            # sum the shap and feature values for each subject
            features_with_shap_values_df = features_with_shap_values_df.groupby(['case_admission_id_idx', 'feature']).sum().reset_index()

        # give a numerical encoding to the categorical features
    cat_to_numerical_encoding = {
        'Prestroke disability (Rankin)': {0:0, 1:5, 2:4, 3:2, 4:1, 5:3},
        'categorical_onset_to_admission_time': {0:1, 1:2, 2:3, 3:4, 4:0},
        'categorical_IVT': {0:2, 1:3, 2:4, 3:1, 4:0},
        'categorical_IAT': {0:1, 1:0, 2:3, 3:2}
    }

    for cat_feature, cat_encoding in cat_to_numerical_encoding.items():
        features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_feature, 'feature_value'] = features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_feature, 'feature_value'].map(cat_encoding)


In [None]:
pool_hourly_split_values = True

# For features that are downsampled to hourly values, pool the values (median, min, max)

if pool_hourly_split_values:
    hourly_split_features = ['NIHSS', 'systolic_blood_pressure', 'diastolic_blood_pressure', 'mean_blood_pressure', 'heart_rate', 'respiratory_rate', 'temperature', 'oxygen_saturation']
    for feature in hourly_split_features:
        features_with_shap_values_df.loc[features_with_shap_values_df.feature.str.contains(feature), 'feature'] = (feature[0].upper() + feature[1:]
).replace('_', ' ')



Replace feature names with their english names

In [None]:
feature_to_english_name_correspondence_path = os.path.join(os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))))),
                                                           'preprocessing/preprocessing_tools/feature_name_to_english_name_correspondence.xlsx')
feature_to_english_name_correspondence = pd.read_excel(feature_to_english_name_correspondence_path)


In [None]:
for feature in features_with_shap_values_df.feature.unique():
    if feature in feature_to_english_name_correspondence.feature_name.values:
        features_with_shap_values_df.loc[features_with_shap_values_df.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]

## Feature selection

Select only the features that are in the top 10 most important features by mean absolute shap value

In [None]:
# identify the top 10 most important features by mean absolute shap value
features_with_shap_values_df['absolute_shap_value'] = np.abs(features_with_shap_values_df['shap_value'])
top_10_features_by_mean_abs_summed_shap = features_with_shap_values_df.groupby('feature').mean().sort_values(by='absolute_shap_value', ascending=False).head(10).index.values
top_10_features_by_mean_abs_summed_shap

In [None]:
selected_features = top_10_features_by_mean_abs_summed_shap
selected_features_with_shap_values_df = features_with_shap_values_df[features_with_shap_values_df.feature.isin(selected_features)]

Alternatively, features could also be selected before joining categories and pooling hourly values

In [None]:
ten_most_important_features_by_mean_abs_shap = np.abs(shap_values[0]).mean(axis=(0, 1)).argsort()[::-1][0:10]
np.array(features)[ten_most_important_features_by_mean_abs_shap]

## Create color palette for feature values

In [None]:

all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

In [None]:
base_colors = sns.color_palette(['#f61067', '#012D98'], n_colors=2)
base_colors

In [None]:
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import sRGBColor, HSVColor, LabColor, LCHuvColor, XYZColor, LCHabColor, LuvColor

start_color = '#012D98'
end_color = '#f61067'

# start_color= '#049b9a'
# end_color= '#012D98'

number_of_colors = 50

start_rgb = hex_to_rgb_color(start_color)
end_rgb = hex_to_rgb_color(end_color)

palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)
custom_cmap = sns.color_palette(palette, n_colors=number_of_colors, as_cmap=True)
sns.color_palette(palette, n_colors=number_of_colors)

## Plot most important features with SHAP values

Preqrequisites: pd.Dataframe with shap values and feature values for each feature, along with indexes for each case

In [None]:
from matplotlib.colors import ListedColormap
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple
import matplotlib.cm as cm


plot_shap_direction_label = True
plot_legend = True
plot_colorbar = True
plot_feature_value_along_y = False

tick_label_size = 11
label_font_size = 13

row_height = 0.4
alpha = 0.8

plt.gcf().set_size_inches(10, 10)


for pos, feature in enumerate(selected_features[::-1]):
    shaps = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].shap_value.values
    values = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].feature_value
    plt.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)

    values = np.array(values, dtype=np.float64)  # make sure this can be numeric

    N = len(shaps)
    nbins = 100
    quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
    inds = np.argsort(quant + np.random.randn(N) * 1e-6)
    layer = 0
    last_bin = -1

    if plot_feature_value_along_y:
        ys = values.copy()
        cluster_factor = 0.1
        for ind in inds:
            if quant[ind] != last_bin:
                layer = 0
            ys[ind] += cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
            layer += 1
            last_bin = quant[ind]

    else:
        ys = np.zeros(N)
        cluster_factor = 1
        for ind in inds:
            if quant[ind] != last_bin:
                layer = 0
            ys[ind] = cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
            layer += 1
            last_bin = quant[ind]

    ys *= 0.9 * (row_height / np.max(ys + 1))

    # trim the color range, but prevent the color range from collapsing
    vmin = np.nanpercentile(values, 5)
    vmax = np.nanpercentile(values, 95)
    if vmin == vmax:
        vmin = np.nanpercentile(values, 1)
        vmax = np.nanpercentile(values, 99)
        if vmin == vmax:
            vmin = np.min(values)
            vmax = np.max(values)
    if vmin > vmax: # fixes rare numerical precision issues
        vmin = vmax

    # plot the non-nan values colored by the trimmed feature value
    cvals = values.astype(np.float64)
    cvals_imp = cvals.copy()
    cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
    cvals[cvals_imp > vmax] = vmax
    cvals[cvals_imp < vmin] = vmin
    plt.scatter(shaps, pos + ys,
               cmap=ListedColormap(palette), vmin=vmin, vmax=vmax, s=16,
               c=cvals, alpha=alpha, linewidth=0,
               zorder=3, rasterized=len(shaps) > 500)


import matplotlib.cm as cm

axis_color="#333333"
if plot_colorbar:
    m = cm.ScalarMappable(cmap=ListedColormap(palette))
    m.set_array([0, 1])
    cb = plt.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2)
    cb.set_ticklabels(['Low', 'High'])
    cb.ax.tick_params(labelsize=tick_label_size, length=0)
    cb.set_label('Feature value', size=label_font_size)
    cb.ax.yaxis.set_label_position('left')
    cb.set_alpha(1)
    cb.outline.set_visible(False)

if plot_legend:
    legend_markers = []
    legend_labels = []
    single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',
                          markersize=10)
    single_dot_label = 'Single Patient\n(summed over time)'
    legend_markers.append(single_dot)
    legend_labels.append(single_dot_label)

    plt.gca().legend(legend_markers, legend_labels, title='SHAP/Feature values', fontsize=tick_label_size, title_fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None)},
                     loc='upper left', frameon=True)


plt.gca().xaxis.set_ticks_position('bottom')
plt.gca().yaxis.set_ticks_position('none')
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().tick_params(color=axis_color, labelcolor=axis_color)

yticklabels = selected_features[::-1]
plt.yticks(range(len(selected_features_with_shap_values_df.feature.unique())), yticklabels, fontsize=label_font_size)
plt.gca().tick_params('y', length=20, width=0.5, which='major')
plt.gca().tick_params('x', labelsize=tick_label_size)
plt.ylim(-1, len(selected_features_with_shap_values_df.feature.unique()))
plt.xlabel('SHAP Value \n(impact on model output)', fontsize=label_font_size)
plt.grid(color='white', axis='y')

# plt.xlim(-0.25, 0.15)

# Plot additional explanation with the shap value X axis
if plot_shap_direction_label:
    x_ticks_coordinates = plt.xticks()[0]
    x_ticks_labels = [item.get_text() for item in plt.xticks()[1]]
    # let x tick label be the coordinate with 2 decimals

    if outcome == '3M Death':
        x_ticks_labels = [f'{x_ticks_coordinate:.0f}' for x_ticks_coordinate in x_ticks_coordinates]
        x_ticks_labels[0] = f'Toward better\noutcome'
        x_ticks_labels[-1] = f'Toward worse\noutcome'
    else:
        x_ticks_labels = [f'{x_ticks_coordinate:.1f}' for x_ticks_coordinate in x_ticks_coordinates]
        x_ticks_labels[0] = f'Toward worse\noutcome'
        x_ticks_labels[-1] = f'Toward better\noutcome'

    plt.xticks(x_ticks_coordinates, x_ticks_labels)

fig = plt.gcf()

plt.show()


In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple
import matplotlib.cm as cm


def plot_top_features_shap(selected_features_with_shap_values_df, selected_features,
        ax,
    plot_shap_direction_label = True,
    plot_legend = True,
    plot_colorbar = True,
    plot_feature_value_along_y = False,
    reverse_outcome_direction = False,   
    tick_label_size = 11,
    label_font_size = 13,
    row_height = 0.4,
    alpha = 0.8,
    xlim:tuple = None
    ):
    
    # Define the color palette
    start_color = '#012D98'
    end_color = '#f61067'
    number_of_colors = 50
    start_rgb = hex_to_rgb_color(start_color)
    end_rgb = hex_to_rgb_color(end_color)
    palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)
      
    
    for pos, feature in enumerate(selected_features[::-1]):
        shaps = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].shap_value.values
        values = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].feature_value
        ax.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
    
        values = np.array(values, dtype=np.float64)  # make sure this can be numeric
    
        N = len(shaps)
        nbins = 100
        quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
        inds = np.argsort(quant + np.random.randn(N) * 1e-6)
        layer = 0
        last_bin = -1
    
        if plot_feature_value_along_y:
            ys = values.copy()
            cluster_factor = 0.1
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] += cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
                layer += 1
                last_bin = quant[ind]
    
        else:
            ys = np.zeros(N)
            cluster_factor = 1
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] = cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
                layer += 1
                last_bin = quant[ind]
    
        ys *= 0.9 * (row_height / np.max(ys + 1))
    
        # trim the color range, but prevent the color range from collapsing
        vmin = np.nanpercentile(values, 5)
        vmax = np.nanpercentile(values, 95)
        if vmin == vmax:
            vmin = np.nanpercentile(values, 1)
            vmax = np.nanpercentile(values, 99)
            if vmin == vmax:
                vmin = np.min(values)
                vmax = np.max(values)
        if vmin > vmax: # fixes rare numerical precision issues
            vmin = vmax
    
        # plot the non-nan values colored by the trimmed feature value
        cvals = values.astype(np.float64)
        cvals_imp = cvals.copy()
        cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
        cvals[cvals_imp > vmax] = vmax
        cvals[cvals_imp < vmin] = vmin
        ax.scatter(shaps, pos + ys,
                   cmap=ListedColormap(palette), vmin=vmin, vmax=vmax, s=16,
                   c=cvals, alpha=alpha, linewidth=0,
                   zorder=3, rasterized=len(shaps) > 500)
    
    
    
    axis_color="#333333"
    if plot_colorbar:
        m = cm.ScalarMappable(cmap=ListedColormap(palette))
        m.set_array([0, 1])
        
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.5)

        # get fig from ax
        fig = ax.get_figure()
        cb = fig.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2, ax=cax)
        cb.set_ticklabels(['Low', 'High'], backgroundcolor="white")
        cb.ax.tick_params(labelsize=tick_label_size, length=0)
        cb.set_label('Feature value', size=label_font_size, backgroundcolor="white")
        cb.ax.yaxis.set_label_position('left')
        cb.set_alpha(1)
        cb.outline.set_visible(False)
        # turn off grid and spines on cax
        cax.grid(False)
        cax.spines['right'].set_visible(False)
        cax.spines['top'].set_visible(False)
        cax.spines['left'].set_visible(False)
        cax.spines['bottom'].set_visible(False)
        cax.set_xticks([])
        cax.set_yticks([])

    
    if plot_legend:
        legend_markers = []
        legend_labels = []
        single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',
                              markersize=10)
        single_dot_label = 'Single Patient\n(summed over time)'
        legend_markers.append(single_dot)
        legend_labels.append(single_dot_label)
    
        ax.legend(legend_markers, legend_labels, title='SHAP/Feature values', fontsize=tick_label_size, title_fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)},
                         loc='upper left', frameon=True)
    
    
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('none')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(color=axis_color, labelcolor=axis_color)
    
    yticklabels = selected_features[::-1]
    ax.set_yticks(range(len(selected_features_with_shap_values_df.feature.unique())))
    ax.set_yticklabels(yticklabels, fontsize=label_font_size)
    ax.tick_params('y', length=20, width=0.5, which='major')
    ax.tick_params('x', labelsize=tick_label_size)
    ax.set_ylim(-1, len(selected_features_with_shap_values_df.feature.unique()))
    ax.set_xlabel('SHAP Value \n(impact on model output)', fontsize=label_font_size)
    ax.grid(color='white', axis='y')
    
    if xlim:
        ax.set_xlim(xlim[0], xlim[1])
    
    # Plot additional explanation with the shap value X axis
    if plot_shap_direction_label:
        x_ticks_coordinates = ax.get_xticks()
        x_ticks_labels = [item.get_text() for item in ax.get_xticklabels()]
        # let x tick label be the coordinate with 2 decimals
    
        if reverse_outcome_direction:
            x_ticks_labels = [f'{x_ticks_coordinate:.1f}' for x_ticks_coordinate in x_ticks_coordinates]
            x_ticks_labels[0] = f'Toward better\noutcome'
            x_ticks_labels[-1] = f'Toward worse\noutcome'
        else:
            x_ticks_labels = [f'{x_ticks_coordinate:.1f}' for x_ticks_coordinate in x_ticks_coordinates]
            x_ticks_labels[0] = f'Toward worse\noutcome'
            x_ticks_labels[-1] = f'Toward better\noutcome'
    
        ax.set_xticks(x_ticks_coordinates)
        ax.set_xticklabels(x_ticks_labels)
    
    return ax

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

reverse_outcome_direction = False
if outcome in ['3M Death', 'Death in hospital']:
    reverse_outcome_direction = True
plot_top_features_shap(
    selected_features_with_shap_values_df, selected_features,
        ax,
    reverse_outcome_direction=reverse_outcome_direction,
    xlim=(-1.5, 1.5)
    )

In [None]:
# fig.savefig(os.path.join(output_dir, f'top_features_captum_shap_{outcome}.svg'), bbox_inches="tight", format='svg', dpi=1200)

In [None]:
save_plot_data = False

In [None]:
if save_plot_data:
    with open(os.path.join('/Users/jk1/Downloads', f'{outcome.replace(" ", "_")}_top_shap_features_figure_data.pkl'), 'wb') as f:
        pickle.dump((selected_features_with_shap_values_df, selected_features), f)