In [None]:
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from random import randint
import pandas as pd
import pickle
import torch as ch
from tqdm import tqdm
from matplotlib.lines import Line2D
from prediction.utils.utils import smooth, filter_consecutive_numbers
from prediction.utils.visualisation_helper_functions import reverse_normalisation_for_subj, LegendTitle
from preprocessing.preprocessing_tools.normalisation.reverse_normalisation import reverse_normalisation
from prediction.utils.shap_helper_functions import check_shap_version_compatibility
from sklearn.model_selection import train_test_split
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time, \
    link_patient_id_to_outcome, features_to_numpy, numpy_to_lookup_table
from prediction.utils.visualisation_helper_functions import density_jitter
from matplotlib.legend_handler import HandlerTuple

In [None]:
shap_values_path = '/Users/jk1/temp/opsum_end/training/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517/checkpoints_short_opsum_xgb_20250518_001112_cv_1/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'
test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/categorical_variable_encoding.csv'

normalisation_parameters_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/normalisation_parameters.csv'
predictions_path = '/Users/jk1/temp/opsum_end/testing/test_gt_and_pred_cv_1.pkl'

In [None]:
n_time_steps = 72

In [None]:
# load the shap values
with open(os.path.join(shap_values_path), 'rb') as handle:
    original_shap_values = pickle.load(handle)

In [None]:
only_last_timestep = False
if only_last_timestep:
    # use predictions from last timestep (as it also produces output for other timesteps)
    shap_values = [original_shap_values[-1]]

else:
    shap_values = [np.array([original_shap_values[i] for i in range(len(original_shap_values))]).swapaxes(0, 1)][0]

In [None]:
normalisation_parameters_df = pd.read_csv(normalisation_parameters_path)

In [None]:
with open(predictions_path, 'rb') as handle:
    gt_over_time, predictions_over_time = pickle.load(handle)
    
gt_over_time = gt_over_time.reshape(-1, n_time_steps)
predictions_over_time = predictions_over_time.reshape(-1, n_time_steps)

In [None]:
X_test, full_y_test = ch.load(test_data_path)

In [None]:
test_X_np = X_test[:, :, :, -1].astype('float32')

In [None]:
features = X_test[0, 0, :, 2]

# features, avg_features, min_features, max_features
# create feature names for avg_features, min_features, max_features
avg_features = [f'avg_{i}' for i in features]
min_features = [f'min_{i}' for i in features]
max_features = [f'max_{i}' for i in features]
# combine all feature names
aggregated_feature_names = features.tolist() + avg_features + min_features + max_features + ['base_value']

## Create working data frame
Join data in a common dataframe with shap values and feature values

In [None]:
reverse_categorical_encoding = True
pool_hourly_split_values = True
only_keep_current_value_shap = True # (do not use aggregated features)
# pool_time_aggregated_features = True


In [None]:
shap_values.shape

In [None]:
shap_values_df = pd.DataFrame()
for ts in tqdm(range(n_time_steps)):
    ts_shap_values_df = pd.DataFrame(data=shap_values[:, ts], columns = np.array(aggregated_feature_names))
    ts_shap_values_df = ts_shap_values_df.reset_index()
    ts_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)
    ts_shap_values_df = ts_shap_values_df.melt(id_vars='case_admission_id_idx',  var_name='feature', value_name='shap_value')
    ts_shap_values_df['time_step'] = ts
    shap_values_df = shap_values_df.append(ts_shap_values_df)

In [None]:
# do not use aggregated features (avg, min, max)
if only_keep_current_value_shap:
    shap_values_df = shap_values_df[shap_values_df['feature'].isin(features)]

In [None]:
feature_values_df = pd.DataFrame()
for subj_idx in tqdm(range(test_X_np.shape[0])):
    subj_feature_values_df = pd.DataFrame(data=test_X_np[subj_idx, :, :], columns = np.array(features))
    subj_feature_values_df = reverse_normalisation_for_subj(subj_feature_values_df, normalisation_parameters_df)
    subj_feature_values_df = subj_feature_values_df.reset_index()
    subj_feature_values_df.rename(columns={'index': 'time_step'}, inplace=True)
    subj_feature_values_df['case_admission_id_idx'] = subj_idx
    subj_feature_values_df = subj_feature_values_df.melt(id_vars=['case_admission_id_idx', 'time_step'],  var_name='feature', value_name='feature_value')
    feature_values_df = feature_values_df.append(subj_feature_values_df)

In [None]:
shap_aggregation_func = 'sum' # 'median' or 'sum' (but sum makes mokes more sense)

if reverse_categorical_encoding:
    cat_encoding_df = pd.read_csv(cat_encoding_path)
    for i in tqdm(range(len(cat_encoding_df))):
        cat_basename = cat_encoding_df.sample_label[i].lower().replace(' ', '_')
        cat_item_list = cat_encoding_df.other_categories[i].replace('[', '').replace(']', '').replace('\'', '').split(', ')
        cat_item_list = [cat_basename + '_' + item.replace(' ', '_').lower() for item in cat_item_list]
        for cat_item_idx, cat_item in enumerate(cat_item_list):
            #  retrieve the dominant category for this subject (0 being default category)
            feature_values_df.loc[feature_values_df.feature == cat_item, 'feature_value'] *= cat_item_idx + 1
            feature_values_df.loc[feature_values_df.feature == cat_item, 'feature'] = cat_encoding_df.sample_label[i]
            feature_values_df = feature_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).sum().reset_index()

            shap_values_df.loc[shap_values_df.feature == cat_item, 'feature'] = cat_encoding_df.sample_label[i]
            # sum the shap and feature values for each subject
            if shap_aggregation_func:
                shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).sum().reset_index()
            else:
                shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).median().reset_index()

    # give a numerical encoding to the categorical features
    cat_to_numerical_encoding = {
        'Prestroke disability (Rankin)': {0:0, 1:3, 2:4, 3:2, 4:1, 5:5},
        'categorical_onset_to_admission_time': {0:2, 1:1, 2:0, 3:3, 4:5, 5:4},
        'categorical_IVT': {0:2, 1:3, 2:4, 3:1, 4:0},
        'categorical_IAT': {0:1, 1:2, 2:3, 3:0}
    }

    for cat_feature, cat_encoding in cat_to_numerical_encoding.items():
        feature_values_df.loc[feature_values_df.feature == cat_feature, 'feature_value'] = feature_values_df.loc[feature_values_df.feature == cat_feature, 'feature_value'].map(cat_encoding)


