For shap, features are aggregated as features, avg_features, min_features, max_features

In [None]:
import pandas as pd
import pickle
import numpy as np
import torch as ch
import os
import seaborn as sns
import matplotlib.pyplot as plt
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import LabColor



In [None]:
shap_values_path = '/Users/jk1/temp/opsum_end/training/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517/checkpoints_short_opsum_xgb_20250518_001112_cv_1/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'
test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/categorical_variable_encoding.csv'

In [None]:
# load the shap values
with open(os.path.join(shap_values_path), 'rb') as handle:
    original_shap_values = pickle.load(handle)

In [None]:
only_last_timestep = False
if only_last_timestep:
    # use predictions from last timestep (as it also produces output for other timesteps)
    shap_values = [original_shap_values[-1]]

else:
    shap_values = [np.array([original_shap_values[i] for i in range(len(original_shap_values))]).swapaxes(0, 1)][0]

In [None]:
X_test, y_test= ch.load(test_data_path)

In [None]:
features = X_test[0, 0, :, 2]

# features, avg_features, min_features, max_features
# create feature names for avg_features, min_features, max_features
avg_features = [f'avg_{i}' for i in features]
min_features = [f'min_{i}' for i in features]
max_features = [f'max_{i}' for i in features]
# combine all feature names
aggregated_feature_names = features.tolist() + avg_features + min_features + max_features + ['base_value']

In [None]:
median_over_subj_shap_values = np.median(shap_values, axis=0)
median_over_time_shap_values = np.median(shap_values, axis=1)
max_over_time_shap_values = np.max(shap_values, axis=1)
idx_max_over_time_shap_values = np.argmax(shap_values, axis=1)

In [None]:
median_over_subj_shap_values_df = pd.DataFrame(data=median_over_subj_shap_values, columns = np.array(aggregated_feature_names))
median_over_time_shap_values_df = pd.DataFrame(data=median_over_time_shap_values, columns = np.array(aggregated_feature_names))
max_over_time_shap_values_df = pd.DataFrame(data=max_over_time_shap_values, columns = np.array(aggregated_feature_names))

In [None]:
median_over_time_shap_values_df.reset_index(inplace=True)
median_over_time_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
median_over_time_shap_values_df = median_over_time_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')

median_over_subj_shap_values_df.reset_index(inplace=True)
median_over_subj_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
median_over_subj_shap_values_df = median_over_subj_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')

max_over_time_shap_values_df.reset_index(inplace=True)
max_over_time_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
max_over_time_shap_values_df = max_over_time_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')

In [None]:
test_X_np = X_test[:, :, :, -1].astype('float32')
avg_over_time_feature_values_df =  pd.DataFrame(data=test_X_np.mean(axis=(1)), columns = features)
avg_over_time_feature_values_df = avg_over_time_feature_values_df.reset_index()
avg_over_time_feature_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
avg_over_time_feature_values_df = avg_over_time_feature_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='feature_value')

In [None]:
feature_at_max_shap_list = []
for subj_i in range(test_X_np.shape[0]):
    feature_at_max_shap_subj_list = []
    for feature_i in range(test_X_np.shape[2]):
        feature_at_max_shap_subj_list.append(test_X_np[subj_i, idx_max_over_time_shap_values[subj_i, feature_i], feature_i])
    feature_at_max_shap_list.append(np.array(feature_at_max_shap_subj_list))
feature_at_max_shap = np.array(feature_at_max_shap_list)

In [None]:
feature_at_max_shap_df = pd.DataFrame(data=feature_at_max_shap, columns=features)
feature_at_max_shap_df.reset_index(inplace=True)
feature_at_max_shap_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
feature_at_max_shap_df = feature_at_max_shap_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='feature_value')


In [None]:
avg_over_time_features_with_shap_values_df = pd.merge(median_over_time_shap_values_df, avg_over_time_feature_values_df, on=['case_admission_id_idx', 'feature'], how='left')
features_at_max_shap_values = pd.merge(max_over_time_shap_values_df, feature_at_max_shap_df, on=['case_admission_id_idx', 'feature'], how='left')

