In [None]:
import shap
import numpy as np
import pandas as pd
import os
import pickle
from prediction.outcome_prediction.LSTM.testing.shap_helper_functions import check_shap_version_compatibility
import seaborn as sns
import matplotlib.pyplot as plt

Requirements:
- TensorFlow 1.14
- Python 3.7
- Protobuf downgrade to 3.20: `pip install protobuf==3.20`
- downgrade h5py to 2.10: `pip install h5py==2.10`
- turn off masking in LSTM

In [None]:
# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
# print the JS visualization code to the notebook
shap.initjs()

In [None]:
saved_model_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration/testing/selected_xgb_model_cv3.pkl'
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'
cat_encoding_path = os.path.join(os.path.dirname(features_path), f'logs_{os.path.basename(features_path).split(".")[0].split("_")[-2]}_{os.path.basename(features_path).split(".")[0].split("_")[-1]}/categorical_variable_encoding.csv')

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
outcome = '3M mRS 0-2'
moving_time_average = False
test_size = 0.2
seed = 42

In [None]:
override_masking_value = False

## Load the data

In [None]:
from prediction.utils.utils import aggregate_features_over_time
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time
from sklearn.model_selection import train_test_split
from prediction.outcome_prediction.data_loading.data_formatting import features_to_numpy, \
    link_patient_id_to_outcome, numpy_to_lookup_table

X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=outcome)

"""
SPLITTING DATA
Splitting is done by patient id (and not admission id) as in case of the rare multiple admissions per patient there
would be a risk of data leakage otherwise split 'pid' in TRAIN and TEST pid = unique patient_id
"""
# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)

#  Extracting TEST data
test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]

test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')
test_features_lookup_table = numpy_to_lookup_table(test_X_np)

# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')
X_test, y_test = aggregate_features_over_time(test_X_np, test_y_np, moving_average=moving_time_average)
# only keep prediction at last timepoint
X_test = X_test.reshape(-1, 72, X_test.shape[-1])[:, -1, :].astype('float32')
y_test = y_test.reshape(-1, 72)[:, -1].astype('float32')


#  Extracting TRAIN data
# find indexes for train admissions
X_train_df = X.loc[X.patient_id.isin(pid_train)]
y_train_df = y.loc[y.patient_id.isin(pid_train)]

# Transform dataframes to numpy arrays
X_train = features_to_numpy(X_train_df,
                                 ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label',
                                  'value'])
y_train = np.array([y_train_df[y_train_df.case_admission_id == cid].outcome.values[0] for cid in
                         X_train[:, 0, 0, 0]]).astype('float32')

# Remove the case_admission_id, sample_label, and time_step_label columns from the data
X_train = X_train[:, :, :, -1].astype('float32')
X_train, y_train = aggregate_features_over_time(X_train, y_train, moving_average=moving_time_average)


## Load the model

In [None]:
model = pickle.load(open(saved_model_path, 'rb'))

In [None]:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

In [None]:
features = list(test_features_lookup_table['sample_label'].keys())
features = np.concatenate([['last_tp_' + f for f in features], ['avg_' + f for f in features],
                           ['min' + f for f in features], ['max_' + f for f in features]])
features.shape

In [None]:
shap_values.shape, X_test.shape, features.shape

### Create working data frame

Join data in a common dataframe with shap values and feature values

In [None]:
pd.DataFrame(data=shap_values, columns = features)

In [None]:
selected_shap_values_df = pd.DataFrame(data=shap_values, columns = features)
selected_shap_values_df = selected_shap_values_df.reset_index()
selected_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
selected_shap_values_df = selected_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')

In [None]:
selected_feature_values_df =  pd.DataFrame(data=X_test, columns = features)
selected_feature_values_df = selected_feature_values_df.reset_index()
selected_feature_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
selected_feature_values_df = selected_feature_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='feature_value')

In [None]:
features_with_shap_values_df = pd.merge(selected_shap_values_df, selected_feature_values_df, on=['case_admission_id_idx', 'feature'])

In [None]:
CURRENT POSITION IN CODE

In [None]:
reverse_categorical_encoding = True

