In [1]:
import shap
from pathlib import Path
from anndata import read_h5ad
import sys
import scanpy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import nbinom, pearsonr
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
import random
from scipy import stats
import time
sys.path.insert(0, '/scratch/imb/Xiao/STimage')
from tqdm import tqdm
import pickle
from scipy.stats import nbinom
from stimage._utils import gene_plot, Read10X, ReadOldST, tiling, ensembl_to_id
from stimage._model import CNN_NB_multiple_genes, negative_binomial_layer, negative_binomial_loss, PrinterCallback
from stimage._data_generator import DataGenerator

In [2]:
from shap._explanation import Explanation
from typing import Optional
from shap.plots import colors

def image_plot_shap(shap_values: Explanation or np.ndarray,
          pixel_values: Optional[np.ndarray] = None,
          labels: Optional[list or np.ndarray] = None,
          true_labels: Optional[list] = None,
          width: Optional[int] = 20,
          aspect: Optional[float] = 0.2,
          hspace: Optional[float] = 0.2,
          labelpad: Optional[float] = None,
          cmap: Optional[str or Colormap] = colors.red_transparent_blue,
          show: Optional[bool] = True):
    """Plots SHAP values for image inputs.

    Parameters
    ----------
    shap_values : [numpy.array]
        List of arrays of SHAP values. Each array has the shape
        (# samples x width x height x channels), and the
        length of the list is equal to the number of model outputs that are being
        explained.

    pixel_values : numpy.array
        Matrix of pixel values (# samples x width x height x channels) for each image.
        It should be the same
        shape as each array in the ``shap_values`` list of arrays.

    labels : list or np.ndarray
        List or ``np.ndarray`` (# samples x top_k classes) of names for each of the
        model outputs that are being explained.

    true_labels: list
        List of a true image labels to plot.

    width : float
        The width of the produced matplotlib plot.

    labelpad : float
        How much padding to use around the model output labels.

    show : bool
        Whether ``matplotlib.pyplot.show()`` is called before returning.
        Setting this to ``False`` allows the plot
        to be customized further after it has been created.

    Examples
    --------

    See `image plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/image.html>`_.

    """

    # support passing an explanation object
    if str(type(shap_values)).endswith("Explanation'>"):
        shap_exp = shap_values
        # feature_names = [shap_exp.feature_names]
        # ind = 0
        if len(shap_exp.output_dims) == 1:
            shap_values = [shap_exp.values[..., i] for i in range(shap_exp.values.shape[-1])]
        elif len(shap_exp.output_dims) == 0:
            shap_values = shap_exp.values
        else:
            raise Exception("Number of outputs needs to have support added!! (probably a simple fix)")
        if pixel_values is None:
            pixel_values = shap_exp.data
        if labels is None:
            labels = shap_exp.output_names

    # multi_output = True
    if not isinstance(shap_values, list):
        # multi_output = False
        shap_values = [shap_values]

    if len(shap_values[0].shape) == 3:
        shap_values = [v.reshape(1, *v.shape) for v in shap_values]
        pixel_values = pixel_values.reshape(1, *pixel_values.shape)

    # labels: (rows (images) x columns (top_k classes) )
    # if labels is not None:
    #     if isinstance(labels, list):
    #         labels = np.array(labels).reshape(1, -1)

    # if labels is not None:
    #     labels = np.array(labels)
    #     if labels.shape[0] != shap_values[0].shape[0] and labels.shape[0] == len(shap_values):
    #         labels = np.tile(np.array([labels]), shap_values[0].shape[0])
    #     assert labels.shape[0] == shap_values[0].shape[0], "Labels must have same row count as shap_values arrays!"
    #     if multi_output:
    #         assert labels.shape[1] == len(shap_values), "Labels must have a column for each output in shap_values!"
    #     else:
    #         assert len(labels[0].shape) == 1, "Labels must be a vector for single output shap_values."

    label_kwargs = {} if labelpad is None else {'pad': labelpad}

    # plot our explanations
    x = pixel_values
    fig_size = np.array([6 * (len(shap_values) + 1), 5 * (x.shape[0] + 1)])
    if fig_size[0] > width:
        fig_size *= width / fig_size[0]
    fig, axes = plt.subplots(nrows=x.shape[0], ncols=len(shap_values) + 1, figsize=fig_size)
    if len(axes.shape) == 1:
        axes = axes.reshape(1, axes.size)
    for row in range(x.shape[0]):
        x_curr = x[row].copy()

        # make sure we have a 2D array for grayscale
        if len(x_curr.shape) == 3 and x_curr.shape[2] == 1:
            x_curr = x_curr.reshape(x_curr.shape[:2])

        # if x_curr.max() > 1:
        #     x_curr /= 255.

        # get a grayscale version of the image
        if len(x_curr.shape) == 3 and x_curr.shape[2] == 3:
            x_curr_gray = (
                    0.2989 * x_curr[:, :, 0] + 0.5870 * x_curr[:, :, 1] + 0.1140 * x_curr[:, :, 2])  # rgb to gray
            x_curr_disp = x_curr
        elif len(x_curr.shape) == 3:
            x_curr_gray = x_curr.mean(2)

            # for non-RGB multi-channel data we show an RGB image where each of the three channels is a scaled k-mean center
            flat_vals = x_curr.reshape([x_curr.shape[0] * x_curr.shape[1], x_curr.shape[2]]).T
            flat_vals = (flat_vals.T - flat_vals.mean(1)).T
            means = kmeans(flat_vals, 3, round_values=False).data.T.reshape([x_curr.shape[0], x_curr.shape[1], 3])
            x_curr_disp = (means - np.percentile(means, 0.5, (0, 1))) / (
                    np.percentile(means, 99.5, (0, 1)) - np.percentile(means, 1, (0, 1)))
            x_curr_disp[x_curr_disp > 1] = 1
            x_curr_disp[x_curr_disp < 0] = 0
        else:
            x_curr_gray = x_curr
            x_curr_disp = x_curr

        axes[row, 0].imshow(x_curr_disp, cmap=plt.get_cmap('gray'))
        if true_labels:
            axes[row, 0].set_title(true_labels[row], **label_kwargs)
        axes[row, 0].axis('off')
        if len(shap_values[0][row].shape) == 2:
            abs_vals = np.stack([np.abs(shap_values[i]) for i in range(len(shap_values))], 0).flatten()
        else:
            abs_vals = np.stack([np.abs(shap_values[i].sum(-1)) for i in range(len(shap_values))], 0).flatten()
        max_val = np.nanpercentile(abs_vals, 99)
        for i in range(len(shap_values)):
            if labels is not None:
                axes[row, i + 1].set_title(labels[row], **label_kwargs)
            sv = shap_values[i][row] if len(shap_values[i][row].shape) == 2 else shap_values[i][row].sum(-1)
            axes[row, i + 1].imshow(x_curr_gray, cmap=plt.get_cmap('gray'), alpha=0.15,
                                    extent=(-1, sv.shape[1], sv.shape[0], -1))
            im = axes[row, i + 1].imshow(sv, cmap=cmap, vmin=-max_val, vmax=max_val)
            axes[row, i + 1].axis('off')
    if hspace == 'auto':
        fig.tight_layout()
    else:
        fig.subplots_adjust(hspace=hspace)
    cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal",
                      aspect=fig_size[0] / aspect)
    cb.outline.set_visible(False)
    if show:
        plt.show()
    else:
        return fig

