# XAI Analysis for Handwritten Signature Verification

This notebook performs Explainable AI (XAI) analysis on a trained model for handwritten signature verification. The goal is to understand which parts of the input image the model focuses on when making predictions (Real vs. Forged signatures).

## XAI Methods Used
- **Grad-CAM**: Highlights regions in the image that contribute most to the model's prediction by using gradients from the final convolutional layer.
- **LIME**: Identifies important superpixels in the image by perturbing the input and observing changes in the prediction.
- **SHAP**: Assigns importance scores to each pixel by evaluating the model's output with and without certain pixels, using a background dataset.

## Steps
1. Install necessary libraries
2. Import libraries and define the model class
3. Configure paths and settings
4. Load the trained model and input image
5. Generate explanations using Grad-CAM, LIME, and SHAP
6. Visualize the explanations in a combined plot

## --- 1. Install necessary libraries ---

In [None]:
print("--- Installing XAI libraries ---")
!pip install lime shap --quiet
!pip install scikit-image --quiet # Needed by LIME
!pip install opencv-python-headless --quiet # For cv2 used in visualizations
!pip install -U albumentations --quiet
!pip install grad-cam==1.4.8 --quiet # Installing pytorch-grad-cam
print("--- Installations complete ---")

## --- 2. Imports and Model Class Definition ---

### --- Importing libraries ---

In [None]:
print("--- Importing libraries ---")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms # For basic tensor conversion if needed
import timm
import os
import numpy as np
from PIL import Image, UnidentifiedImageError
import matplotlib
import matplotlib.pyplot as plt
import cv2 # For visualization
import time
import copy
import logging
import sys
import json
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.auto import tqdm # For progress bars
import random # For random debug logging in CustomModel 
import matplotlib.cm as cm # For colormaps

# XAI Libraries
import lime
from lime import lime_image
import shap
from skimage.segmentation import mark_boundaries # For LIME visualization

### --- Setting up Logging ---

In [None]:
print("--- Setting up logging ---")
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - [%(levelname)s] - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True
)

### --- Defining CustomModel Class  ---

In [None]:
class CustomModel(nn.Module):
    """
    Flexible model wrapper using timm or HuggingFace Transformers for the base
    and adding a custom classifier head. Allows selective unfreezing of base layers.
    Includes automatic feature dimension verification and adaptation.
    """
    def __init__(self, model_name, dense_units, dropout, pretrained=True, unfreeze_layers=0):
        super(CustomModel, self).__init__()
        self.model_name = model_name
        self.base_model = None
        reported_features = 0  # Features reported by the base model

        logger.info(f"Initializing CustomModel: Base='{model_name}', DenseUnits={dense_units}, Dropout={dropout:.2f}, Unfreeze={unfreeze_layers}")

        try:
            # --- Load Base Model ---
            if model_name == "ViT_Base":
                hf_model_name = 'google/vit-base-patch16-224'
                # Load without pretrained weights initially if loading state_dict later
                self.base_model = ViTModel.from_pretrained(
                    hf_model_name,
                    add_pooling_layer=False,
                    ignore_mismatched_sizes=True
                )
                reported_features = self.base_model.config.hidden_size
            else:
                # Use timm for other models
                timm_model_name_map = {
                    "MobileNetV3_Large": "mobilenetv3_large_100.miil_in21k_ft_in1k",
                }
                if model_name not in timm_model_name_map:
                    raise ValueError(f"Model name '{model_name}' not found in timm map or not supported.")

                timm_name = timm_model_name_map[model_name]
                
                # Load structure only, weights will come from state_dict
                self.base_model = timm.create_model(timm_name, pretrained=False, num_classes=0)
                # Determine features from forward pass
                self.base_model.eval()
                with torch.no_grad():
                    dummy_input = torch.randn(1, 3, 224, 224)
                    features = self.base_model(dummy_input)
                    reported_features = features.shape[1]
                    
                    logger.info(f"Loaded '{timm_name}' structure from timm. Reported features: {reported_features}")

            # --- Determine Feature Dimensions (Simplified - assuming trained model worked) ---
            # For XAI, we rely on the feature dimension being correct from training
            num_features = reported_features # Use the reported dimension

            # --- Parameter Freezing/Unfreezing (NOT needed when loading state_dict) ---
            # The requires_grad status is not saved in the state_dict.
            # We don't need to freeze/unfreeze here; just build the matching structure.

            # --- Define Classifier Head ---
            self.classifier = nn.Sequential(
                nn.Linear(num_features, dense_units),
                nn.ReLU(),
                nn.BatchNorm1d(dense_units), # BatchNorm is important
                nn.Dropout(dropout),
                nn.Linear(dense_units, 1),  # Output 1 logit for binary classification
            )

        except Exception as e:
            logger.error(f"Error initializing model structure '{model_name}': {e}", exc_info=True)
            raise

    def forward(self, x):
        if self.model_name == "ViT_Base":
            features = self.base_model(x).last_hidden_state[:, 0]
        else:
            features = self.base_model(x)
        output = self.classifier(features)
        return output

