# SignXAI with PyTorch - Advanced Usage

This tutorial demonstrates advanced techniques for using SignXAI with PyTorch models. It builds on the basic tutorial and explores more sophisticated explainability methods and customizations.

## Setup Requirements

For this PyTorch tutorial, you'll need to install SignXAI with PyTorch dependencies:

```bash
# For conda users
conda create -n signxai-pytorch python=3.8
conda activate signxai-pytorch
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/pytorch.txt

# Or for pip users
python -m venv signxai_pytorch_env
source signxai_pytorch_env/bin/activate  # On Windows: signxai_pytorch_env\Scripts\activate
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/pytorch.txt
```

## Overview

In this tutorial, we'll cover:
1. Customizing LRP rules with PyTorch
2. Creating composite explanations
3. Working with custom model architectures
4. Advanced visualization techniques
5. Layer-wise relevance tracking

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# SignXAI imports - including advanced modules
from signxai.torch_signxai.methods import (
    SIGN, 
    GradCAM, 
    GuidedBackprop, 
    LRPZ, 
    LRPEpsilon,
    LRPAlphaBeta,  # Alpha-Beta rule for LRP
    LRPComposite   # For composite LRP rules
)
from signxai.common.visualization import visualize_attribution
from signxai.torch_signxai.utils import remove_softmax

## 2. Set Up Paths and Load Data

In [None]:
# Set up data and model paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
VGG16_MODEL_PATH = os.path.join(DATA_DIR, "models", "pytorch", "VGG16", "vgg16_torch.pt")
IMAGE_PATH = os.path.join(DATA_DIR, "images", "example.jpg")

# Verify paths
assert os.path.exists(VGG16_MODEL_PATH), f"VGG16 model not found at {VGG16_MODEL_PATH}"
assert os.path.exists(IMAGE_PATH), f"Image not found at {IMAGE_PATH}"

# Load model
try:
    model = torch.load(VGG16_MODEL_PATH)
    print("Loaded pre-saved VGG16 model")
except Exception as e:
    print(f"Could not load saved model: {e}\nLoading from torchvision instead")
    import torchvision.models as models
    model = models.vgg16(pretrained=True)
    
model.eval()  # Set to evaluation mode
model_no_softmax = remove_softmax(model)

# Helper function to load and preprocess image
def preprocess_image(image_path, size=(224, 224)):
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return img, transform(img).unsqueeze(0)  # Add batch dimension

# Load and display image
original_img, preprocessed_img = preprocess_image(IMAGE_PATH)

plt.figure(figsize=(6, 6))
plt.imshow(original_img)
plt.axis('off')
plt.title('Sample Image')
plt.show()

## 3. Customizing LRP Rules with PyTorch

SignXAI provides extensive customization options for Layer-wise Relevance Propagation (LRP). Let's explore how to configure different rules for different layers of a model.

In [None]:
# Get a prediction for reference
with torch.no_grad():
    output = model(preprocessed_img)
    _, predicted_idx = torch.max(output, 1)
    predicted_class = predicted_idx.item()
    
print(f"Predicted class index: {predicted_class}")

# Example 1: Using LRP Alpha-Beta rule with custom alpha/beta values
alpha, beta = 2, 1  # Alpha-Beta parameters (alpha - beta = 1 to ensure conservation)
lrp_alpha_beta = LRPAlphaBeta(model_no_softmax, alpha=alpha, beta=beta)
explanation_alpha_beta = lrp_alpha_beta.attribute(preprocessed_img, target=predicted_class)

# Example 2: Using LRP Epsilon rule with a custom epsilon value
epsilon = 0.01  # Small epsilon for sharper attribution maps
lrp_epsilon = LRPEpsilon(model_no_softmax, epsilon=epsilon)
explanation_epsilon = lrp_epsilon.attribute(preprocessed_img, target=predicted_class)

# Example 3: Composite LRP - Different rules for different layers
# Define layer-wise rules
layer_rules = {
    # Format: 'layer_name_or_regex': ('rule_name', rule_params)
    'features.0': ('epsilon', {'epsilon': 0.1}),  # First conv layer - epsilon rule
    'features.*': ('alpha_beta', {'alpha': 1, 'beta': 0}),  # Middle layers - alpha1-beta0
    'classifier.*': ('z_plus', {})  # Final layers - z+ rule
}

