In [24]:
# ==============================================================================
# FINAL SCRIPT: INTERPRETABILITY ANALYSIS WITH GRAD-CAM
# ==============================================================================
# This script loads previously saved CNN and Transformer models and applies
# Grad-CAM to visualize what parts of the image each model focuses on.
#
# INSTRUCTIONS:
# 1. Ensure you have the following saved model files in your Colab session:
#    - papaya_disease_resnet50v2.keras
#    - papaya_disease_swin_transformer.keras
# 2. You will also need the 'papaya_data_split_vit' directory from the last run.
# ==============================================================================

import os
import random
import pathlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# --- 1. RE-DEFINE THE CUSTOM SWIN LAYER FOR MODEL LOADING ---
# Keras needs this definition to know how to load the saved Swin model.
# This code is a simplified version from your successful training script.
from transformers import SwinModel
import torch

class SwinBackboneLayer(tf.keras.layers.Layer):
    def __init__(self, model_name="microsoft/swin-tiny-patch4-window7-224", **kwargs):
        super(SwinBackboneLayer, self).__init__(**kwargs)
        self.swin = SwinModel.from_pretrained(model_name)
        self.swin.eval()
        for param in self.swin.parameters():
            param.requires_grad = False

    def call(self, inputs):
        def swin_forward(inp):
            input_np = inp.numpy()
            inputs_torch = torch.from_numpy(input_np).permute(0, 3, 1, 2)
            with torch.no_grad():
                outputs = self.swin(inputs_torch).last_hidden_state
            return outputs.detach().numpy()
        output = tf.py_function(func=swin_forward, inp=[inputs], Tout=tf.float32)
        output.set_shape([None, 49, 768])
        return output

    def get_config(self):
        # Implement get_config to allow model saving/loading
        config = super(SwinBackboneLayer, self).get_config()
        # You can add model_name or other parameters to the config if needed
        return config


# --- 2. SETUP AND MODEL LOADING ---
print("\n--- Section 2: Loading Saved Models ---")

# Define necessary constants
SPLIT_BASE_DIR = pathlib.Path('papaya_data_split_vit')
test_dir = SPLIT_BASE_DIR / "test"
CLASS_NAMES = sorted([item.name for item in test_dir.glob("*") if item.is_dir()])
IMG_SIZE = (224, 224)

# Load the CNN model (ResNet50V2)
cnn_model = None  # Initialize to None
try:
    cnn_model = tf.keras.models.load_model("papaya_disease_resnet50v2.keras")
    print("Successfully loaded ResNet50V2 model.")
    cnn_model.summary()  # Print model summary
except Exception as e:
    print(f"Could not load ResNet50V2 model. Error: {e}")

# Load the Transformer model (Swin)
transformer_model = None  # Initialize to None
try:
    transformer_model = tf.keras.models.load_model(
        "papaya_disease_swin_transformer.keras",
        custom_objects={"SwinBackboneLayer": SwinBackboneLayer}
    )
    print("Successfully loaded Swin Transformer model.")
    transformer_model.summary()
except Exception as e:
    print(f"Could not load Swin Transformer model. Error: {e}")

# Build and initialize the CNN model
if cnn_model:
    if not cnn_model.built:
        cnn_model.build((None, 224, 224, 3))
        print("Built ResNet50V2 model.")
    # Force initialization by passing a dummy input
    try:
        dummy_input = tf.zeros((1, 224, 224, 3))
        output = cnn_model(dummy_input)  # Call the model to initialize all layers
        print("Initialized ResNet50V2 model with dummy input.")
        print(f"CNN model output shape: {output.shape}")
        # Verify model output is defined
        print(f"CNN model has defined output: {hasattr(cnn_model, 'output')}")
        print(f"CNN model output shape (attribute): {cnn_model.output_shape}")
        # Print layer names for verification
        try:
            base_model = cnn_model.get_layer("resnet50v2")
            print("Layers in ResNet50V2 backbone:")
            for layer in base_model.layers:
                print(layer.name)
            # Verify post_relu layer
            try:
                post_relu_layer = base_model.get_layer("post_relu")
                print(f"post_relu layer found with output shape: {post_relu_layer.output_shape}")
            except ValueError as e:
                print(f"Error: post_relu layer not found in resnet50v2. Error: {e}")
        except ValueError as e:
            print(f"Error accessing resnet50v2 layer: {e}")
    except Exception as e:
        print(f"Error initializing CNN model with dummy input: {e}")