print("--- Library imports and Model class definition complete ---")

### --- Helper Functions ---

### --- Helper Function for Heatmap Generation ---

In [None]:
def generate_heatmap_overlay(image, heatmap_data, colormap='viridis', alpha=0.6):
    """
    Generates a heatmap overlay on an image.

    Args:
        image (np.ndarray): Base image (H, W, C) uint8 [0, 255] or float [0, 1].
        heatmap_data (np.ndarray): Heatmap weights (H, W), normalized preferred.
        colormap (str): Name of the matplotlib colormap.
        alpha (float): Transparency of the heatmap overlay.

    Returns:
        np.ndarray: Image with heatmap overlay.
    """
    if image.dtype == np.float32 or image.dtype == np.float64:
        image = (image * 255).astype(np.uint8) # Convert to uint8 if float

    if image.ndim == 2: # Grayscale image? Convert to BGR for overlay
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    elif image.shape[2] == 3 and image.dtype == np.uint8:
         # Ensure it's BGR if it came directly from PIL/OpenCV load
         # If it came from matplotlib display, it might be RGB. Assume BGR typical for cv2 ops.
         # If colors look swapped later, convert: image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
         pass # Assuming BGR or compatible
    elif image.shape[2] == 4: # RGBA
         image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)


    # Normalize heatmap data to 0-1 if not already
    if np.nanmax(heatmap_data) > 1.0 or np.nanmin(heatmap_data) < 0.0:
        heatmap_data = np.clip(heatmap_data, 0, np.nanmax(heatmap_data)) # Clip negative for positive-only heatmaps
        max_val = np.nanmax(heatmap_data)
        if max_val > 1e-6: # Avoid division by zero
             heatmap_data = heatmap_data / max_val
        else:
             heatmap_data = np.zeros_like(heatmap_data) # Set to zero if max is near zero

    # Get colormap function
    cmap = matplotlib.colormaps.get_cmap(colormap)
    # Apply colormap (returns RGBA float 0-1)
    heatmap_colored = cmap(heatmap_data)[:, :, :3] # Take only RGB
    heatmap_colored = (heatmap_colored * 255).astype(np.uint8) # Convert to uint8

    # Ensure heatmap is BGR if image is BGR
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_RGB2BGR)

    # Blend the heatmap and the original image
    # cv2.addWeighted requires same size and type
    if image.shape != heatmap_colored.shape:
         # This shouldn't happen if heatmap_data was (H, W) and image was (H, W, 3)
         logger.error(f"Shape mismatch: Image {image.shape}, Heatmap {heatmap_colored.shape}. Cannot overlay.")
         return image # Return original image on error

    overlay = cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)

    return overlay

### --- Helper Function to Replace In-Place Hardswish ---

In [None]:
def replace_hardswish(module):
    """
    Recursively replace all nn.Hardswish modules with inplace=False.
    """
    for name, child in module.named_children():
        if isinstance(child, nn.Hardswish):
            setattr(module, name, nn.Hardswish(inplace=False))
            # logger.info(f"Replaced Hardswish at {name} with inplace=False")
        else:
            replace_hardswish(child)

### --- Helper Function to Load Images ---