if reverse_categorical_encoding:
    cat_encoding_df = pd.read_csv(cat_encoding_path)
    for i in range(len(cat_encoding_df)):
        cat_basename = cat_encoding_df.sample_label[i].lower().replace(' ', '_')
        cat_item_list = cat_encoding_df.other_categories[i].replace('[', '').replace(']', '').replace('\'', '').split(', ')
        cat_item_list = [cat_basename + '_' + item.replace(' ', '_').lower() for item in cat_item_list]
        for cat_item_idx, cat_item in enumerate(cat_item_list):
            #  retrieve the dominant category for this subject (0 being default category)
            features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_item, 'feature_value'] *= cat_item_idx + 1
            features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_item, 'feature'] = cat_encoding_df.sample_label[i]
            # sum the shap and feature values for each subject
            features_with_shap_values_df = features_with_shap_values_df.groupby(['case_admission_id_idx', 'feature']).sum().reset_index()


In [None]:
# give a numerical encoding to the categorical features
cat_to_numerical_encoding = {
    'Prestroke disability (Rankin)': {0:0, 1:5, 2:4, 3:2, 4:1, 5:3},
    'categorical_onset_to_admission_time': {0:1, 1:2, 2:3, 3:4, 4:0},
    'categorical_IVT': {0:2, 1:3, 2:4, 3:1, 4:0},
    'categorical_IAT': {0:1, 1:0, 2:3, 3:2}
}

In [None]:
for cat_feature, cat_encoding in cat_to_numerical_encoding.items():
    features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_feature, 'feature_value'] = features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_feature, 'feature_value'].map(cat_encoding)

In [None]:
pool_hourly_split_values = True

# For features that are downsampled to hourly values, pool the values (median, min, max)

if pool_hourly_split_values:
    hourly_split_features = ['NIHSS', 'systolic_blood_pressure', 'diastolic_blood_pressure', 'heart_rate', 'respiratory_rate', 'temperature', 'oxygen_saturation']
    for feature in hourly_split_features:
        features_with_shap_values_df.loc[features_with_shap_values_df.feature.str.contains(feature), 'feature'] = (feature[0].upper() + feature[1:]
).replace('_', ' ')



Replace feature names with their english names

In [None]:
feature_to_english_name_correspondence_path = os.path.join(os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))))),
                                                           'preprocessing/preprocessing_tools/feature_name_to_english_name_correspondence.xlsx')
feature_to_english_name_correspondence = pd.read_excel(feature_to_english_name_correspondence_path)


In [None]:
for feature in features_with_shap_values_df.feature.unique():
    if feature in feature_to_english_name_correspondence.feature_name.values:
        features_with_shap_values_df.loc[features_with_shap_values_df.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]

## Feature selection

Select only the features that are in the top 10 most important features by mean absolute shap value

In [None]:
# identify the top 10 most important features by mean absolute shap value
features_with_shap_values_df['absolute_shap_value'] = np.abs(features_with_shap_values_df['shap_value'])
top_10_features_by_mean_abs_summed_shap = features_with_shap_values_df.groupby('feature').mean().sort_values(by='absolute_shap_value', ascending=False).head(10).index.values
top_10_features_by_mean_abs_summed_shap

In [None]:
features_with_shap_values_df = features_with_shap_values_df[features_with_shap_values_df.feature.isin(top_10_features_by_mean_abs_summed_shap)]

Alternatively, features could also be selected before joining categories and pooling hourly values

In [None]:
ten_most_important_features_by_mean_abs_shap = np.abs(shap_values[0]).mean(axis=(0, 1)).argsort()[::-1][0:13]
np.array(features)[ten_most_important_features_by_mean_abs_shap]

## Create color palette for feature values

In [None]:

all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

In [None]:
base_colors = sns.color_palette(['#f61067', '#012D98'], n_colors=2)
base_colors

In [None]:
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import sRGBColor, HSVColor, LabColor, LCHuvColor, XYZColor, LCHabColor, LuvColor

start_color = '#012D98'
end_color = '#f61067'

# start_color= '#049b9a'
# end_color= '#012D98'

number_of_colors = 50

start_rgb = hex_to_rgb_color(start_color)
end_rgb = hex_to_rgb_color(end_color)

palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)
custom_cmap = sns.color_palette(palette, n_colors=number_of_colors, as_cmap=True)
sns.color_palette(palette, n_colors=number_of_colors)

## Plot most important features with SHAP values

Preqrequisites: pd.Dataframe with shap values and feature values for each feature, along with indexes for each case

In [None]:
from matplotlib.colors import ListedColormap
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple

plot_shap_direction_label = True
plot_legend = True
plot_colorbar = True
plot_feature_value_along_y = False

tick_label_size = 11
label_font_size = 13

row_height = 0.4
alpha = 0.8

plt.gcf().set_size_inches(10, 10)