# Create composite LRP explainer
lrp_composite = LRPComposite(model_no_softmax, layer_rules=layer_rules)
explanation_composite = lrp_composite.attribute(preprocessed_img, target=predicted_class)

In [None]:
# Visualize the different LRP variants
def preprocess_for_visualization(explanation):
    # Handle PyTorch tensor shape [B, C, H, W]
    if isinstance(explanation, torch.Tensor):
        explanation = explanation.detach().cpu().numpy()
    if explanation.ndim == 4:
        explanation = np.transpose(explanation[0], (1, 2, 0))  # [H, W, C]
    elif explanation.ndim == 3 and explanation.shape[0] == 3:  # If it's [C, H, W]
        explanation = np.transpose(explanation, (1, 2, 0))      # [H, W, C]
    
    # Use absolute values for gradient-based methods
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

# Create a figure for visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Get numpy array of original image
np_img = np.array(original_img.resize((224, 224)))

# Plot the different LRP variants
visualize_attribution(np_img, preprocess_for_visualization(explanation_alpha_beta), ax=axes[0])
axes[0].set_title(f"LRP Alpha-Beta (α={alpha}, β={beta})")
axes[0].axis('off')

visualize_attribution(np_img, preprocess_for_visualization(explanation_epsilon), ax=axes[1])
axes[1].set_title(f"LRP Epsilon (ε={epsilon})")
axes[1].axis('off')

visualize_attribution(np_img, preprocess_for_visualization(explanation_composite), ax=axes[2])
axes[2].set_title("LRP Composite (Layer-wise rules)")
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 4. Creating Composite Explanations

You can combine multiple explainability methods to create composite explanations that highlight different aspects of the model's decision.

In [None]:
# Find the target layer for GradCAM
target_layer = None
for name, module in model.features.named_children():
    if isinstance(module, nn.Conv2d):
        target_layer = module  # Get the last convolutional layer

# Generate explanations using different methods
gradcam = GradCAM(model_no_softmax, target_layer)
gradcam_explanation = gradcam.attribute(preprocessed_img, target=predicted_class)

guided_bp = GuidedBackprop(model_no_softmax)
guided_bp_explanation = guided_bp.attribute(preprocessed_img, target=predicted_class)

# Create Guided GradCAM as a composite of GradCAM and Guided Backpropagation
# Convert to numpy for processing
gradcam_np = gradcam_explanation.detach().cpu().numpy()
guided_bp_np = guided_bp_explanation.detach().cpu().numpy()

# Reshape GradCAM to match Guided Backprop dimensions if needed
if gradcam_np.shape != guided_bp_np.shape:
    # Assuming GradCAM is [1, 1, H, W] and Guided BP is [1, 3, H, W]
    gradcam_reshaped = np.repeat(gradcam_np, 3, axis=1)
    guided_gradcam = guided_bp_np * gradcam_reshaped
else:
    guided_gradcam = guided_bp_np * gradcam_np

# Create a gradient-weighted LRP by combining gradient and LRP
gradient = SIGN(model_no_softmax)
gradient_explanation = gradient.attribute(preprocessed_img, target=predicted_class).detach().cpu().numpy()
lrp_explanation = explanation_epsilon.detach().cpu().numpy()

# Weight LRP by gradient magnitude
gradient_weighted_lrp = gradient_explanation * lrp_explanation

In [None]:
# Visualize the composite explanations
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

# Plot original methods and composites
visualize_attribution(np_img, preprocess_for_visualization(guided_bp_np), ax=axes[0])
axes[0].set_title("Guided Backpropagation")
axes[0].axis('off')

visualize_attribution(np_img, preprocess_for_visualization(gradcam_np), ax=axes[1])
axes[1].set_title("GradCAM")
axes[1].axis('off')

visualize_attribution(np_img, preprocess_for_visualization(guided_gradcam), ax=axes[2])
axes[2].set_title("Guided GradCAM (Composite)")
axes[2].axis('off')

