In [None]:
import pandas as pd
import pickle
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import matplotlib.patches as mpatches
from sklearn.utils import resample
from tqdm import tqdm
from matplotlib.legend_handler import HandlerTuple
import matplotlib.lines as mlines

In [None]:
mrs02_top_features_data_path = '/Users/jk1/temp/opsum_figure_temp_data/shap_top_features/3M_mRS_0-2_top_shap_features_figure_data.pkl'
death_top_features_data_path = '/Users/jk1/temp/opsum_figure_temp_data/shap_top_features/3M_Death_top_shap_features_figure_data.pkl'
mrs02_shap_over_time_data_path = '/Users/jk1/temp/opsum_figure_temp_data/shap_over_time/3M_mrs02_shap_along_time_figure_data.pkl'
death_shap_over_time_data_path = '/Users/jk1/temp/opsum_figure_temp_data/shap_over_time/3M_Death_shap_along_time_figure_data.pkl'

In [None]:
with open(os.path.join(mrs02_top_features_data_path), 'rb') as f:
        mrs02_selected_features_with_shap_values_df, mrs02_selected_features = pickle.load(f)
        
with open(os.path.join(death_top_features_data_path), 'rb') as f:
        death_selected_features_with_shap_values_df, death_selected_features = pickle.load(f)
        
with open(os.path.join(mrs02_shap_over_time_data_path), 'rb') as f:
        mrs02_summed_shap_along_features, mrs02_n_timesteps, mrs02_n_subj = pickle.load(f)
        
with open(os.path.join(death_shap_over_time_data_path), 'rb') as f:
        death_summed_shap_along_features, death_n_timesteps, death_n_subj = pickle.load(f)

Plotting functions

In [None]:
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple
import matplotlib
from colormath.color_objects import LabColor


def plot_top_features_shap(selected_features_with_shap_values_df, selected_features,
        ax,
    plot_shap_direction_label = True,
    plot_legend = True,
    plot_colorbar = True,
    plot_feature_value_along_y = False,
    reverse_outcome_direction = False,   
    tick_label_size = 11,
    label_font_size = 13,
    row_height = 0.4,
    alpha = 0.8):
    
    # Define the color palette
    start_color = '#012D98'
    end_color = '#f61067'
    number_of_colors = 50
    start_rgb = hex_to_rgb_color(start_color)
    end_rgb = hex_to_rgb_color(end_color)
    palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)
      
    
    for pos, feature in enumerate(selected_features[::-1]):
        shaps = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].shap_value.values
        values = selected_features_with_shap_values_df[selected_features_with_shap_values_df.feature.isin([feature])].feature_value
        ax.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
    
        values = np.array(values, dtype=np.float64)  # make sure this can be numeric
    
        N = len(shaps)
        nbins = 100
        quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
        inds = np.argsort(quant + np.random.randn(N) * 1e-6)
        layer = 0
        last_bin = -1
    
        if plot_feature_value_along_y:
            ys = values.copy()
            cluster_factor = 0.1
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] += cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
                layer += 1
                last_bin = quant[ind]
    
        else:
            ys = np.zeros(N)
            cluster_factor = 1
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] = cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))
                layer += 1
                last_bin = quant[ind]
    
        ys *= 0.9 * (row_height / np.max(ys + 1))
    
        # trim the color range, but prevent the color range from collapsing
        vmin = np.nanpercentile(values, 5)
        vmax = np.nanpercentile(values, 95)
        if vmin == vmax:
            vmin = np.nanpercentile(values, 1)
            vmax = np.nanpercentile(values, 99)
            if vmin == vmax:
                vmin = np.min(values)
                vmax = np.max(values)
        if vmin > vmax: # fixes rare numerical precision issues
            vmin = vmax
    
        # plot the non-nan values colored by the trimmed feature value
        cvals = values.astype(np.float64)
        cvals_imp = cvals.copy()
        cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
        cvals[cvals_imp > vmax] = vmax
        cvals[cvals_imp < vmin] = vmin
        ax.scatter(shaps, pos + ys,
                   cmap=ListedColormap(palette), vmin=vmin, vmax=vmax, s=16,
                   c=cvals, alpha=alpha, linewidth=0,
                   zorder=3, rasterized=len(shaps) > 500)
    
    
    
    axis_color="#333333"
    if plot_colorbar:
        m = matplotlib.cm.ScalarMappable(cmap=ListedColormap(palette))
        m.set_array([0, 1])
        
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.5)

        # get fig from ax
        fig = ax.get_figure()
        cb = fig.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2, ax=cax)
        cb.set_ticklabels(['Low', 'High'])
        cb.ax.tick_params(labelsize=tick_label_size, length=0)
        cb.set_label('Feature value', size=label_font_size)
        cb.ax.yaxis.set_label_position('left')
        cb.set_alpha(1)
        cb.outline.set_visible(False)
        # turn off grid and spines on cax
        cax.grid(False)
        cax.spines['right'].set_visible(False)
        cax.spines['top'].set_visible(False)
        cax.spines['left'].set_visible(False)
        cax.spines['bottom'].set_visible(False)
        cax.set_xticks([])
        cax.set_yticks([])
    
    if plot_legend:
        legend_markers = []
        legend_labels = []
        single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',
                              markersize=10)
        single_dot_label = 'Single Patient\n(summed over time)'
        legend_markers.append(single_dot)
        legend_labels.append(single_dot_label)
    
        ax.legend(legend_markers, legend_labels, title='SHAP/Feature values', fontsize=tick_label_size, title_fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)},
                         loc='upper left', frameon=True)
    
    
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('none')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(color=axis_color, labelcolor=axis_color)
    
    yticklabels = selected_features[::-1]
    ax.set_yticks(range(len(selected_features_with_shap_values_df.feature.unique())))
    ax.set_yticklabels(yticklabels, fontsize=label_font_size)
    ax.tick_params('y', length=20, width=0.5, which='major')
    ax.tick_params('x', labelsize=tick_label_size)
    ax.set_ylim(-1, len(selected_features_with_shap_values_df.feature.unique()))
    ax.set_xlabel('SHAP Value \n(impact on model output)', fontsize=label_font_size)
    ax.grid(color='white', axis='y')
    
    
    # Plot additional explanation with the shap value X axis
    if plot_shap_direction_label:
        x_ticks_coordinates = ax.get_xticks()
        x_ticks_labels = [item.get_text() for item in ax.get_xticklabels()]
        # let x tick label be the coordinate with 2 decimals
    
        if reverse_outcome_direction:
            x_ticks_labels = [f'{x_ticks_coordinate:.0f}' for x_ticks_coordinate in x_ticks_coordinates]
            x_ticks_labels[0] = f'Toward better\noutcome'
            x_ticks_labels[-1] = f'Toward worse\noutcome'
        else:
            x_ticks_labels = [f'{x_ticks_coordinate:.1f}' for x_ticks_coordinate in x_ticks_coordinates]
            x_ticks_labels[0] = f'Toward worse\noutcome'
            x_ticks_labels[-1] = f'Toward better\noutcome'
    
        ax.set_xticks(x_ticks_coordinates)
        ax.set_xticklabels(x_ticks_labels)
    
    return ax

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

