In [None]:
import os
import warnings
import sys
import json
import glob
import types
import cv2
import time
import random
import argparse
import importlib
import numpy as np
import nibabel as nib
from pathlib import Path
from random import shuffle
from datetime import datetime
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}
warnings.filterwarnings("ignore", category=DeprecationWarning)
sys.path.append(r"/local/data1/elech646/source/ResNet50")

# Define GradCAM object
'''
Grad-CAM implementation [1] as described in post available at [2].
[1] Selvaraju RR, Cogswell M, Das A, Vedantam R, Parikh D, Batra D. Grad-cam:
    Visual explanations from deep networks via gradient-based localization.
    InProceedings of the IEEE international conference on computer vision 2017
    (pp. 618-626).
[2] https://www.pyimagesearch.com/2020/03/09/grad-cam-visualize-class-activation-maps-with-keras-tensorflow-and-deep-learning/
'''

class gradCAM:
    def __init__(self, model, classIdx, layerName=None,
                use_image_prediction=True,
                ViT=False,
                is_3D=False,
                debug=False):
        '''
        model: model to inspect
        classIdx: index of the class to ispect
        layerName: which layer to visualize
        '''
        self.model = model
        self.classIdx = classIdx
        self.layerName = layerName
        self.debug = debug
        self.use_image_prediction = use_image_prediction
        self.is_ViT = ViT
        self.is_3D = is_3D

        # if the layerName is not provided, find the last conv layer in the model
        if self.layerName is None:
            self.layerName = self.find_target_layer()
        else:
            if self.debug is True:
                print('GradCAM - using layer {}'.format(self.model.get_layer(self.layerName).name))

    def find_target_layer(self):
        '''
        Finds the last convolutional layer in the model by looping through the
        available layers.
        '''
        for layer in reversed(self.model.layers):
            # check if it is a 2D conv layer (which means that needs to have
            # 4 dimensions [batch, width, height, channels])
            if len(layer.output_shape) == 4:
                # check that is a conv layer
                if layer.name.find('conv') != -1:
                    if self.debug is True:
                        print('GradCAM - using layer {}'.format(layer.name))
                    return layer.name

        if self.layerName is None:
            # if no convolutional layer have been found, rase an error since
            # Grad-CAM can not work
            raise ValueError('Could not find a 4D layer. Cannot apply GradCAM')

    def compute_heatmap(self, image, eps=1e-6):
        '''
        Compute the L_grad-cam^c as defined in the original article, that is the
        weighted sum over feature maps in the given layer with weights based on
        the importance of the feature map on the classsification on the inspected
        class.
        This is done by supplying
        1 - an input to the pre-trained model
        2 - the output of the selected conv layer
        3 - the final softmax activation of the model
        '''
        # this is a gradient model that we will use to obtain the gradients from
        # with respect to an image to construct the heatmaps
        gradModel = tf.keras.Model(
                inputs=[self.model.inputs],
                outputs=[self.model.get_layer(self.layerName).output,
                self.model.output])

        # replacing softmax with linear activation
        gradModel.layers[-1].activation = tf.keras.activations.linear

        if self.debug is True:
            gradModel.summary()

        # use the tensorflow gradient tape to store the gradients
        with tf.GradientTape() as tape:
            '''
            Cast image tensor to a float-32 data type, pass the
            image through the gradient model, and grab the loss
            associated with the specific class index.
            '''
            inputs = tf.cast(image, tf.float32)
            (convOutputs, predictions) = gradModel(inputs)
            # check if the prediction is a list (VAE)
            if type(predictions) is list:
                # the model is a VEA, taking only the prediction
                predictions = predictions[4]
            pred = tf.argmax(predictions, axis=1)
            loss = predictions[:, self.classIdx]

        grads = tape.gradient(loss, convOutputs)
        # sometimes grads becomes NoneType
        if grads is None:
            grads = tf.zeros_like(convOutputs)
        '''
        Compute the guided gradients.
         - positive gradients if the classIdx matches the prediction (I want to
            know which values make the probability of that class to be high)
         - negative gradients if the classIdx != the predicted class (I want to
            know which gradients pushed down the probability for that class)
        '''
        if self.use_image_prediction == True:
            if self.classIdx == pred:
                castConvOutputs = tf.cast(convOutputs > 0, tf.float32)
                castGrads = tf.cast(grads > 0, tf.float32)
            else:
                castConvOutputs = tf.cast(convOutputs <= 0, tf.float32)
                castGrads = tf.cast(grads <= 0, tf.float32)
        else:
            castConvOutputs = tf.cast(convOutputs > 0, tf.float32)
            castGrads = tf.cast(grads > 0, tf.float32)
        guidedGrads = castConvOutputs * castGrads * grads

        # remove the batch dimension
        convOutputs = convOutputs[0]
        guidedGrads = guidedGrads[0]

        # compute the weight value for each feature map in the conv layer based
        # on the guided gradient
        weights = tf.reduce_mean(guidedGrads, axis=(0,1))
        cam = tf.reduce_sum(tf.multiply(weights, convOutputs), axis=-1)

        # now that we have the activation map for the specific layer, we need
        # to resize it to be the same as the input image
        if self.is_ViT:
            if self.is_3D:
                # here we take the middle slice (don't take mean or sum since the
                # channels are not conv filters, but the actual activation for
                # the different images in the sequence). This is different compared
                # to a normal conv3d, where the channels are descriptive of all
                # the images at the same time
                dim = int(np.sqrt(cam.shape[0]/image.shape[3]))
                (w, h) = (image.shape[2], image.shape[1])
                heatmap = cam.numpy().reshape((dim, dim, image.shape[3]))
                heatmap = heatmap[:,:,heatmap.shape[-1] // 2]
                heatmap = cv2.resize(heatmap,(w, h))
            else:
                dim = int(np.sqrt(cam.shape[0]))
                (w, h) = (image.shape[2], image.shape[1])
                heatmap = cam.numpy().reshape((dim, dim))
                heatmap = cv2.resize(heatmap,(w, h))
        else:
            if self.is_3D:
                # reshape cam to the layer input shape and then take the middle
                # slice
                layer_shape = self.model.get_layer(self.layerName).input_shape
                heatmap = cam.numpy().reshape((layer_shape[1], layer_shape[2], layer_shape[3]))
                heatmap = np.mean(heatmap, axis=-1)
                # heatmap = heatmap[:,:,heatmap.shape[-1]//2]
                (w, h) = (image.shape[2], image.shape[1])
                heatmap = cv2.resize(heatmap,(w, h))
            else:
                (w, h) = (image.shape[2], image.shape[1])
                heatmap = cv2.resize(cam.numpy(),(w, h))

        # normalize teh heat map in [0,1] and rescale to [0, 255]
        numer = heatmap - np.min(heatmap)
        denom = (heatmap.max() - heatmap.min()) + eps
        heatmap = (numer/denom)
        heatmap_raw = (heatmap * 255).astype('uint8')

        # create heatmap based on the colormap setting
        heatmap_rgb = cv2.applyColorMap(heatmap_raw, cv2.COLORMAP_VIRIDIS).astype('float32')

        return heatmap_raw, heatmap_rgb

    def overlay_heatmap(self, heatmap, image, alpha=0.5, colormap=cv2.COLORMAP_VIRIDIS):

        # create heatmap based ont he colormap setting
        heatmap = cv2.applyColorMap(heatmap, colormap).astype('float32')

        if image.shape[-1] == 1:
            # convert image from grayscale to RGB
            image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB).astype('float32')

        output = cv2.addWeighted(image, alpha, heatmap, (1 - alpha), 0)

        # return both the heatmap and the overlayed output
        return (heatmap, output)