In [None]:
features_with_shap_values = features_at_max_shap_values

## Category preprocessing

In [None]:
pool_time_aggregated_features = True
if pool_time_aggregated_features:
    prefixes = ['avg_', 'min_', 'max_']
    # remove starting prefixes from feature names
    for prefix in prefixes:
        features_with_shap_values['feature'] = features_with_shap_values['feature'].str.replace(f'^{prefix}', '', regex=True)
    # sum the shap and feature values for each subject
    features_with_shap_values = features_with_shap_values.groupby(['case_admission_id_idx', 'feature']).sum().reset_index()



In [None]:
reverse_categorical_encoding = True

if reverse_categorical_encoding:
    cat_encoding_df = pd.read_csv(cat_encoding_path)
    for i in range(len(cat_encoding_df)):
        cat_basename = cat_encoding_df.sample_label[i].lower().replace(' ', '_')
        cat_item_list = cat_encoding_df.other_categories[i].replace('[', '').replace(']', '').replace('\'', '').split(', ')
        cat_item_list = [cat_basename + '_' + item.replace(' ', '_').lower() for item in cat_item_list]
        for cat_item_idx, cat_item in enumerate(cat_item_list):
            prefixes = ['', 'avg_', 'min_', 'max_']
            for prefix in prefixes:
                #  retrieve the dominant category for this subject (0 being default category)
                cat_item_with_prefix = prefix + cat_item
                features_with_shap_values.loc[features_with_shap_values.feature == cat_item_with_prefix, 'feature_value'] *= cat_item_idx + 1
                features_with_shap_values.loc[features_with_shap_values.feature == cat_item_with_prefix, 'feature'] = prefix + cat_encoding_df.sample_label[i]
                # sum the shap and feature values for each subject
                features_with_shap_values = features_with_shap_values.groupby(['case_admission_id_idx', 'feature']).sum().reset_index()

    # give a numerical encoding to the categorical features (manually extracted from the categorical_variable_encoding.csv file)
    cat_to_numerical_encoding = {
        'Prestroke disability (Rankin)': {0:0, 1:3, 2:4, 3:2, 4:1, 5:5},
        'categorical_onset_to_admission_time': {0:2, 1:1, 2:0, 3:3, 4:5, 5:4},
        'categorical_IVT': {0:2, 1:3, 2:4, 3:1, 4:0},
        'categorical_IAT': {0:1, 1:2, 2:3, 3:0}
    }

    for cat_feature, cat_encoding in cat_to_numerical_encoding.items():
        prefixes = ['', 'avg_', 'min_', 'max_']
        for prefix in prefixes:
            cat_feature = prefix + cat_feature
            features_with_shap_values.loc[features_with_shap_values.feature == cat_feature, 'feature_value'] = features_with_shap_values.loc[features_with_shap_values.feature == cat_feature, 'feature_value'].map(cat_encoding)

In [None]:
pool_hourly_split_values = True

# For features that are downsampled to hourly values, pool the values (median, min, max)

if pool_hourly_split_values:
    hourly_split_features = ['NIHSS', 'systolic_blood_pressure', 'diastolic_blood_pressure', 'mean_blood_pressure', 'heart_rate', 'respiratory_rate', 'temperature', 'oxygen_saturation']
    prefixes = ['', 'avg_', 'min_', 'max_']
    
    for feature in hourly_split_features:
        for prefix in prefixes:
            features_with_shap_values.loc[features_with_shap_values.feature.str.contains(feature), 'feature'] = (prefix + feature[0].upper() + feature[1:]
).replace('_', ' ')

In [None]:
# Replace feature names with their english names
feature_to_english_name_correspondence_path = os.path.join(os.path.dirname(os.path.abspath('__file__')),
                                                           'preprocessing/preprocessing_tools/feature_name_to_english_name_correspondence.xlsx')