In [3]:
DATA_PATH = Path("/scratch/imb/Xiao/STimage_100run/dataset_breast_cancer_9visium")
BASE_PATH = Path("/scratch/imb/Xiao/STimage/development/stimage_100run/results_1")
MODEL_PATH = BASE_PATH / "model_0.h5"
CORR_PATH = BASE_PATH / "corr_df_1160920F_0.csv"
PRED_PATH = BASE_PATH / "prediction_df_1160920F_0.csv"
OUT_PATH = Path("/scratch/imb/Xiao/STimage/development/SHAP/PLOT")
OUT_PATH.mkdir(exist_ok=True, parents=True)

In [4]:
GENE = "TRBC2"

In [5]:
df_pred = pd.read_csv(PRED_PATH, index_col=0)

In [6]:
df_corr = pd.read_csv(CORR_PATH, index_col=0)

In [7]:
df_corr = df_corr.sort_values(by='Pearson correlation', ascending=False)

In [8]:
df_corr["Gene"][-2:-1]

1453    MPPED1
Name: Gene, dtype: object

In [9]:
adata_all = read_h5ad(DATA_PATH / "all_adata.h5ad")

In [10]:
adata_all.obs["tile_path"] = adata_all.obs.tile_path.map(
    lambda x: x.replace("/clusterdata/uqxtan9/Xiao/breast_cancer_9visium",
                        "/scratch/imb/Xiao/STimage_100run/dataset_breast_cancer_9visium"))


In [11]:
df_gene_ls = pd.read_csv("/scratch/imb/Xiao/STimage/development/stimage_100run/Intersection_marker_genes.csv", sep=",")

gene_list_select = df_gene_ls["gene_name"].values.tolist()

gene_list = adata_all.var_names.intersection(gene_list_select)
n_genes = len(gene_list)

In [12]:
adata_all_train_valid = adata_all[adata_all.obs["library_id"].isin(
    adata_all.obs.library_id.cat.remove_categories(["FFPE", "1160920F"]).unique())]

training_index = adata_all_train_valid.obs.sample(frac=0.7, random_state=1).index
training_dataset = adata_all_train_valid[training_index,].copy()

valid_index = adata_all_train_valid.obs.index.isin(training_index)
valid_dataset = adata_all_train_valid[~valid_index,].copy()

