# Interactive Slider for GradCAM Heatmap

Plots an interactive slider for GradCAM heatmaps for all models for a given patient.

### Import Libraries and Modules

In [None]:
%matplotlib inline

import os
import h5py
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import metrics
from sklearn.metrics import confusion_matrix, roc_curve, auc

import tensorflow as tf
from tensorflow import keras

print("TF  Version",tf.__version__)

In [None]:
# check and set path before loading modules
print(os.getcwd())
INPUT_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/"
OUTPUT_DIR = "/tf/notebooks/bule/explainable_AI/"
if os.getcwd() != OUTPUT_DIR:
    os.chdir(OUTPUT_DIR)
    
import functions_model_definition as md
import functions_read_data as rdat
import functions_slider as sl
import functions_gradcam as gc
import functions_plot_heatmap as phm
import functions_metrics as fm
import functions_occlusion as oc

## Load Data and Set Up Model

In [None]:
## ToDo: 
## - hm_type should always be "gc" in this notebook
## - pred_hm_only, norm_hm and hm_mode should only be adjusted in last chunk
## - pic_save_name is not needed in this notebook: implement dictionary for paths

# Define Version
version = "CIB" # one of:
# 10Fold_sigmoid_V0, 10Fold_sigmoid_V1, 10Fold_sigmoid_V2, 10Fold_sigmoid_V2f, 10Fold_sigmoid_V3
# 10Fold_softmax_V0, 10Fold_softmax_V1, andrea
# 10Fold_CIB, 10Fold_CIBLSX

# Define Model Version
model_version = 2

# define weighting
hm_mode = "wgt" 

# define heatmap type
hm_type = "gc"
norm_hm = False # (gradcam is normalized over all heatmaps)
pred_hm_only = True

# Select naming convention (for CIBLSX model_version >= 3 should be False)
comp_mode = False # if True: use old naming convention

# define paths
DATA_DIR, WEIGHT_DIR, DATA_OUTPUT_DIR, PIC_OUTPUT_DIR, pic_save_name = rdat.dir_setup(
    INPUT_DIR, OUTPUT_DIR, version, model_version, 
    weight_mode = hm_mode, hm_type = hm_type, pred_hm = pred_hm_only, hm_norm = norm_hm,
    compatibility_mode=comp_mode)

In [None]:
## load images and ids
(X_in, pat_ids, id_tab, all_results_tab, pat_orig_tab, pat_norm_tab, num_models) = rdat.version_setup(
    DATA_DIR = DATA_DIR, version = version, model_version = model_version,
    compatibility_mode=comp_mode)

## Define Model

In [None]:
# define model
(input_dim_img, output_dim, LOSS, layer_connection, last_activation) = md.model_setup(version)

model_3d = md.model_init(
    version = version, 
    output_dim = output_dim,
    LOSS = LOSS,
    layer_connection = layer_connection,
    last_activation = last_activation,
    C = 2,
    learning_rate = 5*1e-5,
    batch_size = 6,
    input_dim = input_dim_img,
    input_dim_tab = pat_norm_tab.drop(columns=["p_id"]).shape[1] if "LSX" in version else None,
)

In [None]:
# Define Model Name
generate_model_name = md.set_generate_model_name(
    model_version = model_version, 
    layer_connection = layer_connection, 
    last_activation = last_activation, 
    path = WEIGHT_DIR,
    compatability_mode=comp_mode)  

In [None]:
p_ids = [297]
(res_table, res_images, res_model_names, res_norm_table) = gc.get_img_and_models(
    p_ids, results = all_results_tab, pats = pat_ids, imgs = X_in, 
    gen_model_name = generate_model_name, norm_tab = pat_norm_tab,
    num_models = num_models)

In [None]:
y_pred_cl = "y_pred_class_avg_w"
pred_co = "pred_correct_w"
y_pred_prob = "y_pred_trafo_avg_w"
y_pred_u = "y_pred_unc_w"

invert_hm = "all" if res_table[y_pred_cl][0] == 0 else "none"
pos_hm = "last"
cmap = "jet"
hm_positive=True

## Plot Original


In [None]:
def plot_img(image, minima, maxima):
    plt.figure()
    plt.imshow(image, cmap='jet', vmin = minima, vmax = maxima, alpha=0.4) # jet / gray, alpha 0.4 / 1
    plt.axis("off")


avg_image = np.mean(res_images[0], axis = 2).squeeze()
plot_img(avg_image, np.min(avg_image), np.max(avg_image))

# GradCam

Ensemble GradCam first as comparison.