In [None]:
# --- Helper Function to Load Images ---
def load_image(image_path, transform_type='tensor'):
    """
    Load and preprocess an image.
    
    Args:
        image_path (str): Path to the image.
        transform_type (str): 'tensor' for model input, 'numpy' for LIME/SHAP.
    
    Returns:
        torch.Tensor or np.ndarray: Preprocessed image.
    """
    try:
        image = Image.open(image_path).convert('RGB')
        logger.info(f"Loaded image: {image_path}, size: {image.size}")
    except UnidentifiedImageError as e:
        logger.error(f"Cannot identify image file {image_path}: {e}")
        raise
    
    if transform_type == 'tensor':
        transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image_tensor = transform(image)
        logger.info(f"Transformed image to tensor: shape={image_tensor.shape}")
        return image_tensor
    else:  # 'numpy' for LIME/SHAP
        image_np = np.array(image)
        transform = A.Compose([
            A.Resize(*IMG_SIZE),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        augmented = transform(image=image_np)
        image_np = augmented['image']
        logger.info(f"Transformed image to numpy: shape={image_np.shape}")
        return image_np

## --- 3. Configuration ---

### --- Model Configuration ---

In [None]:
MODEL_NAME = "MobileNetV3_Large"    
# Hyperparameters from the Optuna results 
DENSE_UNITS = 768 
DROPOUT = 0.45     
UNFREEZE_LAYERS = 3 

### --- Paths --- 

In [None]:
MODEL_PATH = "/kaggle/input/checkpoints-xai/MobileNetV3_Large_best_val_loss.pth"
IMAGE_PATH = "/kaggle/input/handwritten-signature-verification/data/data/forged/52F192AB-A654-4778-861A-C81E555D4656.jpg/2__52F192AB-A654-4778-861A-C81E555D4656.jpg.jpg"
OUTPUT_DIR = "/kaggle/working/xai_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# SHAP Background Data
BACKGROUND_DATA_DIR = "/kaggle/input/handwritten-signature-verification/data/data/real" # Using 'real' signatures as background
N_SHAP_BACKGROUND_SAMPLES = 50 # Number of background images for SHAP

### --- Other Settings ---

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

# Image size 
model_img_sizes = {
    "MobileNetV3_Large": (224, 224)
}

if MODEL_NAME not in model_img_sizes:
    raise ValueError(f"Image size for model '{MODEL_NAME}' not defined in 'model_img_sizes'.")
IMG_SIZE = model_img_sizes[MODEL_NAME]
logger.info(f"Using Image Size: {IMG_SIZE} for model {MODEL_NAME}")

## --- 4. Load Model and Image ---

### --- Load the Trained Model ---

In [None]:
logger.info(f"--- Loading model structure for: {MODEL_NAME} ---")
try:
    model = CustomModel(
        model_name=MODEL_NAME,
        dense_units=DENSE_UNITS,
        dropout=DROPOUT,
        pretrained=False, # Weights loaded from state_dict
        unfreeze_layers=UNFREEZE_LAYERS 
    )   
    logger.info(f"Model structure created successfully.")

    logger.info(f"--- Loading state dict from: {MODEL_PATH} ---")
    if not os.path.exists(MODEL_PATH):
         raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")

    # Load state dict 
    state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)

    # Handle potential keys mismatch
    if all(key.startswith('module.') for key in state_dict.keys()):
        logger.info("Removing 'module.' prefix from state_dict keys.")
        state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()}

    # Load the state dictionary
    model.load_state_dict(state_dict)
    logger.info("State dict loaded successfully.")
    
    # Modify Hardswish to non-inplace
    if MODEL_NAME == "MobileNetV3_Large":
        logger.info("Replacing all in-place Hardswish activations in MobileNetV3_Large.")
        replace_hardswish(model.base_model)

    model.to(DEVICE)
    model.eval() # Set model to evaluation mode
    logger.info("Model loaded and set to evaluation mode.")

except Exception as e:
    logger.error(f"Failed to load model: {e}", exc_info=True)
    raise # Stop execution if model loading fails

### --- Load the target image ---

In [None]:
logger.info(f"--- Loading target image for explanation: {IMAGE_PATH} ---")
target_image_tensor = load_image(IMAGE_PATH, transform_type='tensor')
target_image_numpy = load_image(IMAGE_PATH, transform_type='numpy') # For LIME visualization
logger.info("Target image loaded successfully.")

if target_image_tensor is None or target_image_numpy is None:
    logger.error("Failed to load the target image. Exiting.")
    exit()

