# Exploring shap value visualisation over time

Valid version is last

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from random import randint
import pandas as pd
import pickle
from tqdm import tqdm
from prediction.mrs_outcome_prediction.LSTM.testing.shap_helper_functions import check_shap_version_compatibility

In [None]:

# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
model_weights_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/2022_09_07_1744/test_LSTM_sigmoid_all_unchanged_0.2_2_True_RMSprop_3M mRS 0-2_128_4/sigmoid_all_unchanged_0.2_2_True_RMSprop_3M mRS 0-2_128_4.hdf5'
features_path = '/Users/jk1/temp/opsum_prepro_output/preprocessed_features_02092022_083046.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/preprocessed_outcomes_02092022_083046.csv'
normalisation_parameters_path = '/Users/jk1/temp/opsum_prepro_output/logs_02092022_083046/normalisation_parameters.csv'
shap_values_path = "/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/2022_09_07_1744/deep_explainer_shap_values.pkl"


In [None]:
outcome = '3M mRS 0-2'
masking = True
units = 128
activation = 'sigmoid'
dropout = 0.2
layers = 2
optimizer = 'RMSprop'
seed = 42
test_size = 0.20
override_masking_value = False

In [None]:
from prediction.mrs_outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time

# load the dataset
X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=outcome)


In [None]:
n_time_steps = X.relative_sample_date_hourly_cat.max() + 1
n_channels = X.sample_label.unique().shape[0]

In [None]:
from sklearn.model_selection import train_test_split
from prediction.mrs_outcome_prediction.data_loading.data_formatting import features_to_numpy, \
    link_patient_id_to_outcome, numpy_to_lookup_table

# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)

test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]
train_X_df = X[X.patient_id.isin(pid_train)]
train_y_df = y[y.patient_id.isin(pid_train)]

train_X_np = features_to_numpy(train_X_df,
                                 ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
train_y_np = np.array([train_y_df[train_y_df.case_admission_id == cid].outcome.values[0] for cid in
                         train_X_np[:, 0, 0, 0]]).astype('float32')
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')

# create look-up table for case_admission_ids, sample_labels and relative_sample_date_hourly_cat
test_features_lookup_table = numpy_to_lookup_table(test_X_np)
train_features_lookup_table = numpy_to_lookup_table(train_X_np)

# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')
train_X_np = train_X_np[:, :, :, -1].astype('float32')

## Prediction at every timepoint for a given subject

In [None]:
from prediction.utils.scoring import precision, recall, matthews
from prediction.mrs_outcome_prediction.LSTM.LSTM import lstm_generator

subj = randint(0, len(test_X_np))

subj_pred_over_ts = []

for ts in tqdm(range(n_time_steps)):
    modified_time_steps = ts + 1
    model = lstm_generator(x_time_shape=modified_time_steps, x_channels_shape=n_channels, masking=masking, n_units=units,
                           activation=activation, dropout=dropout, n_layers=layers)

    model.compile(loss='binary_crossentropy', optimizer=optimizer,
                  metrics=['accuracy', precision, recall, matthews])

    model.load_weights(model_weights_path)

    subj_X_with_first_n_ts = test_X_np[subj:subj+1,0:modified_time_steps,:]

    y_pred = model.predict(subj_X_with_first_n_ts)
    subj_pred_over_ts.append(y_pred[0][0])


In [None]:
len(subj_pred_over_ts)

## Find shap explanations for every timepoint for this subject

In [None]:
with open(shap_values_path, 'rb') as handle:
    shap_values = pickle.load(handle)

In [None]:
features = list(test_features_lookup_table['sample_label'].keys())

In [None]:
shap_values[0][subj, 33].shape

Taking shap values from every individual timestep (actually model as access to aggregated history before ts as well)

In [None]:
selected_features_over_ts = []
selected_features_shap_over_ts = []
for ts in tqdm(range(n_time_steps)):
    # find index of 3 features with biggest positive shap impart
    selected_positive_features = shap_values[0][subj, ts].argsort()[-3:][::-1]
    np.array(features)[selected_positive_features]

    # find index of 3 features with biggest negative shap impart
    selected_negative_features = shap_values[0][subj, ts].argsort()[:3][::-1]
    np.array(features)[selected_negative_features]

    selected_features = np.concatenate((selected_positive_features, selected_negative_features))
    selected_features_over_ts.append(selected_features)

    selected_features_shap_over_ts.append(shap_values[0][subj, ts, selected_features])



In [None]:
len(set(flatten(selected_features_over_ts)))

In [None]:
fig = plt.figure(figsize=(15,10))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts)