feature_to_english_name_correspondence = pd.read_excel(feature_to_english_name_correspondence_path)

for feature in features_with_shap_values.feature.unique():
    if feature in feature_to_english_name_correspondence.feature_name.values:
        features_with_shap_values.loc[features_with_shap_values.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]

## Feature selection

Select only the features that are in the top 10 most important features by mean absolute shap value

In [None]:
# identify the top 10 most important features by mean absolute shap value
features_with_shap_values['absolute_shap_value'] = np.abs(features_with_shap_values['shap_value'])
top_10_features_by_mean_abs_summed_shap = features_with_shap_values.groupby('feature').mean().sort_values(by='absolute_shap_value', ascending=False).head(11).index.values
# drop the 'base_value' feature from the top 10 features
top_10_features_by_mean_abs_summed_shap = top_10_features_by_mean_abs_summed_shap[top_10_features_by_mean_abs_summed_shap != 'base_value']
top_10_features_by_mean_abs_summed_shap

In [None]:
# identify the top 10 most important features by mean minimum shap value
top_10_features_by_mean_min_shap = features_with_shap_values.groupby('feature')['shap_value'].mean().nsmallest(10).index.values
# drop the 'base_value' feature from the top 10 features
top_10_features_by_mean_min_shap = top_10_features_by_mean_min_shap[top_10_features_by_mean_min_shap != 'base_value']

# identify the top 10 most important features by mean maximum shap value
top_10_features_by_mean_max_shap = features_with_shap_values.groupby('feature')['shap_value'].mean().nlargest(10).index.values
# drop the 'base_value' feature from the top 10 features
top_10_features_by_mean_max_shap = top_10_features_by_mean_max_shap[top_10_features_by_mean_max_shap != 'base_value']

top_10_features_by_mean_min_shap, top_10_features_by_mean_max_shap



In [None]:
selected_features = top_10_features_by_mean_max_shap
selected_features_with_shap_values_df = features_with_shap_values[features_with_shap_values.feature.isin(selected_features)]

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple
import matplotlib.cm as cm


