In [None]:
# Basics
import numpy as np
import cv2
import matplotlib.pyplot as plt

# Tensorflow
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.preprocessing import image

# Pre-trained Models
# from tensorflow.keras.applications.convnext     import ConvNeXtBase,   preprocess_input, decode_predictions
# from tensorflow.keras.applications.densenet     import DenseNet121,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.densenet     import DenseNet169,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.densenet     import DenseNet201,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.efficientnet import EfficientNetB1, preprocess_input, decode_predictions
# from tensorflow.keras.applications.efficientnet import EfficientNetB2, preprocess_input, decode_predictions
# from tensorflow.keras.applications.efficientnet import EfficientNetB3, preprocess_input, decode_predictions
# from tensorflow.keras.applications.efficientnet import EfficientNetB4, preprocess_input, decode_predictions
# from tensorflow.keras.applications.efficientnet import EfficientNetB5, preprocess_input, decode_predictions
# from tensorflow.keras.applications.efficientnet import EfficientNetB6, preprocess_input, decode_predictions
# from tensorflow.keras.applications.efficientnet import EfficientNetB7, preprocess_input, decode_predictions
# from tensorflow.keras.applications.inception_v3 import ResInceptionV3, preprocess_input, decode_predictions
# from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.nasnet       import NASNetMobile,   preprocess_input, decode_predictions
from tensorflow.keras.applications.resnet50     import ResNet50,       preprocess_input, decode_predictions
# from tensorflow.keras.applications.resnet_rs    import ResNetRS101,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.resnet_rs    import ResNetRS152,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.resnet_rs    import ResNetRS200,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.resnet_v2    import ResNet50V2,     preprocess_input, decode_predictions
# from tensorflow.keras.applications.resnet_v2    import ResNet101V2,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.resnet_v2    import ResNet152V2,    preprocess_input, decode_predictions
# from tensorflow.keras.applications.vgg16        import VGG16,          preprocess_input, decode_predictions
# from tensorflow.keras.applications.vgg19        import VGG19,          preprocess_input, decode_predictions
# from tensorflow.keras.applications.xception     import Xception,       preprocess_input, decode_predictions

In [None]:
IMAGE_PATH = ''  # str
MODEL_PATH = ''  # str
INPUT_DIMS = 224 # int
INTENSITY  = 0.5 # float

In [None]:
# 1. Get the CNN Classifier

# # Load from dir
# model = load_model(MODEL_PATH)

# Or download pretrained
model = ResNet50(weights='imagenet')

In [None]:
# 2. Find the last convolutional leyer
last_conv_layer = None
for layer in model.layers:
    if isinstance(layer, Conv2D):
        last_conv_layer = layer
        
print(f'Last Convolutional Layer name is: {last_conv_layer.name}')

In [None]:
# 3. Load the original image
img = image.load_img(IMAGE_PATH, target_size=(INPUT_DIMS, INPUT_DIMS))

# Get a first look at the image
plt.imshow(cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)) 
# Display Graph
plt.show()

In [None]:
# 4. Preprocess the image ACCORDING to the chosen network
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

# Find Original Preds
preds = model.predict(x)
predicted_class = decode_predictions(preds)[0][0]

# 5. Get the gradient of the class output with respect to the last convolutional layer
with tf.GradientTape() as tape:
    iterate = Model([model.inputs], [model.output, last_conv_layer.output])
    model_out, last_conv_layer = iterate(x)
    class_out = model_out[:, np.argmax(model_out[0])]
    grads = tape.gradient(class_out, last_conv_layer)
    pooled_grads = K.mean(grads, axis=(0, 1, 2))

In [None]:
# 6. Get a heatmap showing the grads of the last convolutional layer for the predicted class
heatmap = tf.reduce_mean(tf.multiply(pooled_grads, last_conv_layer), axis=-1)

# Normalize the heatmap
heatmap  = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
heatmap  = np.squeeze(heatmap) # Get rid of unnecessary dims

# Plot the heatmap
plt.matshow(heatmap)
plt.show()

In [None]:
# 7. Get Final Results
# Create the heatmap on the original image
img = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
heatmap_img = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap_img = cv2.applyColorMap(np.uint8(255*heatmap_img), cv2.COLORMAP_JET)
heatmap_img = heatmap_img * INTENSITY + img

# Create a figure and axis for the images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,6))
# Plot the first image on the first axis
ax1.imshow(img)
ax1.set_title('Original Image')
# Plot the second image on the second axis
ax2.imshow(heatmap_img/255)
ax2.set_title('With Grads')
ax2.set_xlabel(f'Predicted Class: {predicted_class}')
# Display the figure
plt.show()