k = 5
alpha = 0.3
positive_color_palette = ['#d6fff6', '#231651', '#4DCCBD']
negative_color_palette = ['#EEEBD0', '#EBB3A9', '#E87EA1']

positive_feature1 = subj_pred_over_ts + k * np.array(selected_features_shap_over_ts)[:, 0]
ax.fill_between(timestep_axis, subj_pred_over_ts, positive_feature1, color=positive_color_palette[0], alpha=alpha)
positive_feature2 = positive_feature1 + k * np.array(selected_features_shap_over_ts)[:, 1]
ax.fill_between(timestep_axis, subj_pred_over_ts, positive_feature2, color=positive_color_palette[1], alpha=alpha)
positive_feature3 = positive_feature2 + k * np.array(selected_features_shap_over_ts)[:, 2]
ax.fill_between(timestep_axis, subj_pred_over_ts, positive_feature3, color=positive_color_palette[2], alpha=alpha)

negative_feature1 = subj_pred_over_ts + k * np.array(selected_features_shap_over_ts)[:, 3]
ax.fill_between(timestep_axis, negative_feature1, subj_pred_over_ts, color=negative_color_palette[0], alpha=alpha)
negative_feature2 = negative_feature1 + k * np.array(selected_features_shap_over_ts)[:, 4]
ax.fill_between(timestep_axis, negative_feature2, subj_pred_over_ts, color=negative_color_palette[1], alpha=alpha)
negative_feature3 = negative_feature2 + k * np.array(selected_features_shap_over_ts)[:, 5]
ax.fill_between(timestep_axis, negative_feature3, subj_pred_over_ts, color=negative_color_palette[2], alpha=alpha)

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

plt.plot()

Taking shap values for summed timestep (individual timestep + summed history before it)


In [None]:
selected_features_over_ts = []
selected_features_shap_over_ts = []

subj_pred_df_columns = ['timestep', 'prediction', 'shap_feature_id', 'shap_value']
subj_pred_df = pd.DataFrame(columns=subj_pred_df_columns)

for ts in tqdm(range(n_time_steps)):
    modified_time_steps = ts + 1

    # find index of 3 features with biggest positive shap impart
    selected_positive_features = shap_values[0][subj, 0:modified_time_steps].sum(axis=0).argsort()[-3:][::-1]
    np.array(features)[selected_positive_features]

    # find index of 3 features with biggest negative shap impart
    selected_negative_features = shap_values[0][subj, 0:modified_time_steps].sum(axis=0).argsort()[:3][::-1]
    np.array(features)[selected_negative_features]

    selected_features = np.concatenate((selected_positive_features, selected_negative_features))
    selected_features_over_ts.append(selected_features)

    subj_shap_values_at_ts = shap_values[0][subj, 0:modified_time_steps, selected_features].sum(axis=1)
    selected_features_shap_over_ts.append(subj_shap_values_at_ts)

    subj_pred_df_ts = pd.DataFrame([selected_features, subj_shap_values_at_ts]).T
    subj_pred_df_ts.columns = subj_pred_df_columns[-2:]
    subj_pred_df_ts['timestep'] = ts
    subj_pred_df_ts['prediction'] = subj_pred_over_ts[ts]
    subj_pred_df = subj_pred_df.append(subj_pred_df_ts)


In [None]:
len(set(flatten(selected_features_over_ts)))


In [None]:
fig = plt.figure(figsize=(15,10))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts)

k = 1
alpha = 0.3
positive_color_palette = ['#d6fff6', '#231651', '#4DCCBD']
negative_color_palette = ['#EEEBD0', '#EBB3A9', '#E87EA1']

positive_feature1 = subj_pred_over_ts + k * np.array(selected_features_shap_over_ts)[:, 0]
ax.fill_between(timestep_axis, subj_pred_over_ts, positive_feature1, color=positive_color_palette[0], alpha=alpha)
positive_feature2 = positive_feature1 + k * np.array(selected_features_shap_over_ts)[:, 1]
ax.fill_between(timestep_axis, subj_pred_over_ts, positive_feature2, color=positive_color_palette[1], alpha=alpha)
positive_feature3 = positive_feature2 + k * np.array(selected_features_shap_over_ts)[:, 2]
ax.fill_between(timestep_axis, subj_pred_over_ts, positive_feature3, color=positive_color_palette[2], alpha=alpha)

