# XAI Visualization for Skin Lesion Classification Model

## Import Required Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from scipy.ndimage import zoom
import random

In [None]:
sys.path.append("../src")

import config
from dataset import load_test_metadata, prepare_test_data, create_dataset
from utils import MulticlassROC_AUC

## Load the Model

In [None]:
# Path to the model
model_path = os.path.join('..', config.BEST_MODEL_PATH)
print(f"Loading model from: {model_path}")

# Define custom objects if needed
custom_objects = {'MulticlassROC_AUC': MulticlassROC_AUC}

model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)
print(f"Model loaded successfully. Input shape: {model.input_shape}")


In [None]:
#model.summary()

## Load Test Images

In [None]:
def load_and_preprocess_image(image_path, target_size=None):
    """Load and preprocess an image for model inference"""
    # Read image
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Could not read image: {image_path}")
    
    # Convert from BGR to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize if target_size is specified
    if target_size is not None:
        img = cv2.resize(img, (target_size[1], target_size[0]))
    
    # Normalize to 0-1
    img = img.astype(np.float32) / 255.0
    
    return img

In [None]:
# Define class names for the skin lesion classes
fallback_class_names = {
    0: 'akiec', # Actinic Keratosis
    1: 'bcc', # Basal Cell Carcinoma
    2: 'bkl', # Benign Keratosis
    3: 'df', # Dermatofibroma
    4: 'mel', # Melanoma
    5: 'nv', # Melanocytic Nevus
    6: 'vasc' #Vascular Lesion
}

# Load test images
test_metadata_df = load_test_metadata(config.TEST_METADATA_FILE)
test_image_paths, test_labels, class_names = prepare_test_data(test_metadata_df)
print(f"Found {len(test_image_paths)} test images with {len(class_names)} classes")

# Select a subset of images to test
n = 10
num_samples = min(n, len(test_image_paths))
sample_indices = random.sample(range(len(test_image_paths)), num_samples)
sample_paths = [test_image_paths[i] for i in sample_indices]
sample_labels = [test_labels[i] for i in sample_indices if i < len(test_labels)]

# Display selected images
print(f"Selected {len(sample_paths)} images: {sample_paths}")

plt.figure(figsize=(15, 3))
for i, path in enumerate(sample_paths):
    try:
        img = load_and_preprocess_image(path)
        plt.subplot(1, len(sample_paths), i+1)
        plt.imshow(img)
        plt.title(f"Image {i+1}, class {fallback_class_names[test_labels[i]]}")
        plt.axis('off')
    except Exception as e:
        print(f"Error displaying image {path}: {e}")
plt.tight_layout()
plt.show()

## XAI

In [None]:
def extract_conv_output(model, img_array, backbone_name, layer_name):
    """
    Extract the convolutional output from a specific layer within the backbone.
    
    Args:
        model: The full model
        img_array: Input image array (preprocessed and with batch dimension)
        backbone_name: Name of the backbone model (e.g., 'densenet121')
        layer_name: Name of the convolutional layer to extract (e.g., 'conv5_block16_concat')
        
    Returns:
        Convolutional output from the specified layer
    """    
    # Access the backbone
    backbone = model.get_layer(backbone_name)
    
    # Create a temporary model with just the backbone
    temp_backbone_model = tf.keras.models.Model(
        inputs=backbone.input,
        outputs=backbone.get_layer(layer_name).output
    )
    
    # Use the backbone model to get the output from the desired layer
    return temp_backbone_model.predict(img_array)