else:
    print("CNN model not loaded, skipping initialization.")

# Build and initialize the Transformer model
if transformer_model:
    if not transformer_model.built:
        transformer_model.build((None, 224, 224, 3))
        print("Built Swin Transformer model.")
    try:
        dummy_input = tf.zeros((1, 224, 224, 3))
        output = transformer_model(dummy_input / 255.0)  # Normalize for Transformer
        print("Initialized Swin Transformer model with dummy input.")
        print(f"Transformer model output shape: {output.shape}")
    except Exception as e:
        print(f"Error initializing Transformer model with dummy input: {e}")
else:
    print("Transformer model not loaded, skipping initialization.")

# --- 3. GRAD-CAM IMPLEMENTATION ---
print("\n--- Section 3: Defining Grad-CAM Logic ---")

def get_img_array(img_path, size):
    """Loads and preprocesses a single image."""
    img = tf.keras.utils.load_img(img_path, target_size=size)
    array = tf.keras.utils.img_to_array(img)
    array = np.expand_dims(array, axis=0)
    return array

# --- Transformer Visualization ---
if transformer_model and last_layer_name_transformer:
    img_array_transformer = get_img_array(img_path_str, size=IMG_SIZE)
    # The Swin Transformer expects inputs normalized to [0, 1]
    img_array_transformer_norm = img_array_transformer / 255.0

    # Get the individual layers of the model
    swin_backbone_layer = transformer_model.get_layer(last_layer_name_transformer)
    global_pool_layer = transformer_model.get_layer("global_average_pooling1d")
    dense_layer = transformer_model.get_layer("dense") # Assuming the last dense layer is named 'dense'

    # Use GradientTape to manually watch the backbone's output
    with tf.GradientTape() as tape:
        # 1. Get the output from the custom Swin layer
        transformer_output = swin_backbone_layer(img_array_transformer_norm)
        # 2. IMPORTANT: Explicitly watch this tensor
        tape.watch(transformer_output)
        # 3. Manually pass the output through the rest of the model
        x = global_pool_layer(transformer_output)
        # The dropout layer is often skipped during inference/explanation
        preds = dense_layer(x)
        
        # Get the specific prediction for the top class
        pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # Calculate the gradient of the top predicted class w.r.t the backbone's output
    grads = tape.gradient(class_channel, transformer_output)

    if grads is not None:
        # --- FIX #2: Corrected averaging axis for Grad-CAM ---
        # Pool the gradients across the patch dimension (49) to get a single weight per feature channel (768)
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1))

        transformer_output = transformer_output[0]
        # Multiply each patch's feature vector by its corresponding gradient-weight
        heatmap_transformer = transformer_output @ pooled_grads[..., tf.newaxis]
        heatmap_transformer = tf.squeeze(heatmap_transformer)
        
        # Normalize and reshape the heatmap
        heatmap_transformer = tf.maximum(heatmap_transformer, 0) / tf.math.reduce_max(heatmap_transformer)
        heatmap_transformer = tf.reshape(heatmap_transformer, (7, 7))
        heatmap_transformer = heatmap_transformer.numpy()

        superimposed_transformer = save_and_display_gradcam(img_path_str, heatmap_transformer)
        transformer_pred_idx = np.argmax(transformer_model.predict(img_array_transformer_norm)[0])
        
        axes[2].imshow(superimposed_transformer)
        axes[2].set_title(f"Transformer (Swin)\nPrediction: {CLASS_NAMES[transformer_pred_idx]}")
        axes[2].axis("off")
    else:
        print("❌ Gradient for Transformer was None. Could not generate heatmap.")
        axes[2].text(0.5, 0.5, 'Failed to generate Transformer heatmap', ha='center', va='center')
        axes[2].set_title("Transformer (Swin)")
        axes[2].axis("off")
else:
    axes[2].text(0.5, 0.5, 'Transformer model not loaded', ha='center', va='center')
    axes[2].set_title("Transformer (Swin)")
    axes[2].axis("off")