negative_feature1 = subj_pred_over_ts + k * np.array(selected_features_shap_over_ts)[:, 3]
ax.fill_between(timestep_axis, negative_feature1, subj_pred_over_ts, color=negative_color_palette[0], alpha=alpha)
negative_feature2 = negative_feature1 + k * np.array(selected_features_shap_over_ts)[:, 4]
ax.fill_between(timestep_axis, negative_feature2, subj_pred_over_ts, color=negative_color_palette[1], alpha=alpha)
negative_feature3 = negative_feature2 + k * np.array(selected_features_shap_over_ts)[:, 5]
ax.fill_between(timestep_axis, negative_feature3, subj_pred_over_ts, color=negative_color_palette[2], alpha=alpha)

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

plt.plot()

Plot every individual feature

In [None]:
subj_pred_df.head()

In [None]:
n_features = 3
positive_features = shap_values[0][subj, :,:].sum(axis=0).argsort()[-n_features :]
negative_features = shap_values[0][subj, :,:].sum(axis=0).argsort()[:n_features]
selected_features = np.concatenate((positive_features, negative_features))



In [None]:
shap_values[0][subj, :, selected_features]

In [None]:
fig = plt.figure(figsize=(15,10))

k=1
alpha=0.6

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(negative_features))


# temp = subj_pred_df[['timestep', 'prediction']].drop_duplicates().reset_index(drop=True)
# ax1 = sns.lineplot(x='timestep', y='prediction', data=temp)

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts)

# plot positive features
baseline = subj_pred_over_ts
for i, feature in enumerate(positive_features):
    positive_feature = baseline + k * shap_values[0][subj, :, feature]
    ax.fill_between(timestep_axis, baseline, positive_feature, color=positive_color_palette[i], alpha=alpha, label=features[feature])
    baseline = positive_feature


# plot negative features
baseline = subj_pred_over_ts
for i, feature in enumerate(negative_features):
    negative_feature = baseline + k * shap_values[0][subj, :, feature]
    ax.fill_between(timestep_axis, negative_feature, baseline, color=negative_color_palette[i], alpha=alpha, label=features[feature])
    baseline = negative_feature

ax.legend()

Normalise shap values by substracting preceding value

In [None]:
fig = plt.figure(figsize=(15,10))

k=10
alpha=0.6

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(negative_features))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts)

# normalise shap_values
normalised_subj_shap = np.concatenate([shap_values[0][subj,0:1,:], np.diff(shap_values[0][subj], n=1, axis=0)])

# plot positive features
baseline = subj_pred_over_ts
for i, feature in enumerate(positive_features):
    positive_feature = baseline + k * normalised_subj_shap[:, feature]
    ax.fill_between(timestep_axis, baseline, positive_feature, color=positive_color_palette[i], alpha=alpha, label=features[feature])
    baseline = positive_feature


# plot negative features
baseline = subj_pred_over_ts
for i, feature in enumerate(negative_features):
    negative_feature = baseline + k * normalised_subj_shap[:, feature]
    ax.fill_between(timestep_axis, negative_feature, baseline, color=negative_color_palette[i], alpha=alpha, label=features[feature])
    baseline = negative_feature

ax.legend()

Shap values should be normalised so that the total at each timestep is equal to divergence of timestep

In [None]:
diff_from_baseline_prediction_at_ts = np.array(subj_pred_over_ts) - 0.5
sequential_diff_from_baseline = np.concatenate([diff_from_baseline_prediction_at_ts[0:1], np.diff(diff_from_baseline_prediction_at_ts, n=1, axis=0)])
normalised_subj_shap = ((shap_values[0][subj,:,:].T / np.abs(shap_values[0][subj,:,:].sum(axis=-1))) * np.abs(sequential_diff_from_baseline)).T



In [None]:
# temp = np.concatenate([shap_values[0][subj,0:1,:], np.diff(shap_values[0][subj], n=1, axis=0)])
# normalised_subj_shap = ((temp[:,:].T / temp[:,:].sum(axis=-1)) * (np.array(subj_pred_over_ts) - 0.5)).T

In [None]:
fig = plt.figure(figsize=(15,10))

k=1
alpha=0.3

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(negative_features))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts)

# plot positive features
baseline = subj_pred_over_ts
for i, feature in enumerate(positive_features):
    positive_feature = baseline + k * normalised_subj_shap[:, feature]
    ax.fill_between(timestep_axis, baseline, positive_feature, color=positive_color_palette[i], alpha=alpha, label=features[feature])
    baseline = positive_feature


