Importing Libraries

In [None]:
from tf_explain.core.grad_cam import GradCAM
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from matplotlib import cm
import cv2
import tensorflow as tf
from PIL import Image
import numpy as np
import os
import glob

Loading Model and Setting Paths

In [None]:
model = load_model("path/to/trained/model/")
images_directory_path = "path/to/greyscale/session/images/"
heatmaps_directory_path = "path/to/directory/for/generated/heatmaps/"

for layer in reversed(model.layers):
    if 'conv' in layer.name:
        last_conv_layer_name = layer.name
        
session_images = os.listdir(images_directory_path)
image_size = (32, 32)
count = 0

''' True if Binary, and False if Multi-Class Classification'''
binary_classification = True

Generating Heatmaps

In [None]:
for image_full_name in session_images:
    name = str(image_full_name).split('.')[0]
    try:
        image = Image.open(images_directory_path+image_full_name)
        image = image.resize(image_size)
        input_image = np.array(image) / 255.0
        input_image = np.expand_dims(input_image, axis=0)
    except Exception as e:
        print('Error occurred while loading the image:', str(e))
        exit(1)
        
    prediction = model.predict(input_image)
    
    explainer = GradCAM()
    
    if binary_classification:
        if prediction >= 0.5:
            prediction = 1
        else:
            prediction = 0
            
        heatmap = explainer.explain(validation_data = (input_image, None), model=model, class_index = 0, layer_name = last_conv_layer_name)
        
    else:
        prediction = np.argmax(prediction, axis=1)

        heatmap = explainer.explain(validation_data = (input_image, None), class_index = prediction[0], model=model, layer_name = last_conv_layer_name)
        

    plt.imshow(heatmap)
    plt.axis('off')
    plt.savefig(f"{heatmaps_directory_path}heatmap_{name}_{prediction}.png", bbox_inches = 'tight', pad_inches = 0)
    
    count += 1
    print(count)