In [None]:
# For features that are downsampled to hourly values, pool the values (median, min, max)
if pool_hourly_split_values:
    hourly_split_features = ['NIHSS', 'systolic_blood_pressure', 'diastolic_blood_pressure', 'mean_blood_pressure', 'heart_rate', 'respiratory_rate', 'temperature', 'oxygen_saturation']
    for feature in tqdm(hourly_split_features):
        shap_values_df.loc[shap_values_df.feature.str.contains(feature), 'feature'] = (feature[0].upper() + feature[1:]).replace('_', ' ')
        if shap_aggregation_func == 'median':
            shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).median().reset_index()
        elif shap_aggregation_func == 'sum':
            shap_values_df = shap_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).sum().reset_index()

        feature_values_df.loc[feature_values_df.feature.str.contains(feature), 'feature'] = (feature[0].upper() + feature[1:]).replace('_', ' ')
        feature_values_df = feature_values_df.groupby(['case_admission_id_idx', 'feature', 'time_step']).median().reset_index()

In [None]:
# Replace feature names with their english names
feature_to_english_name_correspondence_path = os.path.join(os.path.dirname(os.path.abspath('__file__')),
                                                           'preprocessing/preprocessing_tools/feature_name_to_english_name_correspondence.xlsx')
feature_to_english_name_correspondence = pd.read_excel(feature_to_english_name_correspondence_path)

for feature in shap_values_df.feature.unique():
    if feature in feature_to_english_name_correspondence.feature_name.values:
        shap_values_df.loc[shap_values_df.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]

for feature in feature_values_df.feature.unique():
    if feature in feature_to_english_name_correspondence.feature_name.values:
        feature_values_df.loc[feature_values_df.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]

In [None]:
use_simplified_shap_values = True

if use_simplified_shap_values:
    shap_values_over_time = []
    for ts in tqdm(range(n_time_steps)):
        subj_values_over_time = []
        for subj in range(len(test_X_np)):
            subj_values_over_time.append(shap_values_df[(shap_values_df.case_admission_id_idx == subj) & (shap_values_df.time_step == ts)].shap_value.values)
        shap_values_over_time.append(np.array(subj_values_over_time))
    shap_values_over_time = np.array(shap_values_over_time)

In [None]:
reduced_feature_names = shap_values_df.feature.unique()

In [None]:
smoothing_window = 15
smoothed_shap_values_over_time = []
for subj_idx in range(shap_values_over_time.shape[1]):
    subj_smoothed_shap_values_over_time = []
    for feature_idx in range(shap_values_over_time.shape[2]):
        subj_smoothed_shap_values_over_time.append(smooth(shap_values_over_time[:, subj_idx, feature_idx], smoothing_window))
    smoothed_shap_values_over_time.append(np.moveaxis(subj_smoothed_shap_values_over_time, 0, -1))
smoothed_shap_values_over_time = np.moveaxis(smoothed_shap_values_over_time, 0, 1)

## Choose subject and load prediction

In [None]:
subj = randint(0, len(test_X_np))
subj

In [None]:
subj = 96

In [None]:
subj_pred_over_ts = predictions_over_time[subj, :]
subj_gt_over_ts = gt_over_time[subj, :]
subj_pred_over_ts.shape, subj_gt_over_ts.shape

## Plot overall subject prediction & explanation

In [None]:
n_features = 3

In [None]:
plot_legend = True
plot_shap_direction_label = True

tick_label_size = 13
label_font_size = 15

fig1 = plt.figure(figsize=(10,7.5))
ax1 = fig1.add_subplot(111)

# plot a bar plot showing impact of most important features on the prediction across all n_time_steps
# find index of 3 features with biggest positive shap impart
selected_positive_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[-n_features:][::-1]
# find index of 3 features with biggest negative shap impart
selected_negative_features = np.squeeze(shap_values_over_time[-1])[subj].argsort()[:n_features][::-1]
selected_features = np.concatenate((selected_positive_features, selected_negative_features))


non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')
non_norm_subj_df.loc[non_norm_subj_df['IVT timing'] == 1, 'IVT timing'] = '91-270min'

y_labels = []
for fidx, feature in enumerate(selected_features):
    feature_name = reduced_feature_names[feature]
    feature_value = non_norm_subj_df[feature_name].iloc[-1]
    if feature_name in ['NIHSS', 'Prestroke disability (Rankin)']:
        y_labels.append(f'{feature_name}: {int(feature_value)}')
    elif type(feature_value) == str:
        y_labels.append(f'{feature_name}: {feature_value}')
    else:
        y_labels.append(f'{feature_name}: {feature_value:.1f}')

ax = sns.barplot(y=y_labels, x=np.squeeze(shap_values_over_time[-1])[subj][selected_features], palette="RdBu_r", ax=ax1)

# Plot additional explanation with the shap value X axis
if plot_shap_direction_label:
    x_ticks_coordinates = plt.xticks()[0]
    x_ticks_labels = [item.get_text() for item in plt.xticks()[1]]
    # let x tick label be the coordinate with 2 decimals
    x_ticks_labels = [f'{x_ticks_coordinate:.02f}' for x_ticks_coordinate in x_ticks_coordinates]

    x_ticks_labels[0] = f'Toward better\noutcome'
    x_ticks_labels[-1] = f'Toward worse\noutcome'
    plt.xticks(x_ticks_coordinates, x_ticks_labels)

ax.tick_params(axis='y', labelsize=label_font_size)
ax.tick_params(axis='x', labelsize=tick_label_size)

ax.set_xlabel('SHAP Value \n(impact on model output)', fontsize=label_font_size)

## Plot evolution of prediction & explanation over time

In [None]:
overall_prevailing_features = False
weigh_by_feature_value = False
use_smoothed_shap_values = True

In [None]:
if use_smoothed_shap_values:
    working_shap_values = smoothed_shap_values_over_time
else:
    working_shap_values = shap_values_over_time

cumulative_shap_values_over_time = np.array([working_shap_values[ts].sum(axis=1) for ts in range(n_time_steps)])

# find index of 3 features with biggest positive shap impart & index of 3 features with biggest negative shap impart
if overall_prevailing_features:
    # prevailing features over cumulative time
    selected_negative_features = cumulative_shap_values_over_time[:, subj].argsort()[:n_features][::-1]
    selected_positive_features = cumulative_shap_values_over_time[:, subj].argsort()[-n_features:][::-1]
else:
    # prevailing features at last timepoint
    selected_positive_features = np.squeeze(working_shap_values[-1])[subj].argsort()[-n_features:][::-1]
    selected_negative_features = np.squeeze(working_shap_values[-1])[subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))