def plot_top_features_shap(selected_features_with_shap_values_df, selected_features,
        ax,
    plot_shap_direction_label = True,
    plot_legend = True,
    plot_colorbar = True,
    plot_feature_value_along_y = False,
    reverse_outcome_direction = False,   
    tick_label_size = 11,
    label_font_size = 13,
    row_height = 0.4,
    alpha = 0.8,
    xlim:tuple = None
    ):
    
    # Define the color palette
    start_color = '#012D98'
    end_color = '#f61067'
    number_of_colors = 50
    start_rgb = hex_to_rgb_color(start_color)
    end_rgb = hex_to_rgb_color(end_color)
    palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)
      
    
    for pos, feature in enumerate(selected_features[::-1]):
        shaps = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].shap_value.values
        values = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].feature_value
        ax.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
    
        values = np.array(values, dtype=np.float64)  # make sure this can be numeric
    
        N = len(shaps)
        nbins = 100
        quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
        inds = np.argsort(quant + np.random.randn(N) * 1e-6)
        layer = 0
        last_bin = -1
    
        if plot_feature_value_along_y:
            ys = values.copy()
            cluster_factor = 0.1
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] += cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
                layer += 1
                last_bin = quant[ind]
    
        else:
            ys = np.zeros(N)
            cluster_factor = 1
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] = cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
                layer += 1
                last_bin = quant[ind]
    
        ys *= 0.9 * (row_height / np.max(ys + 1))
    
        # trim the color range, but prevent the color range from collapsing
        vmin = np.nanpercentile(values, 5)
        vmax = np.nanpercentile(values, 95)
        if vmin == vmax:
            vmin = np.nanpercentile(values, 1)
            vmax = np.nanpercentile(values, 99)
            if vmin == vmax:
                vmin = np.min(values)
                vmax = np.max(values)
        if vmin > vmax: # fixes rare numerical precision issues
            vmin = vmax
    
        # plot the non-nan values colored by the trimmed feature value
        cvals = values.astype(np.float64)
        cvals_imp = cvals.copy()
        cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
        cvals[cvals_imp > vmax] = vmax
        cvals[cvals_imp < vmin] = vmin
        ax.scatter(shaps, pos + ys,
                   cmap=ListedColormap(palette), vmin=vmin, vmax=vmax, s=16,
                   c=cvals, alpha=alpha, linewidth=0,
                   zorder=3, rasterized=len(shaps) > 500)
    
    
    
    axis_color="#333333"
    if plot_colorbar:
        m = cm.ScalarMappable(cmap=ListedColormap(palette))
        m.set_array([0, 1])
        
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.5)

        # get fig from ax
        fig = ax.get_figure()
        cb = fig.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2, ax=cax)
        cb.set_ticklabels(['Low', 'High'])
        cb.ax.tick_params(labelsize=tick_label_size, length=0)
        cb.set_label('Feature value', size=label_font_size, backgroundcolor="white")
        cb.ax.yaxis.set_label_position('left')
        cb.set_alpha(1)
        cb.outline.set_visible(False)
        # turn off grid and spines on cax
        cax.grid(False)
        cax.spines['right'].set_visible(False)
        cax.spines['top'].set_visible(False)
        cax.spines['left'].set_visible(False)
        cax.spines['bottom'].set_visible(False)
        cax.set_xticks([])
        cax.set_yticks([])

    
    if plot_legend:
        legend_markers = []
        legend_labels = []
        single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',
                              markersize=10)
        single_dot_label = 'Single Patient\n(summed over time)'
        legend_markers.append(single_dot)
        legend_labels.append(single_dot_label)
    
        ax.legend(legend_markers, legend_labels, title='SHAP/Feature values', fontsize=tick_label_size, title_fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)},
                         loc='lower right', frameon=True)
    
    
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('none')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(color=axis_color, labelcolor=axis_color)
    
    yticklabels = selected_features[::-1]
    ax.set_yticks(range(len(selected_features_with_shap_values_df.feature.unique())))
    ax.set_yticklabels(yticklabels, fontsize=label_font_size)
    ax.tick_params('y', length=20, width=0.5, which='major')
    ax.tick_params('x', labelsize=tick_label_size)
    ax.set_ylim(-1, len(selected_features_with_shap_values_df.feature.unique()))
    ax.set_xlabel('SHAP Value \n(impact on model output)', fontsize=label_font_size)
    ax.grid(color='white', axis='y')
    
    if xlim:
        ax.set_xlim(xlim[0], xlim[1])
    
    # Plot additional explanation with the shap value X axis
    if plot_shap_direction_label:
        x_ticks_coordinates = ax.get_xticks()
        x_ticks_labels = [item.get_text() for item in ax.get_xticklabels()]
        # let x tick label be the coordinate with 2 decimals
    
        if reverse_outcome_direction:
            x_ticks_labels = [f'{x_ticks_coordinate:.1f}' for x_ticks_coordinate in x_ticks_coordinates]
            x_ticks_labels[0] = f'Toward better\noutcome'
            x_ticks_labels[-1] = f'Toward worse\noutcome'
        else:
            x_ticks_labels = [f'{x_ticks_coordinate:.1f}' for x_ticks_coordinate in x_ticks_coordinates]
            x_ticks_labels[0] = f'Toward worse\noutcome'
            x_ticks_labels[-1] = f'Toward better\noutcome'
    
        ax.set_xticks(x_ticks_coordinates)
        ax.set_xticklabels(x_ticks_labels)
    
    return ax

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

reverse_outcome_direction = True
plot_top_features_shap(selected_features_with_shap_values_df, selected_features,
    ax,
    reverse_outcome_direction=reverse_outcome_direction,
    xlim=(-0.6, 0.6)
    )

In [None]:
# fig.savefig('/Users/jk1/temp/opsum_end/testing/features_at_max_shap_value.png', bbox_inches='tight', dpi=600)