## --- 5. Explanation Generation ---

### --- Grad-CAM Explanation ---

In [None]:
logger.info("--- Starting Grad-CAM Explanation (using pytorch-grad-cam) ---")
start_gradcam_time = time.time()
try:
    from pytorch_grad_cam import GradCAM
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

    # Define a custom target for binary classification with single logit output
    class BinaryClassifierOutputTarget:
        def __init__(self, target_class):
            self.target_class = target_class  # 0 for "Real", 1 for "Forged"

        def __call__(self, model_output):
            # model_output is [batch_size, 1] (sibngle logit)
            if self.target_class == 1:  # "Forged" -> maximize the sigmoid output
                return model_output  # Gradient w.r.t. the logit directly
            else:  # "Real" -> minimize the sigmoid output (maximize 1 - sigmoid)
                return -model_output  # Negate the logit to compute gradient for "Real"

    # Prepare input and predict class
    input_tensor = target_image_tensor.unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        output = model(input_tensor)
        output_prob = torch.sigmoid(output).item()
    predicted_class_idx = 1 if output_prob > 0.5 else 0
    predicted_class_name = "Forged" if predicted_class_idx == 1 else "Real"
    logger.info(f"Predicted class: {predicted_class_name} (Probability: {output_prob:.4f})")

    # Select target layer for MobileNetV3_Large
    target_layers = [model.base_model.blocks[2][0].conv_pwl]  # Earlier layer
    grad_cam = GradCAM(model=model, target_layers=target_layers)

    # Use the custom target for binary classification
    targets = [BinaryClassifierOutputTarget(predicted_class_idx)]
    cam = grad_cam(input_tensor=input_tensor, targets=targets)
    cam = cam[0]  # Shape: (H, W)

    gradcam_heatmap_overlay = generate_heatmap_overlay(
        target_image_numpy,
        cam,
        colormap='jet',
        alpha=0.6
    )
    gradcam_output_path = os.path.join(OUTPUT_DIR, f"{os.path.basename(IMAGE_PATH).split('.')[0]}_gradcam.png")
    cv2.imwrite(gradcam_output_path, gradcam_heatmap_overlay)
    logger.info(f"Grad-CAM visualization saved to {gradcam_output_path}")
except Exception as e:
    logger.error(f"Error generating Grad-CAM with pytorch-grad-cam: {e}", exc_info=True)
    logger.info("Falling back to custom Grad-CAM...")
    try:
        target_layers = [model.base_model.blocks[2][0].conv_pwl]  # Ensure it's a list
        logger.info("Using target layer: blocks[2][0].conv_pwl")
        grad_cam = GradCAM(model=model, target_layers=target_layers)

        # Use the same custom target for the fallback
        targets = [BinaryClassifierOutputTarget(predicted_class_idx)]
        cam = grad_cam(input_tensor=input_tensor, targets=targets)
        cam = cam[0]  # Shape: (H, W)

        gradcam_heatmap_overlay = generate_heatmap_overlay(
            target_image_numpy,
            cam,
            colormap='jet',
            alpha=0.6
        )
        gradcam_output_path = os.path.join(OUTPUT_DIR, f"{os.path.basename(IMAGE_PATH).split('.')[0]}_gradcam_fallback.png")
        cv2.imwrite(gradcam_output_path, gradcam_heatmap_overlay)
        logger.info(f"Grad-CAM (fallback) visualization saved to {gradcam_output_path}")
    except Exception as e:
        logger.error(f"Error generating custom Grad-CAM: {e}", exc_info=True)
        gradcam_heatmap_overlay = target_image_numpy
        gradcam_output_path = None

gradcam_duration = time.time() - start_gradcam_time
logger.info(f"Grad-CAM explanation generated in {gradcam_duration:.2f} seconds.")

### --- Grad-CAM visualization ---