fig3 = plt.figure(figsize=(15,10))

k=0.1
alpha=0.3

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(selected_positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(selected_negative_features))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='probability', linewidth = 2)


pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
feature_color_dict = {}

for i, feature in enumerate(selected_features):
    subj_feature_shap_value_over_time = working_shap_values[:, subj, feature]
    positive_portion = (subj_feature_shap_value_over_time > 0)
    negative_portion = (subj_feature_shap_value_over_time < 0)

    pos_function = subj_feature_shap_value_over_time.copy()
    pos_function[negative_portion] = 0

    neg_function = subj_feature_shap_value_over_time.copy()
    neg_function[positive_portion] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    feature_color_dict[feature] = feature_color

    positive_feature = pos_baseline + k * pos_function
    if weigh_by_feature_value:
        positive_feature *= X_test[subj, :, feature] / X_test[:, :, feature].max()
    ax.fill_between(timestep_axis, pos_baseline, positive_feature, color=feature_color, alpha=alpha, label=reduced_feature_names[feature])
    pos_baseline = positive_feature

    negative_feature = neg_baseline + k * neg_function
    if weigh_by_feature_value:
        negative_feature *= X_test[subj, :, feature] / X_test[:, :, feature].max()
    ax.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
    neg_baseline = negative_feature

ax.legend(fontsize='x-large')

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

# ax.set_ylim(0,1)

plt.plot()



## Identify features driving changes in prediction over time

In [None]:
threshold = 0.035
n_features_selection = 0
k=0.25
alpha=0.3
n_features = 2
only_non_static_features = True
use_smoothed_shap_values = True
plot_ground_truth = True

display_significant_slopes = True
n_slope_steps = 5
slope_threshold = 1.5 * threshold

display_text_labels = True
display_legend = True
display_title = False
plot_NIHSS_continuously = True
ts_marker_level = 'baseline' # 'baseline' (marker on predicted probability function) or 'shap' (marker on SHAP value function)

fig3 = plt.figure(figsize=(15,10))

tick_label_size = 13
label_font_size = 16

if use_smoothed_shap_values:
    working_shap_values = smoothed_shap_values_over_time
else:
    working_shap_values = shap_values_over_time

# identify significant changes in prediction over time by a change in threshold X% in prediction
significant_positive_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) > threshold)[0])
significant_negative_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) < -threshold)[0])
significant_timesteps = np.concatenate((significant_positive_timesteps, significant_negative_timesteps))

non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')

# for each timestep, identify the feature that has the largest impact on the prediction
if only_non_static_features:
    # find non static columns in non normed df
    non_static_features = np.where(non_norm_subj_df.std() > 0.01)[0]
    if use_simplified_shap_values:
       non_static_features = np.where(np.isin(reduced_feature_names, np.array(non_norm_subj_df.std()[non_norm_subj_df.std() > 0.01].index)))[0]
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_positive_features_by_impact = non_static_features[selected_positive_features_by_impact]
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_negative_timesteps].argmin(axis=1)
    selected_negative_features_by_impact = non_static_features[selected_negative_features_by_impact]
else:
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_negative_timesteps].argmin(axis=1)

selected_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_negative_features_by_impact))

if display_significant_slopes:
    # identify features that are driving the change in prediction over time more gentle slopes (then filter out consecutive timesteps)
    significant_positive_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] > slope_threshold)[0]).difference(set(significant_positive_timesteps)))

    significant_negative_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] < -slope_threshold)[0]).difference(set(significant_negative_timesteps)))

    delta_shap_by_features = np.concatenate((working_shap_values[n_slope_steps:, subj, non_static_features], np.zeros((n_slope_steps, len(non_static_features))))) - working_shap_values[:, subj, non_static_features]
    selected_positive_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_positive_slope].argmax(axis=1)
    selected_positive_features_by_slope = non_static_features[selected_positive_features_by_slope]
    selected_negative_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_negative_slope].argmin(axis=1)
    selected_negative_features_by_slope = non_static_features[selected_negative_features_by_slope]

    selected_features_by_impact = np.concatenate((selected_features_by_impact, selected_positive_features_by_slope, selected_negative_features_by_slope))
    significant_timesteps = np.concatenate((significant_timesteps, significant_positive_slope, significant_negative_slope))
    selected_positive_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_positive_features_by_slope))
    selected_negative_features_by_impact = np.concatenate((selected_negative_features_by_impact, selected_negative_features_by_slope))

if n_features_selection == 0:
    selected_positive_features = np.array([])
    selected_negative_features = np.array([])
else:
    selected_positive_features = working_shap_values[-1,subj].argsort()[-n_features:][::-1]
    selected_negative_features = working_shap_values[-1,subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_positive_features_by_impact, selected_negative_features, selected_negative_features_by_impact)).astype(int)


positive_color_palette = sns.color_palette("mako", n_colors=len(set(np.concatenate((selected_positive_features, selected_positive_features_by_impact)))))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(set(np.concatenate((selected_negative_features, selected_negative_features_by_impact)))))

# plot prediction over time
timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='Predicted probability', linewidth = 2)

# plot ground truth over time
if plot_ground_truth:
    # get changes in subj_gt_over_ts
    changes_in_gt = np.diff(subj_gt_over_ts, prepend=0)
    # pair np.where(changes_in_gt == -1) and np.where(changes_in_gt == 1
    change_pairs = list(zip(np.where(changes_in_gt == 1)[0], np.where(changes_in_gt == -1)[0]))
    for change_pair in change_pairs:
        # mark a thick horizontal red line at the ground truth value (-1)
        ax.plot([change_pair[0], change_pair[1]], [0, 0], color="#7b002c", linewidth=10, alpha=0.8)
        ax.text(np.mean(change_pair), 0 + 0.02, '6h to END', horizontalalignment='center', verticalalignment='center', fontsize=tick_label_size)

pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
feature_color_dict = {}
for i, feature in enumerate(set(selected_features)):
    subj_feature_shap_value_over_time = working_shap_values[:, subj, feature]
    positive_portion = (subj_feature_shap_value_over_time > 0)
    negative_portion = (subj_feature_shap_value_over_time < 0)


    pos_function = subj_feature_shap_value_over_time.copy()
    neg_function = subj_feature_shap_value_over_time.copy()
    pos_function[negative_portion] = 0
    neg_function[positive_portion] = 0

    if feature in selected_features_by_impact:
        important_ts_idx = np.where(selected_features_by_impact == feature)[0]
        # set value to zero before the significant timestep (except for NIHSS if plotting continuously)
        if not np.logical_and(plot_NIHSS_continuously, reduced_feature_names[feature] == 'NIHSS'):
            pos_function[:significant_timesteps[important_ts_idx][0] + 1] = 0
            neg_function[:significant_timesteps[important_ts_idx][0] + 1] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    elif feature in selected_negative_features:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_negative_features_by_impact:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_positive_features_by_impact:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    feature_color_dict[feature] = feature_color

    if np.any(pos_function):
        positive_feature = pos_baseline + k * pos_function
        ax.fill_between(timestep_axis , pos_baseline, positive_feature, color=feature_color, alpha=alpha)
        pos_baseline = positive_feature

    if np.any(neg_function):
        negative_feature = neg_baseline + k * neg_function
        ax.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
        neg_baseline = negative_feature

    # add a legend entry for the feature fill
    ax.scatter([], [], color=feature_color, alpha=alpha, label=reduced_feature_names[feature],marker="s", s=200)


# marking inflection points
for feature in set(selected_features_by_impact):
    important_ts_idx = np.where(selected_features_by_impact == feature)[0]
    for ts_idx in important_ts_idx:
        # downward inflection point
        if subj_pred_over_ts[significant_timesteps[ts_idx]] > subj_pred_over_ts[significant_timesteps[ts_idx] + 1]:
            marker = 'v'
            if ts_marker_level == 'shap':
                marker_y_level = pos_baseline[significant_timesteps[ts_idx]] + 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] + 0.005
            text_y_level = marker_y_level + 0.01
        # upward inflection point
        else:
            marker = '^'
            if ts_marker_level == 'shap':
                marker_y_level = neg_baseline[significant_timesteps[ts_idx]] - 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] - 0.005
            text_y_level = marker_y_level - 0.015

        ax.scatter(significant_timesteps[ts_idx], marker_y_level, color=feature_color_dict[feature], s=100, marker=marker, alpha=1, edgecolors='white')
        # insert a label on the plot
        if display_text_labels:
            # rotate the text label by 45 degrees (up if downward inflection, down if upward inflection)
            # ax.text(significant_timesteps[ts_idx]+ 0.01, text_y_level, reduced_feature_names[feature], fontsize=12, color='black')
            if marker == 'v':
                ax.text(significant_timesteps[ts_idx] + 0.01, text_y_level, reduced_feature_names[feature], fontsize=12, color='black', rotation=45)
            else:
                ax.text(significant_timesteps[ts_idx] + 0.01, text_y_level, reduced_feature_names[feature], fontsize=12, color='black', rotation=-45,
                        horizontalalignment='left', verticalalignment='top', rotation_mode='anchor')


if display_title:
    ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)

ax.set_xlabel('Time from admission (hours)', fontsize=label_font_size)
ax.set_ylabel('Probability of favorable outcome', fontsize=label_font_size)
ax.tick_params(axis='both', labelsize=tick_label_size)

if display_legend:
    legend_markers, legend_labels = ax.get_legend_handles_labels()

    # shap value shades
    shap_shades_markers = legend_markers[1:]
    shap_shades_labels = legend_labels[1:]
    legend_markers = [legend_markers[0]]
    legend_labels = [legend_labels[0]]

    # add a legend entry for the timestep markers
    ts_marker_down = plt.scatter([], [], marker='v', color='grey', s=50, alpha=0.8)
    ts_marker_up = plt.scatter([], [], marker='^', color='grey', s=50, alpha=0.8)
    ts_label = 'Positive / Negative impact on inflection of prediction'
    legend_markers.append((ts_marker_up, ts_marker_down))
    legend_labels.append(ts_label)

    # Add a subtitle for shape value shades
    legend_markers.append('')
    legend_labels.append('')
    legend_markers.append('Weight & direction of influence on model prediction')
    legend_labels.append('')

    legend_markers += shap_shades_markers
    legend_labels += shap_shades_labels

    ax.legend(legend_markers, legend_labels, fontsize=label_font_size, title='Influence on model prediction', title_fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None), str: LegendTitle({'fontsize': label_font_size})}, bbox_to_anchor=(1.05, 1), loc='upper left')

# remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.plot()

In [None]:
# fig5, (ax_main, axr) = plt.subplots(
#     ncols=2,
#     figsize=(25, 10),
#     gridspec_kw=dict(width_ratios=[1.5, 1], wspace=0.1),
# )

fig5, (ax_main, axr) = plt.subplots(
    ncols=1,
    nrows=2,
    sharex=True,
    figsize=(15, 15),
    gridspec_kw=dict(hspace=0.05, height_ratios=[1, 0.5]),
)


#### FIGURE 3 ####
threshold = 0.05
n_features_selection = 0
k=0.25
alpha=0.3
only_non_static_features = True
use_smoothed_shap_values = True

display_significant_slopes = True
n_slope_steps = 5
slope_threshold = 1.5 * threshold

display_text_labels = False
display_legend = True
display_title = False
plot_NIHSS_continuously = True
ts_marker_level = 'shap' # 'baseline' (marker on predicted probability function) or 'shap' (marker on SHAP value function)

tick_label_size = 13
label_font_size = 16

if use_smoothed_shap_values:
    working_shap_values = smoothed_shap_values_over_time
else:
    working_shap_values = shap_values_over_time

# identify significant changes in prediction over time by a change in threshold X% in prediction
significant_positive_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) > threshold)[0])
significant_negative_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) < -threshold)[0])
significant_timesteps = np.concatenate((significant_positive_timesteps, significant_negative_timesteps))

non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')

# for each timestep, identify the feature that has the largest impact on the prediction
if only_non_static_features:
    # find non static columns in non normed df
    non_static_features = np.where(non_norm_subj_df.std() > 0.01)[0]
    if use_simplified_shap_values:
       non_static_features = np.where(np.isin(reduced_feature_names, np.array(non_norm_subj_df.std()[non_norm_subj_df.std() > 0.01].index)))[0]
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_positive_features_by_impact = non_static_features[selected_positive_features_by_impact]
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_negative_timesteps].argmin(axis=1)
    selected_negative_features_by_impact = non_static_features[selected_negative_features_by_impact]
else:
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_negative_timesteps].argmin(axis=1)

selected_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_negative_features_by_impact))