In [None]:
def calculate_heatmap(img, conv_output, alpha=0.4):
    """
    Compute heatmaps from convolutional output.
    
    Args:
        img: Original image (without batch size)
        conv_output: Convolutional output tensor (with batch size)
        alpha: Opacity of heatmap in overlay (0-1)
    
    Returns:
        tuple: (heatmap_avg, heatmap_normalized, overlay) - The mean heatmap,
        the normalized heatmap and the overlay image
    """
    # Extract the first (and only) batch
    feature_map = conv_output[0]
    
    # Compute the mean heatmap of all channels
    heatmap_avg = np.mean(feature_map, axis=-1)
    
    # Resize the heatmap to the dimensions of the original image
    heatmap_resized = cv2.resize(heatmap_avg, (img.shape[1], img.shape[0]))
    
    # Normalize the heatmap
    heatmap_normalized = (heatmap_resized - np.min(heatmap_resized)) / ( 
    np.max(heatmap_resized) - np.min(heatmap_resized) + 1e-8) # Add epsilon to avoid division by zero 
    
    # Convert to colormap and apply overlay 
    heatmap_colorized = plt.cm.jet(heatmap_normalized)[:, :, :3] # Exclude alpha channel 
    overlay = (1 - alpha) * img + alpha * heatmap_colorized # Blend 
    overlay = overlay / np.max(overlay) # Normalize to ensure values between 0 and 1
        
    return heatmap_avg, heatmap_normalized, overlay

def visualize_conv_heatmap(img, conv_output, pred_class=None, true_class=None, 
                          show_channels=False, num_channels=16, alpha=0.4, 
                          class_names=None):
    """
    Displays the convolutional heatmap overlaid on the original image.
    
    Args:
        img: Original image (without batch size)
        conv_output: Convolutional output tensor (with batch size)
        pred_class: Index of the predicted class (optional)
        true_class: Index of the original class (optional)
        show_channels: Whether to show individual channels
        num_channels: Number of individual channels to show (if show_channels is True)
        alpha: Opacity of the heatmap in the overlay (0-1)
        class_names: Dictionary of class labels
    
    Returns:
        tuple: (heatmap_avg, overlay) - The average heatmap and the overlay image
    """
    # Calculate heatmaps
    heatmap_avg, heatmap_normalized, overlay = calculate_heatmap(img, conv_output, alpha)
    
    # Prepare the title with the above and original classes, if available
    title_suffix = ""
    if true_class is not None and class_names is not None:
        true_class_name = class_names.get(true_class, f"Class {true_class}")
        title_suffix += f" | True: {true_class_name}"
    
    if pred_class is not None and class_names is not None:
        pred_class_name = class_names.get(pred_class, f"Class {pred_class}")
        title_suffix += f" | Pred: {pred_class_name}"
    
    plt.figure(figsize=(15, 8))
    
    # Show original image
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title(f'Original Image{title_suffix}')
    plt.axis('off')
    
    # Show average heatmap
    plt.subplot(1, 3, 2)
    plt.imshow(heatmap_avg, cmap='viridis')
    plt.title(f'Average Heatmap of All Channels{title_suffix}')
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
    
    # Show heatmap overlay on original image
    plt.subplot(1, 3, 3)
    plt.imshow(overlay)
    plt.title(f'Heatmap Overlay on Image{title_suffix}')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Optionally display specific channels
    if show_channels:
        plt.figure(figsize=(16, 8))
        channels_to_show = min(num_channels, conv_output.shape[-1])
        grid_size = int(np.ceil(np.sqrt(channels_to_show)))
        
        for i in range(channels_to_show):
            plt.subplot(grid_size, grid_size, i+1)
            plt.imshow(conv_output[0, :, :, i], cmap='viridis')
            plt.title(f'Channel {i}')
            plt.axis('off')
        
        plt.suptitle(f'Individual Channels{title_suffix}')
        plt.tight_layout()
        plt.show()
    
    return heatmap_avg, overlay

In [None]:
# 1. Load and preprocess the image
for i, img_path in enumerate(sample_paths):
    img = load_and_preprocess_image(img_path, target_size=model.input_shape[1:3])
    img_array = np.expand_dims(img, axis=0)
    
    # 2. Extract the convolutional output
    backbone_name = 'densenet121'
    last_conv_layer_name = 'conv5_block16_2_conv' 
    conv_output = extract_conv_output(model, img_array, backbone_name, last_conv_layer_name)
    
    # Get prediction
    pred = model.predict(img_array)
    pred_class = np.argmax(pred)
    true_class = sample_labels[i]
    
    # 3. Show heatmap with predicted and original class
    heatmap, overlay = visualize_conv_heatmap(
        img, 
        conv_output, 
        pred_class=pred_class,
        true_class=true_class,
        show_channels=True, 
        num_channels=conv_output.shape[3], 
        alpha=0.4,
        class_names=fallback_class_names
    )