In [None]:
logger.info("--- Displaying Grad-CAM Heatmap ---")
try:
    if gradcam_output_path and os.path.exists(gradcam_output_path):
        # Load the saved Grad-CAM heatmap overlay
        gradcam_img = cv2.imread(gradcam_output_path)
        gradcam_img_rgb = cv2.cvtColor(gradcam_img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for matplotlib
        
        # Display the heatmap
        plt.figure(figsize=(6, 6))
        plt.imshow(gradcam_img_rgb)
        plt.title(f"Grad-CAM Heatmap (Predicted: {predicted_class_name})")
        plt.axis('off')
        plt.show()
        logger.info("Grad-CAM heatmap displayed successfully.")
    else:
        logger.warning("Grad-CAM output path does not exist. Cannot display heatmap.")
except Exception as e:
    logger.error(f"Error displaying Grad-CAM heatmap: {e}", exc_info=True)

### --- LIME Explanation ---


In [None]:
logger.info("--- Starting LIME Explanation ---")
start_lime_time = time.time()
try:
    # Define prediction function for LIME
    def predict_fn(images):
        images_tensor = torch.stack([
            transforms.ToTensor()(img) for img in images
        ]).to(DEVICE)
        # Normalize as per training
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        images_tensor = normalize(images_tensor)
        model.eval()
        with torch.no_grad():
            outputs = model(images_tensor)
            probs = torch.sigmoid(outputs).cpu().numpy()
        return np.hstack([1 - probs, probs])  # [p(Real), p(Forged)]

    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(
        target_image_numpy,
        predict_fn,
        top_labels=2,
        hide_color=0,
        num_samples=1000
    )

    # Get LIME visualization for the predicted class
    temp, mask = explanation.get_image_and_mask(
        label=predicted_class_idx,
        positive_only=True,
        num_features=5,
        hide_rest=False
    )
    lime_heatmap_overlay = mark_boundaries(target_image_numpy / 255.0, mask)
    lime_heatmap_overlay = (lime_heatmap_overlay * 255).astype(np.uint8)

    lime_output_path = os.path.join(OUTPUT_DIR, f"{os.path.basename(IMAGE_PATH).split('.')[0]}_lime.png")
    cv2.imwrite(lime_output_path, cv2.cvtColor(lime_heatmap_overlay, cv2.COLOR_RGB2BGR))
    logger.info(f"LIME visualization saved to {lime_output_path}")
except Exception as e:
    logger.error(f"Error generating LIME explanation: {e}", exc_info=True)
    lime_heatmap_overlay = target_image_numpy
    lime_output_path = None

lime_duration = time.time() - start_lime_time
logger.info(f"LIME explanation generated in {lime_duration:.2f} seconds.")

### --- LIME Visualization ---

In [None]:
# --- Visualize LIME results ---
# We are interested in explaining the predicted class. Let's get the model's prediction first.
with torch.no_grad():
    output_logit = model(target_image_tensor.unsqueeze(0).to(DEVICE))
    output_prob = torch.sigmoid(output_logit).item()
predicted_class_idx = 1 if output_prob > 0.5 else 0
predicted_class_name = "Forged" if predicted_class_idx == 1 else "Real"
logger.info(f"Model prediction: {predicted_class_name} (Probability: {output_prob:.4f})")

# Function to safely normalize image data for visualization
def normalize_for_display(img):
    """Normalize image to [0,1] range for safe display"""
    img_min, img_max = np.min(img), np.max(img)
    if img_min == img_max:
        return np.zeros_like(img)
    return (img - img_min) / (img_max - img_min)
    
# Get the LIME explanation mask for the *predicted* class
# explanation_lime.top_labels[0] usually corresponds to the class with highest probability
# Let's explicitly get the explanation for the predicted class index (0 or 1)

# Create a figure with 3 subplots: original image, positive features, negative features
plt.figure(figsize=(15, 5))
    
    # --- 1. LIME Boundary Plot ---
try:
    temp_lime, mask_lime = explanation_lime.get_image_and_mask(
        predicted_class_idx,
        positive_only=True,
        num_features=10,
        hide_rest=False
    )
    
    # Normalize temp_lime safely to avoid clipping warnings
    temp_lime_norm = normalize_for_display(temp_lime)
    
    plt.subplot(1, 3, 1)
    plt.imshow(mark_boundaries(temp_lime_norm, mask_lime))
    plt.title(f"LIME Boundaries (Pos for '{predicted_class_name}')")
    plt.axis('off')
    show_boundaries = True
except Exception as e:
    logger.error(f"Could not generate LIME boundary plot: {e}", exc_info=True)
    plt.subplot(1, 3, 1)
    plt.imshow(target_image_numpy)  # Show original if boundaries fail
    plt.title("Original Image (LIME Boundaries Failed)")
    plt.axis('off')
    show_boundaries = False

# --- 2. LIME Heatmap Generation ---
try:
    # Create empty heatmap weights array
    lime_heatmap_weights = np.zeros(explanation_lime.segments.shape)
    
    # Get weights for the predicted class
    lime_exp = explanation_lime.local_exp[predicted_class_idx]
    
    # Map positive weights to the segments
    for seg_id, weight in lime_exp:
        if weight > 0:  # Only consider positive contributions
            lime_heatmap_weights[explanation_lime.segments == seg_id] = weight
    
    # Normalize weights for better visualization
    if np.max(lime_heatmap_weights) > 0:
        lime_heatmap_weights = lime_heatmap_weights / np.max(lime_heatmap_weights)
    
    # Generate heatmap overlay
    # Call the existing generate_heatmap_overlay function
    lime_heatmap_overlay = generate_heatmap_overlay(
        target_image_numpy,
        lime_heatmap_weights,
        colormap='viridis',  # Or 'hot', 'jet', 'Reds'
        alpha=0.7
    )
    
    plt.subplot(1, 3, 2)
    plt.imshow(cv2.cvtColor(lime_heatmap_overlay, cv2.COLOR_BGR2RGB))  # Convert BGR->RGB for plt
    plt.title(f"LIME Heatmap (Pos for '{predicted_class_name}')")
    plt.axis('off')
    show_heatmap = True
    
except KeyError:
    logger.warning(f"Predicted class index {predicted_class_idx} not found in LIME explanation weights.")
    plt.subplot(1, 3, 2)
    plt.imshow(target_image_numpy)
    plt.title("Original Image (LIME Heatmap Failed)")
    plt.axis('off')
    show_heatmap = False
except Exception as e:
    logger.error(f"Error generating LIME heatmap: {e}", exc_info=True)
    plt.subplot(1, 3, 2)
    plt.imshow(target_image_numpy)
    plt.title("Original Image (LIME Heatmap Error)")
    plt.axis('off')
    show_heatmap = False

# --- 3. Show Original Image for Reference ---
plt.subplot(1, 3, 3)
plt.imshow(target_image_numpy)  # Assumes target_image_numpy is RGB or Grayscale
plt.title(f"Original Image - Predicted: {predicted_class_name} ({output_prob:.4f})")
plt.axis('off')

plt.tight_layout()
plt.show()

# --- BONUS: Alternative Visualization with Both Positive and Negative Features ---
# If the above visualization fails or you want an additional view
try:
    plt.figure(figsize=(12, 4))
    
    # Plot 1: Original image
    plt.subplot(1, 3, 1)
    plt.imshow(target_image_numpy)
    plt.title(f"Original\nPrediction: {predicted_class_name}")
    plt.axis('off')
    
    # Plot 2: Positive features (supporting the prediction)
    temp_pos, mask_pos = explanation_lime.get_image_and_mask(
        predicted_class_idx,
        positive_only=True,
        num_features=5,
        hide_rest=False
    )
    temp_pos_norm = normalize_for_display(temp_pos)
    plt.subplot(1, 3, 2)
    plt.imshow(mark_boundaries(temp_pos_norm, mask_pos, color=(0,1,0)))
    plt.title(f"Supporting Features\n(Green)")
    plt.axis('off')
    
    # Plot 3: Negative features (contradicting the prediction)
    try:
        temp_neg, mask_neg = explanation_lime.get_image_and_mask(
            predicted_class_idx,
            positive_only=False,
            negative_only=True, 
            num_features=5,
            hide_rest=False
        )
        temp_neg_norm = normalize_for_display(temp_neg)
        plt.subplot(1, 3, 3)
        plt.imshow(mark_boundaries(temp_neg_norm, mask_neg, color=(1,0,0)))
        plt.title(f"Contradicting Features\n(Red)")
        plt.axis('off')
    except:
        # If negative features extraction fails, show original again
        plt.subplot(1, 3, 3)
        plt.imshow(target_image_numpy)
        plt.title("No Contradicting Features Found")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
except Exception as fallback_e:
    logger.warning(f"Alternative visualization also failed: {fallback_e}")

### --- SHAP Explanation ---

In [None]:
# --- 7. SHAP Explanation ---
logger.info("--- Starting SHAP Explanation ---")
start_shap_time = time.time()

# --- Prepare Background Data for SHAP ---
background_files = []
for f in os.listdir(BACKGROUND_DATA_DIR):
    file_path = os.path.join(BACKGROUND_DATA_DIR, f)
    # Check if it's a file and has a valid image extension
    if os.path.isfile(file_path) and f.lower().endswith(('.png', '.jpg', '.jpeg')):
        background_files.append(file_path)
                    
if len(background_files) > N_SHAP_BACKGROUND_SAMPLES:
    background_files = random.sample(background_files, N_SHAP_BACKGROUND_SAMPLES)
elif not background_files:
    logger.error(f"No background image files found in {BACKGROUND_DATA_DIR}. Cannot run SHAP.")
    exit()

logger.info(f"Loading {len(background_files)} background images for SHAP...")
background_tensors = []
for f in tqdm(background_files, desc="Loading Background"):
    img = load_image(f, transform_type='tensor')
    if img is not None:
        background_tensors.append(img)

if not background_tensors:
     logger.error(f"Failed to load any background images. Cannot run SHAP.")
     exit()

background_data_tensor = torch.stack(background_tensors).to(DEVICE)
logger.info(f"Background data shape: {background_data_tensor.shape}")

# --- Prepare Target Image Tensor for SHAP ---
# Needs batch dimension: (1, C, H, W)
shap_input_tensor = target_image_tensor.unsqueeze(0).to(DEVICE)

# --- Create SHAP Explainer ---
# GradientExplainer is usually efficient for PyTorch models
# Pass the model and the background data
explainer_shap = shap.GradientExplainer(model, background_data_tensor)

# --- Calculate SHAP values ---
# This can be computationally expensive
logger.info("Calculating SHAP values (this may take some time)...")
shap_values = explainer_shap.shap_values(shap_input_tensor)
logger.info("SHAP values calculated.")

shap_duration = time.time() - start_shap_time
logger.info(f"SHAP explanation generated in {shap_duration:.2f} seconds.")

### --- SHAP Visualization ---

In [None]:
logger.info("--- Visualizing SHAP Results ---")

# --- Process image and SHAP values for visualization ---
# Convert input tensor for plotting: (1, C, H, W) -> (H, W, C) and denormalize approx
shap_input_numpy = shap_input_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
# Approximate denormalization for visualization
shap_input_numpy = (shap_input_numpy * [0.229, 0.224, 0.225]) + [0.485, 0.456, 0.406]
shap_input_numpy = np.clip(shap_input_numpy, 0, 1)

# Process SHAP values: Handle both list and array formats
if isinstance(shap_values, list):
    shap_values_processed = shap_values[0]  # Take the first element if it's a list
else:
    shap_values_processed = shap_values

# Reshape SHAP values for plotting: (1, C, H, W) -> (1, H, W, C)
shap_values_plot = shap_values_processed.transpose(0, 2, 3, 1)

# Log shapes for debugging
logger.info(f"Shape of image for SHAP plot: {shap_input_numpy.shape}")  # Should be (H, W, C)
logger.info(f"Shape of SHAP values for plot: {shap_values_plot.shape}")  # Should be (1, H, W, C)

# --- 1. Default SHAP image_plot ---
print("\nSHAP Default Explanation Plot:")
try:
    shap.image_plot(
        shap_values=shap_values_plot,
        pixel_values=np.expand_dims(shap_input_numpy, 0),
        labels=np.array([[predicted_class_name]]),
        show=True  # Show the plot directly
    )
    show_default_shap = True
except Exception as e:
    logger.error(f"Error visualizing SHAP with default image_plot: {e}", exc_info=True)
    print("Could not generate default SHAP plot.")
    show_default_shap = False

# --- 2. Custom SHAP Heatmap ---
print("\nSHAP Custom Heatmap Plot:")
try:
    # shap_values_processed has shape (1, C, H, W)
    # 1. Remove Batch dimension -> (C, H, W)
    shap_map = shap_values_processed[0]
    
    # 2. Aggregate across channels
    # Sum contributions across channels, then clip negatives
    shap_heatmap_data = np.sum(shap_map, axis=0)  # Now (H, W)
    
    # 3. Keep only positive contributions (pushing towards predicted class)
    shap_heatmap_data_pos = np.clip(shap_heatmap_data, a_min=0, a_max=None)
    
    # 4. Generate Overlay
    shap_heatmap_overlay = generate_heatmap_overlay(
        shap_input_numpy,
        shap_heatmap_data_pos,  # Use positive contributions map
        colormap='hot',  # 'hot', 'Reds', 'jet' are good choices
        alpha=0.6
    )
    
    # 5. Plot side by side comparison
    plt.figure(figsize=(10, 5))
    
    # Heatmap overlay
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(shap_heatmap_overlay, cv2.COLOR_BGR2RGB))  # Convert BGR->RGB for plt
    plt.title(f"SHAP Heatmap (Pos Contributions for '{predicted_class_name}')")
    plt.axis('off')
    
    # Original image
    plt.subplot(1, 2, 2)
    plt.imshow(shap_input_numpy)  # Already float [0, 1] suitable for plt
    plt.title("Original Image (Approx. Denormalized)")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    plt.close()
    show_shap = True
    