if display_significant_slopes:
    # identify features that are driving the change in prediction over time more gentle slopes (then filter out consecutive timesteps)
    significant_positive_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] > slope_threshold)[0]).difference(set(significant_positive_timesteps)))

    significant_negative_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] < -slope_threshold)[0]).difference(set(significant_negative_timesteps)))

    delta_shap_by_features = np.concatenate((working_shap_values[n_slope_steps:, subj, non_static_features], np.zeros((n_slope_steps, len(non_static_features))))) - working_shap_values[:, subj, non_static_features]
    selected_positive_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_positive_slope].argmax(axis=1)
    selected_positive_features_by_slope = non_static_features[selected_positive_features_by_slope]
    selected_negative_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_negative_slope].argmin(axis=1)
    selected_negative_features_by_slope = non_static_features[selected_negative_features_by_slope]

    selected_features_by_impact = np.concatenate((selected_features_by_impact, selected_positive_features_by_slope, selected_negative_features_by_slope))
    significant_timesteps = np.concatenate((significant_timesteps, significant_positive_slope, significant_negative_slope))
    selected_positive_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_positive_features_by_slope))
    selected_negative_features_by_impact = np.concatenate((selected_negative_features_by_impact, selected_negative_features_by_slope))

if n_features_selection == 0:
    selected_positive_features = np.array([])
    selected_negative_features = np.array([])
else:
    selected_positive_features = working_shap_values[-1,subj].argsort()[-n_features:][::-1]
    selected_negative_features = working_shap_values[-1,subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_positive_features_by_impact, selected_negative_features, selected_negative_features_by_impact)).astype(int)


positive_color_palette = sns.color_palette("mako", n_colors=len(set(np.concatenate((selected_positive_features, selected_positive_features_by_impact)))))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(set(np.concatenate((selected_negative_features, selected_negative_features_by_impact)))))

# plot prediction over time
timestep_axis = np.array(range(n_time_steps))
sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='Predicted probability', linewidth = 2, ax=ax_main)

pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
feature_color_dict = {}
for i, feature in enumerate(set(selected_features)):
    subj_feature_shap_value_over_time = working_shap_values[:, subj, feature]
    positive_portion = (subj_feature_shap_value_over_time > 0)
    negative_portion = (subj_feature_shap_value_over_time < 0)


    pos_function = subj_feature_shap_value_over_time.copy()
    neg_function = subj_feature_shap_value_over_time.copy()
    pos_function[negative_portion] = 0
    neg_function[positive_portion] = 0

    if feature in selected_features_by_impact:
        important_ts_idx = np.where(selected_features_by_impact == feature)[0]
        # set value to zero before the significant timestep (except for NIHSS if plotting continuously)
        if not np.logical_and(plot_NIHSS_continuously, reduced_feature_names[feature] == 'NIHSS'):
            pos_function[:significant_timesteps[important_ts_idx][0] + 1] = 0
            neg_function[:significant_timesteps[important_ts_idx][0] + 1] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    elif feature in selected_negative_features:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_negative_features_by_impact:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_positive_features_by_impact:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    feature_color_dict[feature] = feature_color

    if np.any(pos_function):
        positive_feature = pos_baseline + k * pos_function
        ax_main.fill_between(timestep_axis , pos_baseline, positive_feature, color=feature_color, alpha=alpha)
        pos_baseline = positive_feature

    if np.any(neg_function):
        negative_feature = neg_baseline + k * neg_function
        ax_main.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
        neg_baseline = negative_feature

    # add a legend entry for the feature fill
    ax_main.scatter([], [], color=feature_color, alpha=alpha, label=reduced_feature_names[feature],marker="s", s=200)


# marking inflection points
for feature in set(selected_features_by_impact):
    important_ts_idx = np.where(selected_features_by_impact == feature)[0]
    for ts_idx in important_ts_idx:
        # downward inflection point
        if subj_pred_over_ts[significant_timesteps[ts_idx]] > subj_pred_over_ts[significant_timesteps[ts_idx] + 1]:
            marker = 'v'
            if ts_marker_level == 'shap':
                marker_y_level = pos_baseline[significant_timesteps[ts_idx]] + 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] + 0.005
            text_y_level = marker_y_level + 0.01
        # upward inflection point
        else:
            marker = '^'
            if ts_marker_level == 'shap':
                marker_y_level = neg_baseline[significant_timesteps[ts_idx]] - 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] - 0.005
            text_y_level = marker_y_level - 0.015

        ax_main.scatter(significant_timesteps[ts_idx], marker_y_level, color=feature_color_dict[feature], s=100, marker=marker, alpha=1, edgecolors='white')
        # insert a label on the plot
        if display_text_labels:
            ax_main.text(significant_timesteps[ts_idx]+ 0.01, text_y_level, reduced_feature_names[feature], fontsize=12, color='black')


if display_title:
    ax_main.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)

ax_main.set_xlabel('Time from admission (hours)', fontsize=label_font_size)
ax_main.set_ylabel('Probability of favorable outcome', fontsize=label_font_size)
ax_main.tick_params(axis='both', labelsize=tick_label_size)

if display_legend:
    legend_markers, legend_labels = ax_main.get_legend_handles_labels()

    # shap value shades
    shap_shades_markers = legend_markers[1:]
    shap_shades_labels = legend_labels[1:]
    legend_markers = [legend_markers[0]]
    legend_labels = [legend_labels[0]]

    # add a legend entry for the timestep markers
    ts_marker_down = plt.scatter([], [], marker='v', color='grey', s=50, alpha=0.8)
    ts_marker_up = plt.scatter([], [], marker='^', color='grey', s=50, alpha=0.8)
    ts_label = 'Positive / Negative impact on inflection of prediction'
    legend_markers.append((ts_marker_up, ts_marker_down))
    legend_labels.append(ts_label)

    # Add a subtitle for shape value shades
    legend_markers.append('')
    legend_labels.append('')
    legend_markers.append('Weight & direction of influence on model prediction')
    legend_labels.append('')

    legend_markers += shap_shades_markers
    legend_labels += shap_shades_labels

    ax_main.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., fontsize=label_font_size, title='Influence on model prediction', title_fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None), str: LegendTitle({'fontsize': label_font_size})})


##### Figure 4 #####

smooth_values = True
plot_legend = True
smoothing_window = 2
plot_normalized_features = False
display_inflection_markers = True

legend_markers, legend_labels = [], []

timestep_axis = np.array(range(n_time_steps))

