In [1]:
!pip install tf-explain


Collecting tf-explain
  Downloading tf_explain-0.3.1-py3-none-any.whl.metadata (9.3 kB)
Downloading tf_explain-0.3.1-py3-none-any.whl (43 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tf-explain
Successfully installed tf-explain-0.3.1


In [2]:
import tensorflow as tf
from tf_explain.core.grad_cam import GradCAM
from google.colab import files
from IPython.display import Image, display
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image as keras_image


In [3]:
def grad_cam_visualization(image_paths, indices):
    """
    Function to generate Grad-CAM visualizations for the provided images.
    """
    for i in range(len(image_paths)):
        each_path = image_paths[i]
        index = indices[i]  # The target class index

        # Load and preprocess the image
        img = tf.keras.preprocessing.image.load_img(each_path, target_size=(224, 224))
        img_array = tf.keras.preprocessing.image.img_to_array(img)
        img = tf.keras.applications.vgg16.preprocess_input(img_array)
        data = ([img], None)

        # Extract image name for saving
        name = each_path.split("/")[-1].split(".")[0]

        # Load the pre-trained VGG16 model
        model = tf.keras.applications.vgg16.VGG16(weights="imagenet", include_top=True)

        # Debugging: List all model layers
        print("\nModel Layers:")
        for layer in model.layers:
            print(layer.name)

        # Specify the target layer
        target_layer = "block5_conv3"  # Change this if needed to another layer in your model

        # Debugging: Check if the target layer exists
        if target_layer not in [layer.name for layer in model.layers]:
            raise ValueError(f"Target layer '{target_layer}' not found in model!")

        # Debugging: Ensure index is an integer
        if not isinstance(index, int):
            raise ValueError(f"Index must be an integer, but got: {type(index)} ({index})")

        # Generate Grad-CAM visualization
        explainer = GradCAM()
        grid = explainer.explain(data, model, class_index=index, layer_name=target_layer)

        # Save the Grad-CAM result
        explainer.save(grid, ".", name + "_grad_cam.png")

        # Display both images side by side
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        # Load and display the input image (resize it smaller)
        input_img = keras_image.load_img(each_path, target_size=(100, 100))  # Resize input image
        input_img = keras_image.img_to_array(input_img)
        axes[0].imshow(input_img.astype("uint8"))
        axes[0].axis('off')  # Hide axes for the input image
        axes[0].set_title("Original Image")

        # Load and display the Grad-CAM image (larger output)
        grad_cam_img = plt.imread(name + "_grad_cam.png")
        axes[1].imshow(grad_cam_img)
        axes[1].axis('off')  # Hide axes for the Grad-CAM image
        axes[1].set_title("Grad-CAM Visualization")

        # Show the plot
        plt.show()

        print(f"Original Image: {each_path}")
        display(Image(each_path))
        print(f"Grad-CAM Visualization: {name}_grad_cam.png")
        display(Image(name + "_grad_cam.png"))


In [4]:
# Example usage:
print("Please upload your images:")
uploaded = files.upload()

# Save the paths of the uploaded images
IMAGE_PATHS = list(uploaded.keys())
indices = [817] * len(IMAGE_PATHS)  # You can adjust this for specific class indices if necessary
grad_cam_visualization(IMAGE_PATHS, indices)


Output hidden; open in https://colab.research.google.com to view.