# plot negative features
baseline = subj_pred_over_ts
for i, feature in enumerate(negative_features):
    negative_feature = baseline + k * normalised_subj_shap[:, feature]
    ax.fill_between(timestep_axis, negative_feature, baseline, color=negative_color_palette[i], alpha=alpha, label=features[feature])
    baseline = negative_feature

ax.legend()

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

plt.plot()

## Use shap predictions computed with model for each timepoint

In [None]:
shap_over_time_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/2022_09_07_1744/deep_explainer_shap_values_over_ts.pkl'

In [None]:
with open(shap_over_time_path, 'rb') as handle:
    shap_values_over_time = pickle.load(handle)

In [None]:
cumulative_shap_values_over_time = np.array([shap_values_over_time[ts][0].sum(axis=1) for ts in range(n_time_steps)])
cumulative_shap_values_over_time.shape

In [None]:
n_features = 3
# find index of 3 features with biggest positive shap impart
selected_positive_features = cumulative_shap_values_over_time[:, subj].sum(axis=0).argsort()[-n_features:]

# find index of 3 features with biggest negative shap impart
selected_negative_features = cumulative_shap_values_over_time[:, subj].sum(axis=0).argsort()[:n_features]

selected_features = np.concatenate((selected_positive_features, selected_negative_features))



In [None]:
fig = plt.figure(figsize=(15,10))

k=1
alpha=0.3

timestep_axis = np.array(range(n_time_steps))

positive_color_palette = sns.color_palette("mako", n_colors=len(positive_features))
negative_color_palette = sns.color_palette("flare_r", n_colors=len(negative_features))

timestep_axis = np.array(range(n_time_steps))
ax = sns.lineplot(x=timestep_axis, y=subj_pred_over_ts)


# todo: depending on sign of feature change direction of addition
pos_baseline = subj_pred_over_ts
neg_baseline = subj_pred_over_ts
pos_count = 0
neg_count = 0
for i, feature in enumerate(selected_features):
    subj_cumulative_shap_value_over_time = cumulative_shap_values_over_time[:, subj, feature]
    positive_portion = (subj_cumulative_shap_value_over_time > 0)
    negative_portion = (subj_cumulative_shap_value_over_time < 0)

    pos_function = subj_cumulative_shap_value_over_time.copy()
    pos_function[negative_portion] = 0

    neg_function = subj_cumulative_shap_value_over_time.copy()
    neg_function[positive_portion] = 0

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1

    positive_feature = pos_baseline + k * pos_function
    ax.fill_between(timestep_axis, pos_baseline, positive_feature, color=feature_color, alpha=alpha, label=features[feature])
    pos_baseline = positive_feature

    negative_feature = neg_baseline + k * neg_function
    ax.fill_between(timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha)
    neg_baseline = negative_feature

ax.legend()

ax.set_title(f'Predictions for subject {subj} of test set along time', fontsize=20)
ax.set_xlabel('Time from admission (hours)', fontsize=15)
ax.set_ylabel('Probability of favorable outcome', fontsize=15)

plt.plot()

In [None]:
normalisation_parameters_df = pd.read_csv(normalisation_parameters_path)

In [None]:
def reverse_normalisation_for_subj(subj_df, normalisation_parameters_df):
    for variable in normalisation_parameters_df.variable.unique():
        if variable not in subj_df.columns:
            continue

        temp = subj_df[variable].copy()
        std = normalisation_parameters_df[normalisation_parameters_df.variable == variable].original_std.iloc[0]
        mean = normalisation_parameters_df[normalisation_parameters_df.variable == variable].original_mean.iloc[0]
        temp = (temp * std) + mean
        subj_df[variable] = temp

    return subj_df

In [None]:
non_norm_x_subj_df = reverse_normalisation_for_subj(pd.DataFrame(data=test_X_np[subj], columns = features), normalisation_parameters_df)


In [None]:
plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Selected features", fontsize=18, y=0.95)

# set number of columns (use 3 to demonstrate the change)
ncols = 3
# calculate number of rows
nrows = len(selected_features) // ncols + (len(selected_features) % ncols > 0)

pos_count = 0
neg_count = 0
# loop through the length of features and keep track of index
for n, feature in enumerate(selected_features):
    # add a new subplot iteratively using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)

    if feature in selected_positive_features:
        feature_color = positive_color_palette[pos_count]
        pos_count += 1
    else:
        feature_color = negative_color_palette[neg_count]
        neg_count += 1
    sns.lineplot(y=features[feature], x='index', data=non_norm_x_subj_df.reset_index(), color=feature_color, ax=ax)

    ax.set_title(features[feature])