except Exception as e:
    logger.error(f"Error generating SHAP heatmap: {e}", exc_info=True)
    # Fallback plot
    plt.figure(figsize=(6, 6))
    plt.imshow(shap_input_numpy)
    plt.title("Original Image (SHAP Heatmap Failed)")
    plt.axis('off')
    plt.show()
    shap_heatmap_overlay = shap_input_numpy
    show_shap = False

## --- 6. Combine Visualizations ---

In [None]:
# Grad-CAM
plt.subplot(2, 2, 1)
plt.imshow(cv2.cvtColor(gradcam_heatmap_overlay, cv2.COLOR_BGR2RGB))
plt.title(f"Grad-CAM (Pos for '{predicted_class_name}')")
plt.axis('off')

# LIME Boundary
plt.subplot(2, 2, 2)
if show_boundaries:
    plt.imshow(mark_boundaries(temp_lime / 2 + 0.5, mask_lime))
    plt.title(f"LIME Boundaries (Pos for '{predicted_class_name}')")
else:
    plt.imshow(target_image_numpy)
    plt.title("Original Image (LIME Boundaries Failed)")
plt.axis('off')

# LIME Heatmap
plt.subplot(2, 2, 3)
if show_heatmap:
    plt.imshow(cv2.cvtColor(lime_heatmap_overlay, cv2.COLOR_BGR2RGB))
    plt.title(f"LIME Heatmap (Pos for '{predicted_class_name}')")