if plot_normalized_features:
    # plot all normalized features in the background (grey)
    static_feature_val0_count = 0
    static_feature_val1_count = 0
    for norm_feature_idx in range(test_X_np.shape[-1]):
        if np.all(test_X_np[subj, :, norm_feature_idx] == 0):
            if static_feature_val0_count == 0:
                static_feature_val0_count += 1
            else:
                continue
        elif np.all(test_X_np[subj, :, norm_feature_idx] == 1):
            if static_feature_val1_count == 0:
                static_feature_val1_count += 1
            else:
                continue
        axr.plot(timestep_axis, test_X_np[subj, :, norm_feature_idx], color='grey', alpha=0.075)

# plot selected non static features in color on top
non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')
twin_xs = []
for fidx, feature in enumerate(set(selected_features_by_impact)):
    if plot_normalized_features or fidx != 0:
        twin_xs.append(axr.twinx())
    else:
        # first feature is plotted on the left axis if no normalized features are displayed
        twin_xs.append(axr)
    feature_name = reduced_feature_names[feature]
    feature_color = feature_color_dict[feature]
    if smooth_values:
        non_norm_subj_df = non_norm_subj_df.rolling(window=smoothing_window, min_periods=1, center=True).mean()

    sns.lineplot(y=feature_name, x=non_norm_subj_df.index.name, data=non_norm_subj_df.reset_index(), color=feature_color, ax=twin_xs[-1])

    if plot_normalized_features:
        twin_xs[-1].spines.right.set_position(("axes", 1 + 0.1 * fidx))
    elif fidx != 0:
        twin_xs[-1].spines.right.set_position(("axes", 1 + 0.1 * (fidx - 1)))

    twin_xs[-1].set_ylabel(feature_name, fontsize=tick_label_size)
    # twin_xs[-1].yaxis.label.set_color(feature_color)
    twin_xs[-1].tick_params(axis='y', colors=feature_color, labelsize=tick_label_size -1 )

    # add a legend entry for the feature
    legend_markers.append(plt.Line2D([0,0],[0,0], color=feature_color))
    legend_labels.append(feature_name)

    twin_xs[-1].grid(False)

    if display_inflection_markers:
        important_ts_idx = np.where(selected_features_by_impact == feature)[0]
        for ts_idx in important_ts_idx:
            timestep = significant_timesteps[ts_idx]
            marker_x_base_level = timestep
            marker_y_base_level = non_norm_subj_df.reset_index()[feature_name][timestep]
            next_ts_y_level = non_norm_subj_df.reset_index()[feature_name][timestep + 1]

            if timestep == 0:
                previous_ts_y_level = marker_y_base_level
            else:
                previous_ts_y_level = non_norm_subj_df.reset_index()[feature_name][timestep - 1]

            feature_range = non_norm_subj_df.reset_index()[feature_name].max() - non_norm_subj_df.reset_index()[feature_name].min()

            # plot marker at maximum angle to y slope
            if (next_ts_y_level == previous_ts_y_level) or \
                ((next_ts_y_level - marker_y_base_level < feature_range / 20) and (marker_y_base_level > previous_ts_y_level))\
                    or ((feature_name == 'Temperature') and (timestep == 47)):
                marker_y_level = marker_y_base_level + 0.02 * feature_range
                marker_x_level = marker_x_base_level
                marker = 'v'
            elif (next_ts_y_level > previous_ts_y_level) and (marker_y_base_level - previous_ts_y_level < feature_range / 20):
                marker_y_level = marker_y_base_level - 0.02 * feature_range
                marker_x_level = marker_x_base_level
                marker = '^'
            elif (next_ts_y_level > previous_ts_y_level):
                marker_y_level = marker_y_base_level
                marker_x_level = marker_x_base_level - 0.5
                marker = '>'
            else :
                marker_y_level = marker_y_base_level
                marker_x_level = marker_x_base_level + 0.5
                marker = '<'

            twin_xs[-1].scatter(marker_x_level, marker_y_level, color=feature_color_dict[feature], s=100, marker=marker, alpha=1, edgecolors='white')


axr.set_xlabel('Time from admission (hours)', fontsize=label_font_size)
if plot_normalized_features:
    axr.set_ylabel('Normalized feature value', fontsize=label_font_size)
axr.tick_params(axis='both', labelsize=tick_label_size)

if plot_legend:
    if plot_normalized_features:
        # add a legend entry for the normalized features
        legend_markers.insert(0, plt.Line2D([0,0],[0,0], color='grey', alpha=0.5))
        legend_labels.insert(0, 'All features (normalized)')

    if display_inflection_markers:
        # add a legend entry for the inflection markers
        ts_marker_down = plt.scatter([], [], marker='v', color='grey', s=50, alpha=0.8)
        ts_marker_up = plt.scatter([], [], marker='^', color='grey', s=50, alpha=0.8)
        ts_marker_left = plt.scatter([], [], marker='<', color='grey', s=50, alpha=0.8)
        ts_marker_right = plt.scatter([], [], marker='>', color='grey', s=50, alpha=0.8)
        ts_label = 'Significant impact on inflection of prediction'
        legend_markers.append(ts_marker_down)
        legend_labels.append(ts_label)

    twin_xs[-1].legend(legend_markers, legend_labels, fontsize=label_font_size, title_fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None), str: LegendTitle({'fontsize': label_font_size})},
                       loc='lower right')

# turn off grid
axr.grid(False)



In [None]:
fig_joint, (ax_main, ax_features) = plt.subplots(
    nrows=2, 
    ncols=1, 
    figsize=(15, 12),
    gridspec_kw=dict(height_ratios=[2, 1], hspace=0.3)
)


#### FIGURE 3 ####
threshold = 0.04
n_features_selection = 0
n_features = 1
k=0.25
alpha=0.3
only_non_static_features = True
use_smoothed_shap_values = True
plot_ground_truth = True
display_significant_slopes = True
n_slope_steps = 5
slope_threshold = 1.5 * threshold

skip_label_at_zero = True
display_text_labels = True
display_legend = True
display_title = False
plot_NIHSS_continuously = True
ts_marker_level = 'shap' # 'baseline' (marker on predicted probability function) or 'shap' (marker on SHAP value function)

tick_label_size = 13
label_font_size = 16

if use_smoothed_shap_values:
    working_shap_values = smoothed_shap_values_over_time
else:
    working_shap_values = shap_values_over_time

# identify significant changes in prediction over time by a change in threshold X% in prediction
significant_positive_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) > threshold)[0])
significant_negative_timesteps = filter_consecutive_numbers(np.where(np.diff(subj_pred_over_ts) < -threshold)[0])
significant_timesteps = np.concatenate((significant_positive_timesteps, significant_negative_timesteps))

non_norm_subj_df = feature_values_df[feature_values_df.case_admission_id_idx == subj].drop(columns=['case_admission_id_idx']).pivot(index='time_step', columns='feature', values='feature_value')