def save_and_display_gradcam(img_path, heatmap, cam_path="cam.jpg", alpha=0.4):
    """Superimposes the heatmap on the original image."""
    img = tf.keras.utils.load_img(img_path)
    img = tf.keras.utils.img_to_array(img)
    heatmap = np.uint8(255 * heatmap)
    jet = cm.get_cmap("jet")
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]
    jet_heatmap = tf.keras.utils.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = tf.keras.utils.img_to_array(jet_heatmap)
    superimposed_img = jet_heatmap * alpha + img
    superimposed_img = tf.keras.utils.array_to_img(superimposed_img)
    superimposed_img.save(cam_path)
    return superimposed_img

# --- 4. GENERATE AND DISPLAY VISUALIZATIONS ---
print("\n--- Section 4: Generating and Displaying Visualizations ---")

# Find a few sample images to test
sample_image_paths = []
for class_name in CLASS_NAMES:
    try:
        sample_image_paths.append(list((test_dir / class_name).glob("*.jpg"))[0])
    except IndexError:
        print(f"Warning: No images found for class {class_name} in the test set.")

# For ResNet50V2, we find the last convolutional layer name
last_conv_layer_name_cnn = None
if cnn_model:
    # Find the ResNet base model inside the loaded model
    base_model = None
    for layer in cnn_model.layers:
        if isinstance(layer, tf.keras.Model):  # ResNet50V2 is itself a Model
            base_model = layer
            break

    if base_model:
        for layer in reversed(base_model.layers):
            if len(layer.output.shape) == 4:  # Conv layer
                last_conv_layer_name_cnn = layer.name
                break
    print(f"Found last conv layer for CNN: {last_conv_layer_name_cnn}")

# For Swin Transformer, we use the output of our custom backbone layer
# Note: Grad-CAM on transformers is an active area of research. This is an approximation.
# --- Find the Transformer backbone layer name dynamically ---
last_layer_name_transformer = None # Initialize to None

if transformer_model:
    # Loop through the layers of the loaded transformer model
    for layer in transformer_model.layers:
        # Check if the current layer is your custom SwinBackboneLayer
        if isinstance(layer, SwinBackboneLayer):
            last_layer_name_transformer = layer.name # Get its actual name
            break # Exit the loop once found

# Check if the layer was found before proceeding
if last_layer_name_transformer:
    print(f"✅ Successfully found Swin backbone layer: '{last_layer_name_transformer}'")
else:
    print("❌ Warning: Could not find SwinBackboneLayer in the transformer model.")