plot_top_features_shap(
    mrs02_selected_features_with_shap_values_df, mrs02_selected_features,
        ax,
    plot_colorbar=True
)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

plot_top_features_shap(
    death_selected_features_with_shap_values_df, death_selected_features,
        ax,
 reverse_outcome_direction=True)

In [None]:
from matplotlib.colors import ListedColormap
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import LabColor


def plot_shap_along_time(summed_shap_along_features, n_timesteps, n_subj,
        ax,
    plot_legend = True,
    plot_colorbar = True,
    plot_median = True,
    tick_label_size = 11,
    label_font_size = 13,
                         alpha=0.9
    ):

    # create palette
    start_color= '#049b9a'
    end_color= '#012D98'
    number_of_colors = n_timesteps * n_subj
    start_rgb = hex_to_rgb_color(start_color)
    end_rgb = hex_to_rgb_color(end_color)    
    palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)

    ax = sns.scatterplot(x=np.tile(np.arange(0, n_timesteps), n_subj) + (np.random.rand(n_subj * n_timesteps) / 1), y=summed_shap_along_features.flatten(),
                         hue=np.log(summed_shap_along_features.flatten()), ax=ax,
                         alpha=0.05, legend=False, palette=palette)
    
    if plot_median:
        # plot median summed shap value on top
        median_color = '#f61067'
        ax = sns.lineplot(x=np.arange(0, n_timesteps), y=np.median(summed_shap_along_features, axis=0), ax=ax,
                          color=median_color, markers='.', lw=2, alpha=alpha)
    
    ax.set(yscale="log")
    
    ax.set_xlabel('Time from admission (hours)', fontsize=label_font_size)
    ax.set_ylabel('Sum of absolute SHAP values', fontsize=label_font_size)
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)
    
    if plot_colorbar:
        m = matplotlib.cm.ScalarMappable(cmap=ListedColormap(palette))
        m.set_array([0, 1])
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=1.2)

        # get fig from ax
        fig = ax.get_figure()
        
        cb = fig.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2, ax=cax)
        cb.set_ticklabels(['Small impact on \nmodel output', 'Large impact on \nmodel output'])
        cb.ax.tick_params(labelsize=tick_label_size, length=0)
        cb.set_label('SHAP values', size=label_font_size)
        cb.ax.yaxis.set_label_position('left')
        cb.set_alpha(1)
        cb.outline.set_visible(False)
        # turn off grid and spines on cax
        cax.grid(False)
        cax.spines['right'].set_visible(False)
        cax.spines['top'].set_visible(False)
        cax.spines['left'].set_visible(False)
        cax.spines['bottom'].set_visible(False)
        cax.set_xticks([])
        cax.set_yticks([])
    
    if plot_legend:
        legend_markers = []
        legend_labels = []
        single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',
                              markersize=10)
        single_dot_label = 'Single patient at timepoint t'
        legend_markers.append(single_dot)
        legend_labels.append(single_dot_label)
    
        if plot_median:
            median_line = mlines.Line2D([], [], color=median_color, linestyle='-')
            median_line_label = 'Median'
            legend_markers.append(median_line)
            legend_labels.append(median_line_label)
    
        ax.legend(legend_markers, legend_labels, title='Summed absolute SHAP', fontsize=tick_label_size, title_fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
        
    return ax

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

plot_shap_along_time(mrs02_summed_shap_along_features, mrs02_n_timesteps, mrs02_n_subj, ax, plot_legend=True, plot_colorbar=True, plot_median=True)
   

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

plot_shap_along_time(death_summed_shap_along_features, death_n_timesteps, death_n_subj, ax, plot_legend=True, plot_colorbar=True, plot_median=True)

# Plot combined plot

In [None]:
sns.set_theme(style="whitegrid", context="paper", font_scale = 1)

cm = 1/2.54  # centimeters in inches
main_fig = plt.figure(figsize=(18 * cm, 22 * cm))

tick_label_size = 6
label_font_size = 7
subplot_number_font_size = 9
suptitle_font_size = 10
plot_subplot_titles = True

plot_legend = True
plot_colorbar = False

subfigs = main_fig.subfigures(2, 1, height_ratios=[1, 1])
# increase space between subfigs
main_fig.subplots_adjust(hspace=4.5)

# MRS02
subfigs[0].suptitle('I. Prediction of functional outcome', fontsize=suptitle_font_size, horizontalalignment='left', x=0.1, y=1.)

ax1, ax2 = subfigs[0].subplots(1, 2)
subfigs[0].subplots_adjust(wspace=0.3)

plot_top_features_shap(
    mrs02_selected_features_with_shap_values_df, mrs02_selected_features,
    ax1,
    plot_legend=plot_legend,
    plot_colorbar=plot_colorbar,
    tick_label_size=tick_label_size,
    label_font_size=label_font_size
)

plot_shap_along_time(mrs02_summed_shap_along_features, mrs02_n_timesteps, mrs02_n_subj, ax2, 
                     plot_legend=plot_legend, plot_colorbar=plot_colorbar, plot_median=True,
                     tick_label_size=tick_label_size,
                     label_font_size=label_font_size)

if plot_subplot_titles:
    ax1.set_title('A.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)
    ax2.set_title('B.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)
    
    
# Death
subfigs[1].suptitle('II. Prediction of mortality', fontsize=suptitle_font_size, horizontalalignment='left', x=0.1, y=1.0)

ax3, ax4 = subfigs[1].subplots(1, 2)
subfigs[1].subplots_adjust(wspace=0.3)

plot_top_features_shap(
    death_selected_features_with_shap_values_df, death_selected_features,
    ax3,
    plot_legend=plot_legend,
    plot_colorbar=plot_colorbar,
    tick_label_size=tick_label_size,
    label_font_size=label_font_size
)

plot_shap_along_time(death_summed_shap_along_features, death_n_timesteps, death_n_subj, ax4, 
                     plot_legend=plot_legend, plot_colorbar=plot_colorbar, plot_median=True,
                     tick_label_size=tick_label_size,
                     label_font_size=label_font_size)

if plot_subplot_titles:
    ax3.set_title('C.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)
    ax4.set_title('D.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)

In [None]:
output_dir = '/Users/jk1/Downloads'
# main_fig.savefig(os.path.join(output_dir, 'comparative_performances_combined.svg'), bbox_inches="tight", format='svg', dpi=1200)