else:
    plt.imshow(target_image_numpy)
    plt.title("Original Image (LIME Heatmap Failed)")
plt.axis('off')

# SHAP Heatmap
plt.subplot(2, 2, 4)
if show_shap:
    plt.imshow(cv2.cvtColor(shap_heatmap_overlay, cv2.COLOR_BGR2RGB))
    plt.title(f"SHAP Heatmap (Pos for '{predicted_class_name}')")
else:
    plt.imshow(shap_input_numpy)
    plt.title("Original Image (SHAP Heatmap Failed)")
plt.axis('off')

plt.tight_layout()
combined_output_path = os.path.join(OUTPUT_DIR, f"{os.path.basename(IMAGE_PATH).split('.')[0]}_combined.png")
plt.savefig(combined_output_path, bbox_inches='tight')
plt.show()
plt.close()

logger.info(f"Combined visualization saved to {combined_output_path}")

# Summary

In [None]:
# Log a summary of the execution.
logger.info("--- Explainability Analysis Summary ---")
logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Image: {IMAGE_PATH}")
logger.info(f"Predicted Class: {predicted_class_name} (Probability: {output_prob:.4f})")
logger.info(f"Grad-CAM Duration: {gradcam_duration:.2f} seconds")
logger.info(f"LIME Duration: {lime_duration:.2f} seconds")
logger.info(f"SHAP Duration: {shap_duration:.2f} seconds")
logger.info(f"Output Directory: {OUTPUT_DIR}")

logger.info("--- XAI Script Finished ---")