In [None]:
heatmap, resized_img, max_hm_slice, hm_mean_std, all_heatmaps = gc.multi_models_grad_cam_3d(
            img = np.expand_dims(res_images[0], axis = 0), 
            cnn = model_3d,
            model_names = res_model_names[0],
            layers = md.get_last_conv_layer(model_3d),
            model_mode = "weighted",
            pred_index = 0,
            invert_hm = invert_hm,
            pos_hm = pos_hm,
            # model weigths are only used when model_mode = "weighted"
            model_weights = res_table[0:1].reset_index(drop = True).loc[:, 
                res_table.columns.str.startswith("weight")].to_numpy().squeeze(),
            tabular_df = res_norm_table,
            normalize = False)

In [None]:
phm.plot_heatmap(resized_img, heatmap,
                version = "overlay",
                mode = "avg",
                hm_colormap=cmap,
                hm_positive=hm_positive,
                colorbar=True)

Now GradCam for one model with all needed intermediate steps.

In [None]:
model_3d_0 = model_3d
model_3d_0.load_weights(res_model_names[0][0])

img = np.expand_dims(res_images[0], axis = 0),
model_3d = model_3d_0
layer = md.get_last_conv_layer(model_3d)
normalize = False
pred_index=None
inv_hm=False
relu_hm=True

In [None]:
if model_3d.name == "cnn_3d_":
    grad_model = tf.keras.models.Model([model_3d.inputs], 
        [model_3d.get_layer(layer).output, model_3d.output])
elif model_3d.name == "mod_ontram":
    grad_model = tf.keras.models.Model([model_3d.inputs], 
        [model_3d.get_layer(layer).output, model_3d.get_layer("dense_complex_intercept").output])
        
with tf.GradientTape() as tape:
    conv_outputs, predictions = grad_model(img)
    # check for right model variant
    if model_3d.name == "mod_ontram":
        pred_index = 0
        predictions = predictions * -1 # ontram predicts cumulative dist therfore invert
    elif pred_index is None or model_3d.layers[-1].get_config().get("activation") == "sigmoid":
        pred_index = tf.argmax(predictions[0])
    class_channel = predictions[:, pred_index] # when sigmoid, pred_index must be None or 0

In [None]:
grads = tape.gradient(class_channel, conv_outputs)[0]

if model_3d.name == "mod_ontram" and not isinstance(model_3d.input, list):
    # output of CNN can be used for predictions
    grads = fm.sigmoid(predictions) * (1 - fm.sigmoid(predictions)) * grads # sigmoid gradient

In [None]:
avg_grad_maps = tf.reduce_mean(grads, axis=2).numpy()

for i in [0,1,2,3,-1]:
    plot_img(avg_grad_maps[:,:,i], np.min(avg_grad_maps), np.max(avg_grad_maps))

In [None]:
weights = tf.reduce_mean(grads, axis=(0, 1, 2)) 

In [None]:
weights[-1]

In [None]:
output = conv_outputs[0]   
heatmap = output @ weights[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)

In [None]:
heatmap.shape

In [None]:
avg_original_size_hm = tf.reduce_mean(heatmap, axis=2).numpy()
plot_img(avg_original_size_hm, np.min(avg_original_size_hm), np.max(avg_original_size_hm))

In [None]:
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)

In [None]:
avg_original_size_hm_relu = tf.reduce_mean(heatmap, axis=2).numpy()
plot_img(avg_original_size_hm_relu, np.min(avg_original_size_hm_relu), np.max(avg_original_size_hm_relu))

In [None]:
from skimage.transform import resize
heatmap = resize(heatmap.numpy(), img[0].shape[1:])

In [None]:
avg_finish = tf.reduce_mean(heatmap, axis=2).numpy().squeeze()
plot_img(avg_finish, np.min(avg_finish), np.max(avg_finish))

## Occlusion

In [None]:
occ_size = (18, 18, 4)
occ_stride = (10, 10, 3)

invert_hm = "pred_class"
both_directions = False
cmap = "jet"
hm_positive=True

In [None]:
(heatmap, resized_img, max_hm_slice, hm_mean_std, all_heatmaps) =  oc.volume_occlusion(
            volume = res_images, 
            res_tab = res_table, 
            occlusion_size = np.array(occ_size), 
            cnn = model_3d,
            invert_hm=invert_hm,
            tabular_df=res_norm_table,
            model_mode = "weighted",
            both_directions=both_directions,
            model_names = res_model_names[0],
            normalize = False,
            occlusion_stride = occ_stride)

In [None]:
phm.plot_heatmap(resized_img, heatmap,
                version = "overlay",
                mode = "avg",
                hm_colormap=cmap,
                hm_positive=hm_positive,
                colorbar=True)

In [None]:
heatmap.shape

In [None]:
for i in [0,1,2,3,4,6,13,20,-1]:
    hm = heatmap[:,:,i,0]
    plot_img(hm, np.min(hm), np.max(hm))

In [None]:
hm_avg = np.mean(heatmap, axis = 2).squeeze()
plot_img(hm_avg, np.min(hm_avg), np.max(hm_avg))