# Generate and plot the comparisons
for img_path in sample_image_paths:
    img_path_str = str(img_path)
    true_class = os.path.basename(os.path.dirname(img_path_str))
    print(f"\n--- Processing image: {os.path.basename(img_path_str)} (True Class: {true_class}) ---")

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    original_img = tf.keras.utils.load_img(img_path_str, target_size=IMG_SIZE)
    axes[0].imshow(original_img)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    # --- CNN Visualization ---
    if cnn_model and last_conv_layer_name_cnn:
        img_array_cnn = get_img_array(img_path_str, size=IMG_SIZE)
        # Preprocess for ResNet50V2 (it has a built-in rescaling layer)
        print(f"Generating Grad-CAM for image: {os.path.basename(img_path_str)}")
        heatmap_cnn = make_gradcam_heatmap(img_array_cnn, cnn_model, last_conv_layer_name_cnn)
        if heatmap_cnn is not None:
            try:
                superimposed_cnn = save_and_display_gradcam(img_path_str, heatmap_cnn)
                cnn_pred_idx = np.argmax(cnn_model.predict(img_array_cnn)[0])
                axes[1].imshow(superimposed_cnn)
                axes[1].set_title(f"CNN (ResNet50V2)\nPrediction: {CLASS_NAMES[cnn_pred_idx]}")
                axes[1].axis("off")
            except Exception as e:
                print(f"Error in save_and_display_gradcam: {e}")
                axes[1].text(0.5, 0.5, f'Failed to process CNN heatmap: {e}', ha='center', va='center')
                axes[1].set_title("CNN (ResNet50V2)")
                axes[1].axis("off")
        else:
            print("Grad-CAM heatmap is None, skipping visualization.")
            axes[1].text(0.5, 0.5, 'Failed to generate CNN heatmap', ha='center', va='center')
            axes[1].set_title("CNN (ResNet50V2)")
            axes[1].axis("off")
    else:
        print("CNN model or last_conv_layer_name_cnn not available.")
        axes[1].text(0.5, 0.5, 'CNN model not loaded or layer not found', ha='center', va='center')
        axes[1].set_title("CNN (ResNet50V2)")
        axes[1].axis("off")
        
    # --- Transformer Visualization ---
    if transformer_model:
        img_array_transformer = get_img_array(img_path_str, size=IMG_SIZE)
        # The Swin Transformer expects inputs normalized to [0, 1]
        img_array_transformer_norm = img_array_transformer / 255.0

        # For transformers, we use a different approach for the heatmap
        grad_model_transformer = tf.keras.models.Model(
            transformer_model.inputs,
            [transformer_model.get_layer(last_layer_name_transformer).output, transformer_model.output]
        )
        with tf.GradientTape() as tape:
            transformer_output, preds = grad_model_transformer(img_array_transformer_norm)
            pred_index = tf.argmax(preds[0])
            class_channel = preds[:, pred_index]
        grads = tape.gradient(class_channel, transformer_output)
        pooled_grads = tf.reduce_mean(grads, axis=(0, 2))
        transformer_output = transformer_output[0]
        heatmap_transformer = transformer_output @ pooled_grads[..., tf.newaxis]
        heatmap_transformer = tf.squeeze(heatmap_transformer)
        heatmap_transformer = tf.maximum(heatmap_transformer, 0) / tf.math.reduce_max(heatmap_transformer)
        # The output is a sequence of 49 patches, so we reshape it to a 7x7 grid
        heatmap_transformer = tf.reshape(heatmap_transformer, (7, 7))
        heatmap_transformer = heatmap_transformer.numpy()

        superimposed_transformer = save_and_display_gradcam(img_path_str, heatmap_transformer)
        transformer_pred_idx = np.argmax(transformer_model.predict(img_array_transformer_norm)[0])
        axes[2].imshow(superimposed_transformer)
        axes[2].set_title(f"Transformer (Swin)\nPrediction: {CLASS_NAMES[transformer_pred_idx]}")
        axes[2].axis("off")
    else:
        axes[2].text(0.5, 0.5, 'Transformer model not loaded', ha='center', va='center')
        axes[2].set_title("Transformer (Swin)")
        axes[2].axis("off")

    plt.tight_layout()
    plt.show()

print("\n--- Interpretability Analysis Finished ---")


--- Section 2: Loading Saved Models ---
Successfully loaded ResNet50V2 model.


Successfully loaded Swin Transformer model.


Initialized ResNet50V2 model with dummy input.
CNN model output shape: (1, 5)
CNN model has defined output: False
CNN model output shape (attribute): (None, 5)
Layers in ResNet50V2 backbone:
input_layer
conv1_pad
conv1_conv
pool1_pad
pool1_pool
conv2_block1_preact_bn
conv2_block1_preact_relu
conv2_block1_1_conv
conv2_block1_1_bn
conv2_block1_1_relu
conv2_block1_2_pad
conv2_block1_2_conv
conv2_block1_2_bn
conv2_block1_2_relu
conv2_block1_0_conv
conv2_block1_3_conv
conv2_block1_out
conv2_block2_preact_bn
conv2_block2_preact_relu
conv2_block2_1_conv
conv2_block2_1_bn
conv2_block2_1_relu
conv2_block2_2_pad
conv2_block2_2_conv
conv2_block2_2_bn
conv2_block2_2_relu
conv2_block2_3_conv
conv2_block2_out
conv2_block3_preact_bn
conv2_block3_preact_relu
conv2_block3_1_conv
conv2_block3_1_bn
conv2_block3_1_relu
conv2_block3_2_pad
conv2_block3_2_conv
conv2_block3_2_bn
conv2_block3_2_relu
max_pooling2d
conv2_block3_3_conv
conv2_block3_out
conv3_block1_preact_bn
conv3_block1_preact_relu
conv3_block1_

ValueError: No such layer: swin_backbone_layer_21. Existing layers are: ['input_layer', 'swin_backbone_layer_22', 'global_average_pooling1d', 'dropout', 'dense'].

In [19]:
import tensorflow as tf
cnn_model = tf.keras.models.load_model("papaya_disease_resnet50v2.keras")
cnn_model.summary()
dummy_input = tf.zeros((1, 224, 224, 3))
output = cnn_model(dummy_input)
print(f"Output shape: {output.shape}")

Output shape: (1, 5)