# for each timestep, identify the feature that has the largest impact on the prediction
if only_non_static_features:
    # find non static columns in non normed df
    non_static_features = np.where(non_norm_subj_df.std() > 0.01)[0]
    if use_simplified_shap_values:
       non_static_features = np.where(np.isin(reduced_feature_names, np.array(non_norm_subj_df.std()[non_norm_subj_df.std() > 0.01].index)))[0]
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_positive_features_by_impact = non_static_features[selected_positive_features_by_impact]
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj, non_static_features], axis=0)[significant_negative_timesteps].argmin(axis=1)
    selected_negative_features_by_impact = non_static_features[selected_negative_features_by_impact]
else:
    selected_positive_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_positive_timesteps].argmax(axis=1)
    selected_negative_features_by_impact = np.diff(working_shap_values[:, subj], axis=0)[significant_negative_timesteps].argmin(axis=1)

selected_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_negative_features_by_impact))

if display_significant_slopes:
    # identify features that are driving the change in prediction over time more gentle slopes (then filter out consecutive timesteps)
    significant_positive_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] > slope_threshold)[0]).difference(set(significant_positive_timesteps)))

    significant_negative_slope = filter_consecutive_numbers(set(np.where((np.concatenate((subj_pred_over_ts[n_slope_steps:], np.zeros(n_slope_steps))) - subj_pred_over_ts)[:-n_slope_steps] < -slope_threshold)[0]).difference(set(significant_negative_timesteps)))

    delta_shap_by_features = np.concatenate((working_shap_values[n_slope_steps:, subj, non_static_features], np.zeros((n_slope_steps, len(non_static_features))))) - working_shap_values[:, subj, non_static_features]
    selected_positive_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_positive_slope].argmax(axis=1)
    selected_positive_features_by_slope = non_static_features[selected_positive_features_by_slope]
    selected_negative_features_by_slope = delta_shap_by_features[:-n_slope_steps][significant_negative_slope].argmin(axis=1)
    selected_negative_features_by_slope = non_static_features[selected_negative_features_by_slope]

    selected_features_by_impact = np.concatenate((selected_features_by_impact, selected_positive_features_by_slope, selected_negative_features_by_slope))
    significant_timesteps = np.concatenate((significant_timesteps, significant_positive_slope, significant_negative_slope))
    selected_positive_features_by_impact = np.concatenate((selected_positive_features_by_impact, selected_positive_features_by_slope))
    selected_negative_features_by_impact = np.concatenate((selected_negative_features_by_impact, selected_negative_features_by_slope))

if n_features_selection == 0:
    selected_positive_features = np.array([])
    selected_negative_features = np.array([])
else:
    selected_positive_features = working_shap_values[-1,subj].argsort()[-n_features:][::-1]
    selected_negative_features = working_shap_values[-1,subj].argsort()[:n_features][::-1]

selected_features = np.concatenate((selected_positive_features, selected_positive_features_by_impact, selected_negative_features, selected_negative_features_by_impact)).astype(int)


positive_color_palette = sns.color_palette("mako", n_colors=len(set(np.concatenate((selected_positive_features, selected_positive_features_by_impact)))))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(set(np.concatenate((selected_negative_features, selected_negative_features_by_impact)))))

# plot prediction over time
timestep_axis = np.array(range(n_time_steps))
sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='Predicted probability', linewidth = 2, ax=ax_main)

# plot ground truth over time
if plot_ground_truth:
    # get changes in subj_gt_over_ts
    changes_in_gt = np.diff(subj_gt_over_ts, prepend=0)
    # pair np.where(changes_in_gt == -1) and np.where(changes_in_gt == 1
    change_pairs = list(zip(np.where(changes_in_gt == 1)[0], np.where(changes_in_gt == -1)[0]))
    for change_pair in change_pairs:
        # mark a thick horizontal red line at the ground truth value (-1)
        ax_main.plot([change_pair[0], change_pair[1]], [0, 0], color="#7b002c", linewidth=10, alpha=0.8)
        ax_main.text(np.mean(change_pair), 0 + 0.02, '6h to END', horizontalalignment='center', verticalalignment='center', fontsize=tick_label_size)


pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
feature_color_dict = {}
for i, feature in enumerate(set(selected_features)):
    subj_feature_shap_value_over_time = working_shap_values[:, subj, feature]
    positive_portion = (subj_feature_shap_value_over_time > 0)
    negative_portion = (subj_feature_shap_value_over_time < 0)


    pos_function = subj_feature_shap_value_over_time.copy()
    neg_function = subj_feature_shap_value_over_time.copy()
    pos_function[negative_portion] = 0
    neg_function[positive_portion] = 0

    if feature in selected_features_by_impact:
        important_ts_idx = np.where(selected_features_by_impact == feature)[0]
        # set value to zero before the significant timestep (except for NIHSS if plotting continuously)
        if not np.logical_and(plot_NIHSS_continuously, reduced_feature_names[feature] == 'NIHSS'):
            pos_function[:significant_timesteps[important_ts_idx][0] + 1] = 0
            neg_function[:significant_timesteps[important_ts_idx][0] + 1] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    elif feature in selected_negative_features:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_negative_features_by_impact:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    elif feature in selected_positive_features_by_impact:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    feature_color_dict[feature] = feature_color

    if np.any(pos_function):
        positive_feature = pos_baseline + k * pos_function
        ax_main.fill_between(timestep_axis , pos_baseline, positive_feature, color=feature_color, alpha=alpha)
        pos_baseline = positive_feature

    if np.any(neg_function):
        negative_feature = neg_baseline + k * neg_function
        ax_main.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
        neg_baseline = negative_feature

    # add a legend entry for the feature fill
    ax_main.scatter([], [], color=feature_color, alpha=alpha, label=reduced_feature_names[feature],marker="s", s=200)


