# Generate all Heatmaps (Occlusion of GradCam) for a given Version

Generates a PDF with all heatmaps for a given version.

- Choose heatmap type (Occlusion or GradCam)
- Choose if heatmaps should be genereated or loaded
- Choose if pictures (mean over all axis, highest heatmap value slice, original image) should be generated or loaded
- Choose if all patients should be used or only wrongly classified ones  
  

  
- Define if only the predicted class should be visualized (default: predicted class)
- Define if only last gradcam layer should be visualized (default: last layer)

## Load Libraries and Modules

In [20]:
# !pip install tqdm
# !pip install seaborn
# !pip install fpdf

In [1]:
%matplotlib inline

import os
import numpy as np
import pandas as pd
import gc as gci

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from fpdf import FPDF

import tensorflow as tf

print("TF  Version",tf.__version__)

TF  Version 2.2.0


In [2]:
# check and set path before loading modules
print(os.getcwd())
INPUT_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/"
OUTPUT_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/"
if os.getcwd() != OUTPUT_DIR:
    os.chdir(OUTPUT_DIR)

/tf


In [3]:
import functions_metrics as fm
import functions_read_data as rdat
import functions_model_definition as md
import functions_gradcam as gc
import functions_occlusion as oc
import functions_plot_heatmap as phm

Using TensorFlow backend.


## Load Data and Define Parameters

In [4]:
# Define the path + output path:
print(os.getcwd())
DATA_DIR = INPUT_DIR + "data/"

version = "CIBLSX" # one of:
# 10Fold_sigmoid_V0, 10Fold_sigmoid_V1, 10Fold_sigmoid_V2, 10Fold_sigmoid_V2f, 10Fold_sigmoid_V3
# 10Fold_softmax_V0, 10Fold_softmax_V1, andrea
# 10Fold_CIB, 10Fold_CIBLSX

# Define Outputs of the notebook
generate_heatmap_and_save = True # should the heatmap be generated and saved (else loaded)
generate_pictures = False # should the pictures be generated (else loaded)
only_wrong_out = False # should the generated pdf only contain the wrong predictions (else all)

# Define Model Version
model_version = 6

# define weighting
ens_mode = "wgt" # avg or wgt => average or weighted heatmap  

# define heatmap type
hm_type = "oc" # gc or oc => gradcam or occlusion
norm_hm = True # normalize heatmap (Occlusion is not normalized, gradcam is normalized over all heatmaps)
pred_hm_only = True   # if true heatmap of prediction will be generated else positive and negative heatmaps are shown
last_layer_only = True # Default = True, only last layer will be used for gradcam else once last and once all layers

# Select naming convention (for CIBLSX model_version >= 3 and CIB model_version >= 2 should be False, else True)
comp_mode = False # if True: use old naming convention

# define paths
DATA_DIR, WEIGHT_DIR, DATA_OUTPUT_DIR, PIC_OUTPUT_DIR, pic_save_name = rdat.dir_setup(
    INPUT_DIR, OUTPUT_DIR, version, model_version, 
    weight_mode = ens_mode, hm_type = hm_type, pred_hm = pred_hm_only, hm_norm = norm_hm,
    compatibility_mode=comp_mode)

/tf/notebooks/schnemau/xAI_stroke_3d


In [5]:
## load images and ids
(X_in, pat_ids, id_tab, all_results_tab, pat_orig_tab, pat_norm_tab, num_models) = rdat.version_setup(
    DATA_DIR = DATA_DIR, version = version, model_version = model_version,
    compatibility_mode=comp_mode)

## Model

In [6]:
# define model
(input_dim_img, output_dim, LOSS, layer_connection, last_activation) = md.model_setup(version)

model_3d = md.model_init(
    version = version, 
    output_dim = output_dim,
    LOSS = LOSS,
    layer_connection = layer_connection,
    last_activation = last_activation,
    C = 2,
    learning_rate = 5*1e-5,
    batch_size = 6,
    input_dim = input_dim_img,
    input_dim_tab = pat_norm_tab.drop(columns=["p_id"]).shape[1] if "LSX" in version else None,
)