# START ACTUAL SCRIPT - GET PATHS
MODEL_PATH = r"/local/data1/elech646/source/train_logs/gradCAM"

# create path for saving gradCAM
SAVE_PATH = os.path.join(r'/local/data1/elech646/source/Plots/new_plots/gradCAM', 'resnet50_flair_sagittal_TEST')
Path(SAVE_PATH).mkdir(parents=True, exist_ok=True)

# CREATE DATA GENERATOR (THIS SHOULD BE CHANGED BASED ON HOW YOUR GENERATOR WORKS)
img_height, img_width = 224, 224
batch_size = 1

test_datagen = ImageDataGenerator(samplewise_center = True, rescale = 1./255)
test = test_datagen.flow_from_directory('/local/data1/elech646/Tumor_grade_classification/dataset224_RGB/dataset224_flair_sagittal/train', 
                                        classes = ['G2','G3','G4'], color_mode = 'rgb',
                                        shuffle = False, class_mode = 'categorical', 
                                        target_size = (img_height, img_width), 
                                        batch_size = batch_size)
nbr_test_img = len(test)

# LOAD TRAINED MODEL OF CHOICE
m_name = 'resnet50_flair_sagittal'
model = tf.keras.models.load_model(os.path.join(MODEL_PATH, m_name, ''), compile=False)

# CHECK FOR CONVOLUTIONAL LAYERS IN THE MODEL
'''
This heuristic needs to change based on the model architecture as well as layer names.
Check model.summary() to get an idea of the different names.
'''
name_layers = []
print('Looking for 2D conv layers...')
for layer in model.layers:
    # Careful here, each model has unique layer names
    if 'conv' in layer.name:
    # if 'conv2d' in layer.name:
        # here no conv blocks
        name_layers.append(layer.name)

print('Found {} layers -> {}'.format(len(name_layers), name_layers))
name_layers = name_layers[-5::]