# marking inflection points
for feature in set(selected_features_by_impact):
    important_ts_idx = np.where(selected_features_by_impact == feature)[0]
    for ts_idx in important_ts_idx:
        if skip_label_at_zero and significant_timesteps[ts_idx] == 0:
            continue
        # downward inflection point
        if subj_pred_over_ts[significant_timesteps[ts_idx]] > subj_pred_over_ts[significant_timesteps[ts_idx] + 1]:
            marker = 'v'
            if ts_marker_level == 'shap':
                marker_y_level = pos_baseline[significant_timesteps[ts_idx]] + 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] + 0.005
            text_y_level = marker_y_level + 0.01
        # upward inflection point
        else:
            marker = '^'
            if ts_marker_level == 'shap':
                marker_y_level = neg_baseline[significant_timesteps[ts_idx]] - 0.005
            elif ts_marker_level == 'baseline':
                marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] - 0.005
            text_y_level = marker_y_level - 0.015

        ax_main.scatter(significant_timesteps[ts_idx], marker_y_level, color=feature_color_dict[feature], s=100, marker=marker, alpha=1, edgecolors='white')
        # insert a label on the plot
        if display_text_labels:
            if marker == 'v':
                ax_main.text(significant_timesteps[ts_idx]+ 0.01, text_y_level, reduced_feature_names[feature], fontsize=12, color='black',
                                rotation=45, ha='left', va='bottom')
            else:
                ax_main.text(significant_timesteps[ts_idx]- 0.01, text_y_level, reduced_feature_names[feature], fontsize=12, color='black')


if display_title:
    ax_main.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)

ax_main.set_xlabel('Time from admission (hours)', fontsize=label_font_size)
ax_main.set_ylabel('Probability of END', fontsize=label_font_size)
ax_main.tick_params(axis='both', labelsize=tick_label_size)

if display_legend:
    legend_markers, legend_labels = ax_main.get_legend_handles_labels()

    # shap value shades
    shap_shades_markers = legend_markers[1:]
    shap_shades_labels = legend_labels[1:]
    legend_markers = [legend_markers[0]]
    legend_labels = [legend_labels[0]]

    # add a legend entry for the timestep markers
    ts_marker_down = plt.scatter([], [], marker='v', color='grey', s=50, alpha=0.8)
    ts_marker_up = plt.scatter([], [], marker='^', color='grey', s=50, alpha=0.8)
    ts_label = 'Positive / Negative impact on inflection of prediction'
    legend_markers.append((ts_marker_up, ts_marker_down))
    legend_labels.append(ts_label)

    # Add a subtitle for shape value shades
    legend_markers.append('')
    legend_labels.append('')
    legend_markers.append('Weight & direction of influence on model prediction')
    legend_labels.append('')

    legend_markers += shap_shades_markers
    legend_labels += shap_shades_labels

    ax_main.legend(legend_markers, legend_labels, fontsize=label_font_size, title='Influence on model prediction', title_fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None), str: LegendTitle({'fontsize': label_font_size})}, bbox_to_anchor=(1.05, 1), loc='upper left')


##### Figure 4 #####
n_features_small = len(set(selected_features_by_impact))

if n_features_small > 0:
    # Create subplot grid within the bottom subplot
    cols = min(4, n_features_small)  # Max 4 columns for better readability
    rows = (n_features_small + cols - 1) // cols
    
    # Create nested subplots
    gs_nested = ax_features.figure.add_gridspec(rows, cols, 
                                               left=ax_features.get_position().x0,
                                               right=ax_features.get_position().x1,
                                               bottom=ax_features.get_position().y0,
                                               top=ax_features.get_position().y1,
                                               hspace=0.4, wspace=0.3)
    
    # Remove the original ax_features
    ax_features.remove()
    
    for idx, feature in enumerate(set(selected_features_by_impact)):
        row = idx // cols
        col = idx % cols
        ax_small = fig_joint.add_subplot(gs_nested[row, col])
        
        feature_name = reduced_feature_names[feature]
        feature_color = feature_color_dict[feature]
        feature_data = non_norm_subj_df[feature_name]
        
        # Sparkline-style plot
        ax_small.plot(timestep_axis, feature_data, color=feature_color, linewidth=2)
        ax_small.fill_between(timestep_axis, feature_data, alpha=0.3, color=feature_color)
        
        # Add inflection points
        important_ts_idx = np.where(selected_features_by_impact == feature)[0]
        for ts_idx in important_ts_idx:
            timestep = significant_timesteps[ts_idx]
            ax_small.scatter(timestep, feature_data.iloc[timestep], color=feature_color, s=60, zorder=5, edgecolors='white', linewidth=1)
        
        # Styling
        ax_small.set_title(feature_name, fontsize=tick_label_size, color=feature_color, weight='bold')
        ax_small.set_xlim(0, n_time_steps)
        ax_small.spines['top'].set_visible(False)
        ax_small.spines['right'].set_visible(False)
        
        # Show min and max values on y-axis
        y_min, y_max = feature_data.min(), feature_data.max()
        if y_min == y_max:
            y_ticks = [y_min]
        else:
            y_ticks = [y_min, y_max]
        ax_small.set_yticks(y_ticks)
        ax_small.tick_params(labelsize=tick_label_size-2)

        ax_small.set_ylim(y_min - 0.2 * (y_max - y_min), y_max + 0.2 * (y_max - y_min))
        
        # Only show x-label on bottom row
        if row == rows - 1:
            ax_small.set_xlabel('Time (h)', fontsize=tick_label_size-1)
        else:
            ax_small.set_xticklabels([])

else:
    # If no features, show message
    ax_features.text(0.5, 0.5, 'No significant feature changes detected', 
                    transform=ax_features.transAxes, ha='center', va='center', 
                    fontsize=label_font_size, style='italic')
    ax_features.set_xlim(0, 1)
    ax_features.set_ylim(0, 1)
    ax_features.axis('off')

# remove spines
ax_main.spines['top'].set_visible(False)
ax_main.spines['right'].set_visible(False)


In [None]:
# fig_joint.savefig(f'/Users/jk1/temp/opsum_end/testing/subj_{subj}_inference_plot.png', bbox_inches='tight', dpi=600)

#### Find interesting subjects

In [None]:
# find top 10 subj with top difference in max and min prediction
top_10_diff_subj = np.argsort(np.abs(predictions_over_time.max(axis=1) - predictions_over_time.min(axis=1)))[-100:][::-1]
# find intersection with subjs with non zero sum ground truth
top_10_diff_subj_gt = np.intersect1d(top_10_diff_subj, np.where(gt_over_time.sum(axis=1) > 0)[0])

In [None]:
top_10_diff_subj_gt

In [None]:
predictions_over_time[9], gt_over_time[9]

In [None]:
# plot ground truth and prediction for subj
for subj in top_10_diff_subj_gt:
    subj_pred_over_ts = predictions_over_time[subj]
    subj_gt_over_ts = gt_over_time[subj]

    ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts, label='Predicted probability', linewidth = 2)
    ax = sns.lineplot(x=timestep_axis, y=subj_gt_over_ts, label='Ground truth', linewidth = 2, color='black', linestyle='--')
    ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
    plt.show()