In [7]:
model_3d.name

'mod_ontram'

In [8]:
# Define Model Name
generate_model_name = md.set_generate_model_name(
    model_version = model_version, 
    layer_connection = layer_connection, 
    last_activation = last_activation, 
    path = WEIGHT_DIR,
    compatability_mode=comp_mode)  

# Plot GradCams

## Calculate Heatmap and Heatmap Uncertainty

In [9]:
# select all patients
p_ids = all_results_tab["p_id"].to_numpy()

(res_table, res_images, res_model_names, res_norm_tab) = gc.get_img_and_models(
    p_ids, results = all_results_tab, pats = pat_ids, imgs = X_in,
    gen_model_name = generate_model_name, norm_tab = pat_norm_tab,
    num_models = num_models) # 10 Fold

In [10]:
res_table

Unnamed: 0,index,p_idx,p_id,mrs,unfavorable,fold0,fold1,fold2,fold3,fold4,...,threshold_avg,threshold_avg_w,y_pred_class_avg,y_pred_class_avg_w,y_pred_std,y_pred_unc,y_pred_std_w,y_pred_unc_w,pred_correct,pred_correct_w
0,0,1,1,1.0,0,val,train,train,train,train,...,0.321884,0.303965,0,0,0.000293,0.000760,0.000231,0.000720,True,True
1,1,2,2,1.0,0,val,train,train,train,train,...,0.321884,0.303965,0,0,0.010384,0.028348,0.008630,0.026852,True,True
2,2,3,3,0.0,0,train,train,train,train,val,...,0.237775,0.329759,1,1,0.060829,0.166263,0.032033,0.099670,False,False
3,3,4,5,0.0,0,train,test,train,train,train,...,0.148641,0.134165,1,1,0.082791,0.226305,0.028082,0.087377,False,False
4,4,5,6,3.0,1,train,val,train,train,train,...,0.424493,0.432574,0,0,0.009505,0.025944,0.008325,0.025903,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
402,402,403,555,0.0,0,train,test,train,train,train,...,0.148641,0.134165,1,1,0.035988,0.098349,0.024903,0.077484,False,False
403,403,404,556,0.0,0,train,train,test,train,train,...,0.230867,0.262792,0,0,0.000215,0.000547,0.000049,0.000153,True,True
404,404,405,557,1.0,0,train,train,train,train,val,...,0.237775,0.329759,0,0,0.001949,0.005287,0.000535,0.001664,True,True
405,405,406,559,1.0,0,train,train,train,val,train,...,0.087975,0.096875,1,1,0.006618,0.018053,0.000000,0.000000,False,False


Loop over all patients and generate heatmaps.

In [11]:
# lc = last conv layer
# ac = average over all conv layer

if pred_hm_only:
    pos_hm = "last" # gc
    both_directions = False # oc
    cmap = "jet" # both
    hm_positive=True # both
else:
    pos_hm = "none" # gc
    both_directions = True # oc
    cmap = "bwr" # both
    hm_positive=False # both
    
if "sigmoid" in version or "andrea_split" in version or "CI" in version:
    pred_idx = 0
elif "softmax" in version:
    pred_idx = 1
    
if hm_type == "oc":
#     occ_size = (20, 20, 16)
#     occ_stride = [6]
    occ_size = (18, 18, 4)
    occ_stride = (10, 10, 3)
    num_occlusion =  int(np.prod(((np.array(res_images.shape[1:4]) - occ_size) / occ_stride) + 1))
    print('number of occlusions per model: ', num_occlusion)
    print("number of occlusions per axis: ", ((np.array(res_images.shape[1:4]) - occ_size) / occ_stride) + 1)

if ens_mode == "avg":
    y_pred_cl = "y_pred_class_avg"
    model_mode = "mean"
elif ens_mode == "wgt":
    y_pred_cl = "y_pred_class_avg_w"
    model_mode = "weighted"