for pos, feature in enumerate(features_with_shap_values_df.feature.unique()):
    shaps = features_with_shap_values_df[features_with_shap_values_df.feature.isin([feature])].shap_value.values
    values = features_with_shap_values_df[features_with_shap_values_df.feature.isin([feature])].feature_value
    plt.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)

    values = np.array(values, dtype=np.float64)  # make sure this can be numeric

    N = len(shaps)
    nbins = 100
    quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
    inds = np.argsort(quant + np.random.randn(N) * 1e-6)
    layer = 0
    last_bin = -1

    if plot_feature_value_along_y:
        ys = values.copy()
        cluster_factor = 0.1
        for ind in inds:
            if quant[ind] != last_bin:
                layer = 0
            ys[ind] += cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
            layer += 1
            last_bin = quant[ind]

    else:
        ys = np.zeros(N)
        cluster_factor = 1
        for ind in inds:
            if quant[ind] != last_bin:
                layer = 0
            ys[ind] = cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
            layer += 1
            last_bin = quant[ind]

    ys *= 0.9 * (row_height / np.max(ys + 1))

    # trim the color range, but prevent the color range from collapsing
    vmin = np.nanpercentile(values, 5)
    vmax = np.nanpercentile(values, 95)
    if vmin == vmax:
        vmin = np.nanpercentile(values, 1)
        vmax = np.nanpercentile(values, 99)
        if vmin == vmax:
            vmin = np.min(values)
            vmax = np.max(values)
    if vmin > vmax: # fixes rare numerical precision issues
        vmin = vmax

    # plot the non-nan values colored by the trimmed feature value
    cvals = values.astype(np.float64)
    cvals_imp = cvals.copy()
    cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
    cvals[cvals_imp > vmax] = vmax
    cvals[cvals_imp < vmin] = vmin
    plt.scatter(shaps, pos + ys,
               cmap=ListedColormap(palette), vmin=vmin, vmax=vmax, s=16,
               c=cvals, alpha=alpha, linewidth=0,
               zorder=3, rasterized=len(shaps) > 500)


import matplotlib.cm as cm

axis_color="#333333"
if plot_colorbar:
    m = cm.ScalarMappable(cmap=ListedColormap(palette))
    m.set_array([0, 1])
    cb = plt.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2)
    cb.set_ticklabels(['Low', 'High'])
    cb.ax.tick_params(labelsize=tick_label_size, length=0)
    cb.set_label('Feature value', size=label_font_size)
    cb.ax.yaxis.set_label_position('left')
    cb.set_alpha(1)
    cb.outline.set_visible(False)

if plot_legend:
    legend_markers = []
    legend_labels = []
    single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',
                          markersize=10)
    single_dot_label = 'Single Patient\n(summed over time)'
    legend_markers.append(single_dot)
    legend_labels.append(single_dot_label)

    plt.gca().legend(legend_markers, legend_labels, title='SHAP/Feature values', fontsize=tick_label_size, title_fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None)},
                     loc='upper left', frameon=True)


plt.gca().xaxis.set_ticks_position('bottom')
plt.gca().yaxis.set_ticks_position('none')
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().tick_params(color=axis_color, labelcolor=axis_color)

yticklabels = features_with_shap_values_df.feature.unique()
plt.yticks(range(len(features_with_shap_values_df.feature.unique())), yticklabels, fontsize=label_font_size)
plt.gca().tick_params('y', length=20, width=0.5, which='major')
plt.gca().tick_params('x', labelsize=tick_label_size)
plt.ylim(-1, len(features_with_shap_values_df.feature.unique()))
plt.xlabel('SHAP Value \n(impact on model output)', fontsize=label_font_size)
plt.grid(color='white', axis='y')

plt.xlim(-0.25, 0.15)

# Plot additional explanation with the shap value X axis
if plot_shap_direction_label:
    x_ticks_coordinates = plt.xticks()[0]
    x_ticks_labels = [item.get_text() for item in plt.xticks()[1]]
    # let x tick label be the coordinate with 2 decimals
    x_ticks_labels = [f'{x_ticks_coordinate:.2f}' for x_ticks_coordinate in x_ticks_coordinates]
    x_ticks_labels[0] = f'Toward worse \noutcome'
    x_ticks_labels[-1] = f'Toward better \noutcome'
    plt.xticks(x_ticks_coordinates, x_ticks_labels)

fig = plt.gcf()

plt.show()


In [None]:
fig.savefig(os.path.join(output_dir, f'top_features_shap_{outcome}.svg'), bbox_inches="tight", format='svg', dpi=1200)
