In [1]:
"""
Layer-wise Relevance Propagation (LRP) for ResNet18
====================================================

LRP is a technique that explains neural network predictions by backpropagating
"relevance" from the output through all layers to the input pixels.

Key LRP Rules:
- LRP-0: Basic rule, simple redistribution
- LRP-epsilon: Adds small epsilon for numerical stability  
- LRP-gamma: Emphasizes positive contributions
- LRP-alpha-beta: Separates positive/negative contributions (α + β = 1)

For this implementation, we use the Captum library which provides robust LRP.
"""

best_path = "best_bal_model.pth"

# Install captum if not available
from captum.attr import LRP

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision import transforms, models
from torchvision.models import ResNet18_Weights

# ============================================================
# 1. LOAD MODEL
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model
model_lrp = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model_lrp.fc = nn.Linear(model_lrp.fc.in_features, 1)
model_lrp.load_state_dict(torch.load(best_path, map_location=device))
model_lrp = model_lrp.to(device)
model_lrp.eval()

print(f"Model loaded from: {best_path}")
print(f"Device: {device}")

# ============================================================
# 2. DATA PREPARATION
# ============================================================
# Transform for input images
transform = val_test_transform if val_test_transform else transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Denormalization for visualization
def denormalize(tensor):
    """Convert normalized tensor back to displayable image"""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
    tensor = tensor * std + mean
    tensor = torch.clamp(tensor, 0, 1)
    return tensor

# ============================================================
# 3. LRP EXPLAINER CLASS
# ============================================================
class LRPExplainer:
    """
    Layer-wise Relevance Propagation explainer for binary classification.
    
    LRP redistributes the prediction score backwards through the network
    to assign relevance scores to each input pixel.
    """
    
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.lrp = LRP(model)
    
    def explain(self, input_tensor, target=None):
        """
        Compute LRP attribution for an input image.
        
        Args:
            input_tensor: Normalized input tensor (1, 3, H, W)
            target: Target class (0 or 1 for binary). If None, uses predicted class.
        
        Returns:
            attribution: LRP relevance scores (H, W)
        """
        input_tensor = input_tensor.to(self.device)
        input_tensor.requires_grad = True
        
        # Get prediction if target not specified
        if target is None:
            with torch.no_grad():
                output = self.model(input_tensor)
                target = (torch.sigmoid(output) >= 0.5).long().item()
        
        # Compute LRP attribution
        # For binary classification with single output, target=0 means the output neuron
        attribution = self.lrp.attribute(input_tensor, target=0)
        
        # Sum across channels and convert to numpy
        attr_np = attribution.squeeze().cpu().detach().numpy()
        if attr_np.ndim == 3:
            attr_np = attr_np.sum(axis=0)  # Sum RGB channels
        
        return attr_np
    
    def predict(self, input_tensor):
        """Get model prediction probability"""
        with torch.no_grad():
            output = self.model(input_tensor.to(self.device))
            prob = torch.sigmoid(output).item()
        return prob

# ============================================================
# 4. VISUALIZATION FUNCTIONS
# ============================================================
def visualize_lrp(image_path, explainer, transform, class_names, figsize=(15, 4)):
    """
    Visualize LRP explanation for a single image.
    
    Shows: Original | LRP Heatmap | Positive Relevance | Overlay
    """
    # Load and preprocess image
    img_pil = Image.open(image_path).convert('RGB')
    img_tensor = transform(img_pil).unsqueeze(0)
    
    # Get prediction
    prob = explainer.predict(img_tensor)
    pred_class = class_names[1] if prob >= 0.5 else class_names[0]
    pred_prob = prob if prob >= 0.5 else 1 - prob
    
    # Compute LRP attribution
    attr = explainer.explain(img_tensor)
    
    # Prepare original image for display
    img_display = denormalize(img_tensor.squeeze()).permute(1, 2, 0).cpu().numpy()
    
    # Create figure
    fig, axes = plt.subplots(1, 4, figsize=figsize)
    
    # 1. Original image
    axes[0].imshow(img_display)
    axes[0].set_title(f"Original\nPred: {pred_class} ({pred_prob:.1%})")
    axes[0].axis('off')
    
    # 2. LRP Heatmap (full, blue-white-red)
    max_abs = np.abs(attr).max() + 1e-10
    axes[1].imshow(attr, cmap='bwr', vmin=-max_abs, vmax=max_abs)
    axes[1].set_title("LRP Attribution\n(Red→Pneumonia, Blue→Normal)")
    axes[1].axis('off')
    
    # 3. Positive relevance only (regions supporting the prediction)
    attr_positive = np.maximum(attr, 0)
    axes[2].imshow(attr_positive, cmap='Reds')
    axes[2].set_title("Positive Relevance\n(Supporting prediction)")
    axes[2].axis('off')
    
    # 4. Overlay on original
    attr_norm = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
    heatmap = plt.cm.jet(attr_norm)[:, :, :3]
    img_uint8 = (img_display * 255).astype(np.uint8)
    heatmap_uint8 = (heatmap * 255).astype(np.uint8)
    overlay = cv2.addWeighted(img_uint8, 0.6, heatmap_uint8, 0.4, 0)
    axes[3].imshow(overlay)
    axes[3].set_title("LRP Overlay\n(Important regions)")
    axes[3].axis('off')
    
    plt.tight_layout()
    return fig, attr