if generate_heatmap_and_save:

    heatmaps_lc = []
    max_hm_slices_lc = []
    hm_mean_stds_lc = []
    all_heatmaps_lc = []

    heatmaps_ac = []
    max_hm_slices_ac = []
    hm_mean_stds_ac = []
    all_heatmaps_ac = []

    resized_imgs = []

    for i in tqdm(range(len(res_table))):  
        # define if and how heatmap should be inverted
        if pred_hm_only and hm_type == "gc":
            invert_hm = "all" if res_table[y_pred_cl][i] == 0 else "none"
        elif not pred_hm_only and hm_type == "gc":
            invert_hm = "none"
        elif pred_hm_only and hm_type == "oc":
            invert_hm = "pred_class"
        elif not pred_hm_only and hm_type == "oc":
            invert_hm = "never"

        if hm_type == "gc":
            heatmap, resized_img, max_hm_slice, hm_mean_std, all_heatmaps = gc.multi_models_grad_cam_3d(
                img = res_images[i:i+1], 
                model_names = res_model_names[i],
                cnn = model_3d,
                layers = md.get_last_conv_layer(model_3d),
                model_mode = model_mode,
                pred_index = pred_idx,
                invert_hm = invert_hm,
                # model weigths are only used when model_mode = "weighted"
                model_weights = res_table[i:i+1].reset_index(drop = True).loc[:, 
                                    res_table.columns.str.startswith("weight")].to_numpy().squeeze(),
                tabular_df = res_norm_tab,
                pos_hm = pos_hm,
                normalize=norm_hm)
        elif hm_type == "oc":
            heatmap, resized_img, max_hm_slice, hm_mean_std, all_heatmaps =  oc.volume_occlusion(
                volume = res_images[i:i+1], 
                model_names = res_model_names[i],
                res_tab = res_table[i:i+1].reset_index(drop = True),
                tabular_df = res_norm_tab,
                cnn = model_3d,
                occlusion_size = np.array(occ_size), 
                occlusion_stride = occ_stride,
                model_mode = model_mode,
                both_directions = both_directions,
                invert_hm = invert_hm,
                normalize=norm_hm)
            
        heatmaps_lc.append(heatmap)
        max_hm_slices_lc.append(max_hm_slice)
        hm_mean_stds_lc.append(hm_mean_std)
        all_heatmaps_lc.append(all_heatmaps)

        if not last_layer_only and hm_type == "gc":
            vis_layers = [i.name for i in model_3d.layers[1:-6]]
            vis_layers = [vis_layer for vis_layer in vis_layers if vis_layer.startswith("CIB_Conv")]    

            heatmap, resized_img, max_hm_slice, hm_mean_std, all_heatmaps = gc.multi_models_grad_cam_3d(
            img = res_images[i:i+1], 
            cnn = model_3d,
            model_names = res_model_names[i],
            layers = vis_layers,
            model_mode = model_mode,
            pred_index = pred_idx,
            invert_hm = invert_hm,
            tabular_df = res_norm_tab,
            pos_hm = pos_hm)

            heatmaps_ac.append(heatmap)
            max_hm_slices_ac.append(max_hm_slice)
            hm_mean_stds_ac.append(hm_mean_std)
            all_heatmaps_ac.append(all_heatmaps)

        resized_imgs.append(resized_img)
        
        gci.collect()        
        
else:
    res_table = pd.read_csv(DATA_OUTPUT_DIR + "all_tab_results_hm_unc_" + pic_save_name + ".csv",  sep = ",")
    heatmaps_lc = np.load(DATA_OUTPUT_DIR + "all_heatmaps_" + pic_save_name + ".npy")
    max_hm_slices_lc = np.load(DATA_OUTPUT_DIR + "all_max_activation_indices_" + pic_save_name + ".npy", allow_pickle = True)
    if not last_layer_only and hm_type == "gc":
        heatmaps_ac = np.load(DATA_OUTPUT_DIR + "all_heatmaps_average_conv_layer_" + pic_save_name + ".npy")
        max_hm_slices_ac = np.load(DATA_OUTPUT_DIR + "all_max_activation_indices_laverage_conv_layer_" + pic_save_name + ".npy")
    