visualize_attribution(np_img, preprocess_for_visualization(gradient_weighted_lrp), ax=axes[3])
axes[3].set_title("Gradient-Weighted LRP (Composite)")
axes[3].axis('off')

plt.tight_layout()
plt.show()

## 5. Working with Custom Model Architectures

SignXAI works with custom PyTorch model architectures, not just pre-defined ones. Let's create a simple custom model and explain its predictions.

In [None]:
# Define a simple custom CNN model
class CustomCNN(nn.Module):
    def __init__(self, num_classes=1000):
        super(CustomCNN, self).__init__()
        # Feature extraction
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # Classification
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
# Initialize custom model
custom_model = CustomCNN()
custom_model.eval()

# For demo purposes, we'll use random weights
# In a real scenario, you'd train this model first

# Remove softmax from the custom model
custom_model_no_softmax = remove_softmax(custom_model)

In [None]:
# Generate a forward pass to get a random prediction
with torch.no_grad():
    custom_output = custom_model(preprocessed_img)
    _, custom_predicted_idx = torch.max(custom_output, 1)
    custom_predicted_class = custom_predicted_idx.item()
    
print(f"Custom model prediction (random weights): Class {custom_predicted_class}")

# Apply explainability methods to the custom model
custom_last_conv = custom_model.features[6]  # Last conv layer

custom_methods = {
    "Gradient": SIGN(custom_model_no_softmax),
    "GradCAM": GradCAM(custom_model_no_softmax, custom_last_conv),
    "LRP-Z": LRPZ(custom_model_no_softmax),
}

custom_explanations = {}
for name, method in custom_methods.items():
    explanation = method.attribute(preprocessed_img, target=custom_predicted_class)
    custom_explanations[name] = explanation.detach().cpu().numpy()

In [None]:
# Visualize explanations for the custom model
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