# COMPUTE GradCAM FOR EACH IMAGE IN THE DATASET AND LAYER
'''
Here we loop through the images in the given dataset and:
 - compute model prediction on the image
 - save image, ground truth (used later for plotting) and the prediction
 - compute the activation map for the image with respect to all the layers
Note that one can reduce the number of layers by removing the layer names
from the name_layers variable.
'''
test_images = []
pred_logits = []
labels = []
heatmap_raw = []
heatmap_rgb = []

# compute activation maps for each image and each network layer
for idx, (img, label) in enumerate(test):
    print(f'Computing activation maps for each layer for the predicted class: {idx}/{nbr_test_img} \r', end='')
    # get model prediction for this image
    img_pred_logits = model.predict(img)
    # get model classification
    c = np.argmax(img_pred_logits)
    # save pred classification for mater
    pred_logits.append(img_pred_logits)
    # save ground truth and img
    test_images.append(img)
    labels.append(np.argmax(label, axis=-1)[0])
    # for all the images, compute heatmap for all the layers
    heatmap_raw.append([])
    heatmap_rgb.append([])
    for nl in name_layers:
        cam = gradCAM(model, c, layerName = nl)
        aus_raw, aus_rgb = cam.compute_heatmap(img)
        heatmap_raw[idx].append(aus_raw)
        heatmap_rgb[idx].append(aus_rgb)

    if idx == nbr_test_img:
        break
print('\nDone.')

In [None]:
filepath = []
for roots, dirs, files in os.walk("/local/data1/elech646/Tumor_grade_classification/dataset224_RGB/dataset224_t1ce_sagittal/train"):
    for name in files:
        if name.endswith(".png"):
            filepath.append(roots + os.path.sep + name)
            
print(filepath)

In [None]:
# PLOT THE ACTIVATION MAPS (JUST FANCY PLOTTING - CAN ALWAYS BE IMPROVED)
layers_to_print = 3 # this specifies how many layers to print from the last
n_samples_per_image = 3
n_images = nbr_test_img // n_samples_per_image
fix_image = False
save_images = True

# for fancy LaTeX style plots
import matplotlib as mpl
mpl.rc('font',family = 'serif', serif = 'cmr10')
plt.rcParams["font.serif"] = "cmr10"
plt.rcParams["axes.formatter.use_mathtext"] = True

for i in range(n_images):
    print(f'Creating figure {i+1:3d}/{n_images}\r', end='')

    # create figure
    # set different axis aspect ratios, the last axes is for the heatmap -> smaller axes
    aus = [1 for i in range(len(range(layers_to_print)) + 1)]
    aus[-1] = 0.1
    gridspec = {'width_ratios': aus}
    fig, axes = plt.subplots(nrows=n_samples_per_image, ncols=len(range(layers_to_print)) + 1, figsize=(layers_to_print*5,n_samples_per_image*2))
    if len(axes.shape) == 1:
        axes = np.expand_dims(axes,axis=0)
    # fig.suptitle('Consecutive activation maps', fontsize=16)

    # fill in all axes
    for j in range(n_samples_per_image):
        idx = i*n_samples_per_image + j
        if idx >= nbr_test_img:
            break

        # original image
        original_image = np.squeeze(test_images[idx][:,:,:,0])
        if fix_image == True:
            original_image = np.rot90(original_image, k=1)
        axes[j, 0].imshow(original_image, cmap='gray', interpolation=None)
        pred_str = [f'{i:0.2f}' for i in pred_logits[idx][0]]
        if labels[idx] == np.argmax(pred_logits[idx]):
            axes[j, 0].set_title(f'GT {labels[idx]} - Pred {pred_str}', color='g')
        else:
            axes[j, 0].set_title(f'GT {labels[idx]} - Pred {pred_str}', color='r')

        axes[j, 0].set_xticks([])
        axes[j, 0].set_yticks([])

        # layer heatmaps
        for idx1, idx2 in enumerate(reversed(range(layers_to_print))):
            heat_map_image = heatmap_raw[idx][-(idx2+1)]/255
            layer_name = name_layers[-(idx2+1)]
            if fix_image == True:
                heat_map_image = np.rot90(heat_map_image, k=1)
            im = axes[j, idx1+1].imshow(heat_map_image, cmap='jet', vmin=0, vmax=1, interpolation=None)
            axes[j, idx1+1].set_title(f'layer {layer_name}')
            axes[j, idx1+1].set_xticks([])
            axes[j, idx1+1].set_yticks([])

        # add colorbar as an extra axis
        aspect = 20
        pad_fraction = 0.5

        divider = make_axes_locatable(axes[j,-1])
        width = axes_size.AxesY(axes[j,-1], aspect=1./aspect)
        pad = axes_size.Fraction(pad_fraction, width)
        cax = divider.append_axes("right", size=width, pad=pad)
        plt.colorbar(im, cax=cax)

    if save_images:
        fig.savefig(os.path.join(SAVE_PATH, 'activationMap_forConsecutiveLayers_%03d.png' % i), bbox_inches='tight', dpi = 400)
        plt.close(fig)
    else:
        plt.show()