number of occlusions per model:  1296
number of occlusions per axis:  [12. 12.  9.]


  0%|          | 0/407 [00:00<?, ?it/s]

Model: "mod_ontram"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 128, 128, 28 0                                            
__________________________________________________________________________________________________
CIB_Conv3D0 (Conv3D)            (None, 128, 128, 28, 896         input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling3d (MaxPooling3D)    (None, 64, 64, 14, 3 0           CIB_Conv3D0[0][0]                
__________________________________________________________________________________________________
CIB_Conv3D1 (Conv3D)            (None, 64, 64, 14, 3 27680       max_pooling3d[0][0]              
_________________________________________________________________________________________

  0%|          | 0/407 [00:02<?, ?it/s]


ValueError: Data cardinality is ambiguous:
  x sizes: 1296, 527472
Please provide data which shares the same first dimension.

In [25]:
if hm_type == "gc" and not norm_hm:
    hm_min = np.min(heatmaps_lc)
    hm_max = np.max(heatmaps_lc)
    print(hm_min, hm_max)
    np.save(DATA_OUTPUT_DIR + "hm_min_max_" + pic_save_name, np.array([hm_min, hm_max]))
    
    heatmaps_lc = fm.normalize_heatmap(np.array(heatmaps_lc), both_directions=False, hm_min_max = (hm_min, hm_max))

In [26]:
if generate_heatmap_and_save:   
    res_table["heatmap_std_last_layer"] = hm_mean_stds_lc
    res_table["heatmap_unc_last_layer"] = (res_table["heatmap_std_last_layer"] - res_table.heatmap_std_last_layer.min()) / (
        res_table.heatmap_std_last_layer.max() - res_table.heatmap_std_last_layer.min())
    
    if not last_layer_only:
        res_table["heatmap_std_avg_layer"] = hm_mean_stds_ac
        res_table["heatmap_unc_avg_layer"] = (res_table["heatmap_std_avg_layer"] - res_table.heatmap_std_avg_layer.min()) / (
            res_table.heatmap_std_avg_layer.max() - res_table.heatmap_std_avg_layer.min())


ValueError: Length of values does not match length of index

#### Evaluate Metrics

Calculate heatmap uncertainty. Which is the normalized (min-max) averaged standard deviation over each pixel. 

In [None]:
if not last_layer_only and hm_type == "gc":
    print(np.corrcoef(res_table["heatmap_unc_avg_layer"], res_table["heatmap_unc_last_layer"]))
    print(np.corrcoef(res_table["y_pred_unc"], res_table["heatmap_unc_avg_layer"]))
print(np.corrcoef(res_table["y_pred_unc"], res_table["heatmap_unc_last_layer"]))

In [None]:
if not last_layer_only and hm_type == "gc":
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize = (15, 5))
else:
    fig, (ax1, ax3) = plt.subplots(1,2, figsize = (10, 5))

sns.boxplot(x = "unfavorable",
    y = "y_pred_unc",
    data = res_table,
    ax = ax1)
sns.stripplot(x = "unfavorable",
    y = "y_pred_unc",
    hue = y_pred_cl,
    alpha = 0.75,
    palette=["C2", "C3"],
    data = res_table,
    ax = ax1)
ax1.legend(title='predicted class', loc='upper center')
ax1.set(xlabel='true class', ylabel='prediction uncertainty')

if not last_layer_only:
    sns.boxplot(x = "unfavorable",
        y = "heatmap_unc_avg_layer",
        data = res_table,
        ax = ax2)
    sns.stripplot(x = "unfavorable",
        y = "heatmap_unc_avg_layer",
        hue = y_pred_cl,
        alpha = 0.75,
        palette=["C2", "C3"],
        data = res_table,
        ax = ax2)
    ax2.legend(title='predicted class', loc='upper center')
    ax2.set(xlabel='true class', ylabel='heatmap uncertainty avg layer')

sns.boxplot(x = "unfavorable",
    y = "heatmap_unc_last_layer",
    data = res_table,
    ax = ax3)
sns.stripplot(x = "unfavorable",
    y = "heatmap_unc_last_layer",
    hue = y_pred_cl,
    alpha = 0.75,
    palette=["C2", "C3"],
    data = res_table,
    ax = ax3)
ax3.legend(title='predicted class', loc='upper center')
ax3.set(xlabel='true class', ylabel='heatmap uncertainty')

In [None]:
sns.scatterplot(
           x = "heatmap_unc_last_layer",
           y = "y_pred_unc",
            data = res_table)

#### Save Heatmaps, Images and updated Table

In [None]:
if generate_heatmap_and_save:
    res_table.to_csv(DATA_OUTPUT_DIR + "all_tab_results_hm_unc_" + pic_save_name + ".csv",  index=False)
    np.save(DATA_OUTPUT_DIR + "all_heatmaps_" + pic_save_name + ".npy", heatmaps_lc)
    np.save(DATA_OUTPUT_DIR + "all_max_activation_indices_" + pic_save_name + ".npy", max_hm_slices_lc)
    np.save(DATA_OUTPUT_DIR + "all_ensemble_heatmaps_" + pic_save_name + ".npy", all_heatmaps_lc)
    
    if not last_layer_only and hm_type == "gc":
        np.save(DATA_OUTPUT_DIR + "all_heatmaps_average_conv_layer_" + pic_save_name + ".npy", heatmaps_ac)
        np.save(DATA_OUTPUT_DIR + "all_max_activation_indices_laverage_conv_layer_" + pic_save_name + ".npy", max_hm_slices_ac)
        np.save(DATA_OUTPUT_DIR + "all_ensemble_heatmaps_average_conv_layer_" + pic_save_name + ".npy", all_heatmaps_ac)

## Plot Average Heatmaps

Plot the average heatmaps for all patients. Once for class 0, once for class 1 and once for all patients.

In [None]:
idx = np.where(res_table[y_pred_cl] == 0)

In [None]:
mean_hm_lc = np.array(np.take(heatmaps_lc, idx, axis = 0).squeeze()).mean(axis = 0)
if not last_layer_only:
    mean_hm_ac = np.array(np.take(heatmaps_ac, idx, axis = 0).squeeze()).mean(axis = 0)
mean_image = np.array(np.take(res_images, idx, axis = 0).squeeze()).mean(axis = 0)

phm.plot_heatmap(mean_image, mean_hm_lc,
            version = "overlay",
            mode = "avg",
            hm_colormap = cmap,
            hm_positive = hm_positive)
if not last_layer_only:
    phm.plot_heatmap(mean_image, mean_hm_ac,
                version = "overlay",
                mode = "avg")

In [None]:
idx = np.where(res_table[y_pred_cl] == 1)

In [None]:
mean_hm_lc = np.array(np.take(heatmaps_lc, idx, axis = 0).squeeze()).mean(axis = 0)
if not last_layer_only:
    mean_hm_ac = np.array(np.take(heatmaps_ac, idx, axis = 0).squeeze()).mean(axis = 0)
mean_image = np.array(np.take(res_images, idx, axis = 0).squeeze()).mean(axis = 0)

phm.plot_heatmap(mean_image, mean_hm_lc,
            version = "overlay",
            mode = "avg",
            hm_colormap = cmap,
            hm_positive = hm_positive)
if not last_layer_only:
    phm.plot_heatmap(mean_image, mean_hm_ac,
                version = "overlay",
                mode = "avg")

In [None]:
idx = np.arange(0,407)

In [None]:
mean_hm_lc = np.array(np.take(heatmaps_lc, idx, axis = 0).squeeze()).mean(axis = 0)
if not last_layer_only:
    mean_hm_ac = np.array(np.take(heatmaps_ac, idx, axis = 0).squeeze()).mean(axis = 0)
mean_image = np.array(np.take(res_images, idx, axis = 0).squeeze()).mean(axis = 0)

phm.plot_heatmap(mean_image, mean_hm_lc,
            version = "overlay",
            mode = "avg",
            hm_colormap = cmap,
            hm_positive = hm_positive)
if not last_layer_only:
    phm.plot_heatmap(mean_image, mean_hm_ac,
                version = "overlay",
                mode = "avg")

## Save Plots as PNG

In [None]:
if generate_pictures:
    if not last_layer_only:
        phm.plot_gradcams_last_avg_org(
            res_table = res_table, 
            vis_layers = vis_layers,
            res_images = res_images,
            res_model_names = res_model_names,
            model_3d = model_3d,
            layer_mode = "mean", 
            heatmap_mode = "avg", 
            save_path = PIC_OUTPUT_DIR, 
            save_name = pic_save_name, save = True)

        phm.plot_gradcams_last_avg_org(
            res_table = res_table, 
            vis_layers = vis_layers,
            res_images = res_images,
            res_model_names = res_model_names,
            model_3d = model_3d,
            layer_mode = "mean",
            heatmap_mode = "max", 
            save_path = PIC_OUTPUT_DIR, 
            save_name = pic_save_name, save = True)
    else:
        phm.plot_heatmaps_avg_max_org(
            pat_data = pat_orig_tab,
            res_table = res_table, 
            res_images = res_images,
            heatmaps = heatmaps_lc,
            cmap = cmap,
            hm_positive = hm_positive,
            save_path = PIC_OUTPUT_DIR, 
            save_name = pic_save_name, save = True,
            res_mode=ens_mode)
    


## Save Plots to PDF

In [None]:
# only_wrong_out = True # should already be defined above

In [None]:
if not only_wrong_out: # all ids
    pat_ids = list(res_table["p_id"])
else: # only ids with low uncertainty and wrong classified
    # pat_ids = list(res_table.query("pred_correct == False and y_pred_unc < 0.2").p_id)
    pat_ids = list(res_table.query("pred_correct == False").p_id)
    res_table[res_table.p_id.isin(pat_ids)].to_csv(
        DATA_OUTPUT_DIR + "all_tab_results_hm_unc_" + pic_save_name + "_wrong_cl.csv",  index=False)

In [None]:
res_table[res_table.p_id.isin(pat_ids)]

In [None]:
pdf = FPDF()
pdf.set_auto_page_break(0)

# imagelist is the list with all image filenames
for patient in tqdm(pat_ids):
    
    name_start = PIC_OUTPUT_DIR + "pat" + str(patient) + "_" + pic_save_name
    
    if not last_layer_only:
        pdf.add_page(orientation="L")  # Use default page size (A4) in landscape mode
        pdf.set_left_margin(10)
        pdf.set_right_margin(10)
        x, y, w, h = (0, 10, 190, 190)
        pdf.image(name_start + "_last_and_all_layers_avg.png", x, y, w, h)
        x, y, w, h = (140, 10, 190, 190)
        pdf.image(name_start + "_last_and_all_layers_max.png", x, y, w, h)
    else:
        pdf.add_page(orientation="P")  # Use default page size (A4) in portrait mode
        pdf.set_left_margin(10)
        pdf.set_right_margin(10)
        x, y, w, h = (0, 10, 205, 205)
        pdf.image(name_start + "_last_layer_avg_max_orig.png", x, y, w, h)

if only_wrong_out:
    pdf.output(PIC_OUTPUT_DIR + "0_all_heatmaps_" + pic_save_name + "_all_patients_wrong_cl.pdf", "F")
else:
    pdf.output(PIC_OUTPUT_DIR + "0_all_heatmaps_" + pic_save_name + "_all_patients.pdf", "F")


In [None]:
print("done")