train_gen_ = tf.data.Dataset.from_generator(
            lambda: DataGenerator(adata=training_dataset,
                          genes=gene_list, aug=False),
            output_types=(tf.float32, tuple([tf.float32]*n_genes)),
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
train_gen = train_gen_.batch(1)


In [13]:
test_index = adata_all.obs.library_id == "1160920F"
test_dataset_1 = adata_all[test_index,].copy()
test_gen_ = tf.data.Dataset.from_generator(
            lambda: DataGenerator(adata=test_dataset_1,
                          genes=gene_list),
            output_types=(tf.float32, tuple([tf.float32]*n_genes)),
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
test_gen = test_gen_.batch(1)

In [14]:
model = load_model(MODEL_PATH, 
                   custom_objects={
                       'negative_binomial_loss': negative_binomial_loss,
                   }
                  )

In [15]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
tile_input (InputLayer)         [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 305, 305, 3)  0           tile_input[0][0]                 
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 150, 150, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 150, 150, 64) 256         conv1_conv[0][0]                 
______________________________________________________________________________________________

In [16]:
gene_index = gene_list.get_loc(GENE)
print(gene_index)

629


In [17]:
layer_names = [layer.name for layer in model.layers]
layer_index = layer_names.index(f'dense_{gene_index}')
print(layer_index)

805


In [18]:
model_gene = tf.keras.Model(inputs=model.input, outputs=model.layers[layer_index].output)

In [19]:
model_gene.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
tile_input (InputLayer)         [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 305, 305, 3)  0           tile_input[0][0]                 
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 150, 150, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 150, 150, 64) 256         conv1_conv[0][0]                 
______________________________________________________________________________________________

In [20]:
def model_predict_gene(model, gene_list, gene, x):
    test_predictions = model.predict(x)
    n = test_predictions[:, 0]
    p = test_predictions[:, 1]
    y_pred = nbinom.mean(n, p)
    print(y_pred)
    return y_pred.reshape(-1,1)


In [21]:
images_bg = np.concatenate([x for x, y in train_gen.take(1000)], axis=0)
labels_bg = np.concatenate([y for x, y in train_gen.take(1000)], axis=0)

In [22]:
images_test = np.concatenate([x for x, y in test_gen], axis=0)
labels_test = np.concatenate([y for x, y in test_gen], axis=0)

In [23]:
len(test_gen)

TypeError: dataset length is unknown.

In [None]:
train_gen

<BatchDataset shapes: ((None, 299, 299, 3), ((None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1), (None, 1)

## DeepExplainer

In [None]:
e = shap.DeepExplainer(model_gene, images_bg)

In [None]:
shap_values = e.shap_values(images_test[0:5])

## 

## GradientExplainer

In [24]:
e = shap.GradientExplainer(model_gene, images_bg)

In [25]:
shap_values, indexes = e.shap_values(images_test)

: 

In [None]:
images_test[0:5].shape

In [None]:
gene_index

In [24]:
# shap_values_array = np.empty((0, 299, 299, 3))
for i, (x, y) in enumerate(test_gen):
    # print(x.shape)
    if i < 1024:
        continue
    print(i)
    e = shap.GradientExplainer(model_gene, images_bg)
    shap_values, indexes = e.shap_values(x.numpy())
    # print(shap_values.shape)
    # shap_values_array = np.concatenate((shap_values_array, shap_values), axis=0)
    ground_truth = np.squeeze(labels_test[i])
    pred = df_pred.iloc[i, gene_index]
    fig = image_plot_shap(shap_values,
                          x.numpy().astype(np.int),
                          true_labels=[f"True: {ground_truth:.2f}"],
                          labels=[f"Pred: {pred:.2f}"], show=False)
    try:
        fig.savefig(OUT_PATH / f"1000bg_SHAP_{GENE}_{i}.png")
        with open(OUT_PATH / f"1000bg_shap_values_{GENE}_{i}.pkl", 'wb') as f:
            pickle.dump(shap_values, f)
    except:
        print(f"Error in saving figure {i}")
    plt.close()

print("Done!")

1024


`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224


In [None]:
# save pickle
with open(OUT_PATH / "shap_values.pkl", 'wb') as f:
    pickle.dump(shap_values_array, f)

In [None]:
for i in range(5):
    ground_truth = np.squeeze(labels_test[i])
    pred = df_pred.iloc[i, gene_index]
    fig = image_plot_shap(np.expand_dims(shap_values[i], axis=0), 
                    np.expand_dims(images_test[i], axis=0).astype(np.int),
                    true_labels=[f"True: {ground_truth:.2f}"],
                    labels=[f"Pred: {pred:.2f}"], show=False)

In [None]:
fig.savefig(OUT_PATH / f"SHAP_{GENE}_{i}.png")

In [None]:
image_plot_shap(shap_values, images_test.astype(np.int))

In [None]:
images_test[0,:,:,:].astype(np.int)

In [None]:
?shap.GradientExplainer

In [None]:

# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[1:5])

In [None]:
?shap.Explainer

In [None]:
# create an explainer with model and image masker
explainer_blur = shap.Explainer(f, masker_blur, output_names=class_names)

# here we explain two images using 500 evaluations of the underlying model to estimate the SHAP values
shap_values_fine = explainer_blur(
    X[1:3], max_evals=5000, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]
)

In [None]:
X, y = shap.datasets.imagenet50()

In [None]:
X[0].shape