def plot_lrp_grid(image_paths, explainer, transform, class_names, rows=4, cols=4):
    """
    Plot LRP explanations for multiple images in a grid.
    Each row shows: Original | LRP Heatmap | Overlay
    """
    n_images = min(len(image_paths), rows * cols // 3 * 3)  # Ensure divisible by 3
    n_rows = (n_images + 2) // 3  # 3 columns per image (orig, heatmap, overlay)
    
    fig, axes = plt.subplots(n_rows, 3, figsize=(12, 4 * n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(n_rows):
        if i >= len(image_paths):
            break
            
        img_path = image_paths[i]
        
        # Load and process
        img_pil = Image.open(img_path).convert('RGB')
        img_tensor = transform(img_pil).unsqueeze(0)
        
        # Prediction
        prob = explainer.predict(img_tensor)
        pred_class = class_names[1] if prob >= 0.5 else class_names[0]
        pred_prob = prob if prob >= 0.5 else 1 - prob
        
        # LRP attribution
        attr = explainer.explain(img_tensor)
        
        # Display image
        img_display = denormalize(img_tensor.squeeze()).permute(1, 2, 0).cpu().numpy()
        
        # Column 1: Original
        axes[i, 0].imshow(img_display)
        axes[i, 0].set_title(f"Original\n{pred_class} ({pred_prob:.1%})")
        axes[i, 0].axis('off')
        
        # Column 2: LRP Heatmap
        max_abs = np.abs(attr).max() + 1e-10
        axes[i, 1].imshow(attr, cmap='bwr', vmin=-max_abs, vmax=max_abs)
        axes[i, 1].set_title("LRP Attribution")
        axes[i, 1].axis('off')
        
        # Column 3: Overlay
        attr_norm = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
        heatmap = plt.cm.jet(attr_norm)[:, :, :3]
        img_uint8 = (img_display * 255).astype(np.uint8)
        heatmap_uint8 = (heatmap * 255).astype(np.uint8)
        overlay = cv2.addWeighted(img_uint8, 0.6, heatmap_uint8, 0.4, 0)
        axes[i, 2].imshow(overlay)
        axes[i, 2].set_title("Overlay")
        axes[i, 2].axis('off')
    
    plt.suptitle("LRP Explanations for Pneumonia Classification", fontsize=14)
    plt.tight_layout()
    return fig

# ============================================================
# 5. CREATE EXPLAINER AND RUN
# ============================================================
print("\nInitializing LRP Explainer...")
lrp_explainer = LRPExplainer(model_lrp, device)
print("LRP Explainer ready!")

# ============================================================
# 6. EXPLAIN SAMPLE IMAGES
# ============================================================
print("\n" + "="*60)
print("LRP EXPLANATION FOR SAMPLE IMAGES")
print("="*60)

# Test on a few normal and pneumonia images
print("\n--- Normal Samples ---")
if len(normal_pathes) > 0:
    fig, attr = visualize_lrp(normal_pathes[0], lrp_explainer, transform, class_names)
    plt.show()

print("\n--- Pneumonia Samples ---")  
if len(pneumonia_pathes) > 0:
    fig, attr = visualize_lrp(pneumonia_pathes[0], lrp_explainer, transform, class_names)
    plt.show()

print("\n" + "="*60)
print("INTERPRETATION GUIDE:")
print("-" * 60)
print("• Red regions: Increase Pneumonia prediction (evidence FOR disease)")
print("• Blue regions: Decrease Pneumonia prediction (evidence AGAINST disease)")
print("• White regions: Neutral, little influence on prediction")
print("• Bright regions in overlay: Most important for the model's decision")
print("="*60)


Model loaded from: best_bal_model.pth
Device: cuda


NameError: name 'val_test_transform' is not defined

In [None]:
# ============================================================
# LRP VISUALIZATION FOR MULTIPLE IMAGES
# ============================================================

print("="*60)
print("LRP Explanations - Normal Cases")
print("="*60)
fig = plot_lrp_grid(normal_pathes[:6], lrp_explainer, transform, class_names, rows=6)
plt.show()

print("\n" + "="*60)
print("LRP Explanations - Pneumonia Cases")  
print("="*60)
fig = plot_lrp_grid(pneumonia_pathes[:6], lrp_explainer, transform, class_names, rows=6)
plt.show()