for i, (name, explanation) in enumerate(custom_explanations.items()):
    visualize_attribution(np_img, preprocess_for_visualization(explanation), ax=axes[i])
    axes[i].set_title(f"Custom Model: {name}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 6. Advanced Visualization Techniques

SignXAI provides advanced visualization options to help interpret complex models more effectively.

In [None]:
# Get a high-quality explanation to visualize
lrp = LRPZ(model_no_softmax)
lrp_explanation = lrp.attribute(preprocessed_img, target=predicted_class).detach().cpu().numpy()
lrp_processed = preprocess_for_visualization(lrp_explanation)

# 1. Visualization with different color maps
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Default Red-Blue colormap
visualize_attribution(np_img, lrp_processed, ax=axes[0], cmap='RdBu_r')
axes[0].set_title("Red-Blue Colormap")
axes[0].axis('off')

# Heat colormap
visualize_attribution(np_img, lrp_processed, ax=axes[1], cmap='hot')
axes[1].set_title("Heat Colormap")
axes[1].axis('off')

# Viridis colormap
visualize_attribution(np_img, lrp_processed, ax=axes[2], cmap='viridis')
axes[2].set_title("Viridis Colormap")
axes[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# 2. Visualization with different alpha (transparency) values
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Low alpha (more transparent)
visualize_attribution(np_img, lrp_processed, ax=axes[0], alpha=0.3)
axes[0].set_title("Low Transparency (α=0.3)")
axes[0].axis('off')

# Medium alpha
visualize_attribution(np_img, lrp_processed, ax=axes[1], alpha=0.6)
axes[1].set_title("Medium Transparency (α=0.6)")
axes[1].axis('off')

# High alpha (less transparent)
visualize_attribution(np_img, lrp_processed, ax=axes[2], alpha=0.9)
axes[2].set_title("High Transparency (α=0.9)")
axes[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# 3. Advanced: Custom masking visualization
# We'll create a visualization that only shows areas with high attribution

# Create a binary mask based on attribution threshold
threshold = 0.5  # Threshold for significance
binary_mask = (lrp_processed > threshold).astype(float)

# Apply mask to the original image
masked_img = np_img.copy().astype(float) / 255.0
for i in range(3):  # Apply to each color channel
    if binary_mask.ndim == 3:
        channel_mask = binary_mask[:, :, i]
    else:
        channel_mask = binary_mask
    masked_img[:, :, i] = masked_img[:, :, i] * channel_mask

# Display results
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original image
axes[0].imshow(np_img)
axes[0].set_title("Original Image")
axes[0].axis('off')

# Binary mask
axes[1].imshow(binary_mask, cmap='gray')
axes[1].set_title(f"Binary Mask (threshold={threshold})")
axes[1].axis('off')

# Masked image
axes[2].imshow(masked_img)
axes[2].set_title("Masked Image (Only Important Features)")
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 7. Layer-wise Relevance Tracking

Advanced users might want to track how relevance flows through different layers of the network. SignXAI allows you to access intermediate relevance maps.

In [None]:
# Create a custom LRP analyzer that tracks layer-wise relevance
class LayerTrackingLRP(LRPZ):
    def __init__(self, model, tracked_layers=None):
        super().__init__(model)
        self.tracked_layers = tracked_layers or []
        self.layer_relevances = {}
        
        # Register hooks for tracking
        self.hooks = []
        for name, module in model.named_modules():
            if any(tracked in name for tracked in self.tracked_layers):
                hook = module.register_forward_hook(self._hook_fn(name))
                self.hooks.append(hook)
    
    def _hook_fn(self, name):
        def hook(module, input, output):
            self.layer_relevances[name] = output.detach()
        return hook
    
    def __del__(self):
        for hook in self.hooks:
            hook.remove()

# For VGG16, track feature layers at different depths
tracked_layers = [
    'features.1',    # Early layer (after first ReLU)
    'features.15',   # Middle layer
    'features.28',   # Deep layer (last feature layer)
    'classifier.3'   # Final classifier layer
]

# Create the layer tracking analyzer
tracking_lrp = LayerTrackingLRP(model_no_softmax, tracked_layers=tracked_layers)

# Generate explanation
tracking_explanation = tracking_lrp.attribute(preprocessed_img, target=predicted_class)

In [None]:
# Visualize intermediate layer activations
# For this example, we'll just visualize the first few channels of each layer

# Function to visualize activation maps
def visualize_activations(activations, max_channels=4, title="Layer Activations"):
    if len(activations) == 0:
        return
    
    # Determine plot size
    num_layers = len(activations)
    channels_per_layer = min(max_channels, activations[0].shape[1])
    
    fig, axes = plt.subplots(num_layers, channels_per_layer, figsize=(4*channels_per_layer, 3*num_layers))
    
    for i, (layer_name, activation) in enumerate(activations.items()):
        act = activation.cpu().numpy()
        for j in range(channels_per_layer):
            ax = axes[i, j] if num_layers > 1 else axes[j]
            channel_data = act[0, j]
            
            # Normalize for visualization
            if channel_data.max() > channel_data.min():
                channel_data = (channel_data - channel_data.min()) / (channel_data.max() - channel_data.min())
                
            ax.imshow(channel_data, cmap='viridis')
            ax.set_title(f"{layer_name}\nChannel {j}")
            ax.axis('off')
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Adjust for suptitle
    plt.show()

# Visualize the tracked layer activations
visualize_activations(tracking_lrp.layer_relevances, title="Layer-wise Activations")

## 8. Conclusion

In this advanced tutorial, we explored several sophisticated aspects of using SignXAI with PyTorch:

1. **Customizing LRP Rules**: We demonstrated how to configure different LRP variants and create composite rules for different layers.

2. **Creating Composite Explanations**: We combined multiple explainability methods to create enhanced visualizations like Guided GradCAM.

3. **Working with Custom Models**: We showed how SignXAI works seamlessly with custom model architectures.

4. **Advanced Visualization**: We explored different visualization options including different colormaps, transparency settings, and masking techniques.

5. **Layer-wise Relevance Tracking**: We demonstrated how to track and visualize relevance flow through different layers of the network.

These advanced techniques provide deeper insights into how your models make decisions and can help identify potential biases or weaknesses in the model's reasoning process.

For more information and to contribute to the project, visit the SignXAI documentation and GitHub repository.