# SignXAI2 with PyTorch - Advanced Usage

This tutorial demonstrates advanced techniques for using SignXAI2 with PyTorch models. It builds on the basic tutorial and explores more sophisticated explainability methods and customizations.

## Setup Requirements

**Important**: SignXAI2 requires Python 3.9 or 3.10 (Python 3.11+ is not supported)

Since you're running this tutorial, you should already have cloned the signxai2 repository. From the repository root directory:

### Using conda:
```bash
# Create environment with Python 3.10
conda create -n signxai2 python=3.10
conda activate signxai2

# Install SignXAI2 with PyTorch support
pip install signxai2[pytorch]

# Download models and example data
git lfs install
bash ./prepare.sh
```

### Using venv:
```bash
# Create virtual environment
python3.10 -m venv signxai2_env
source signxai2_env/bin/activate  # On Windows: signxai2_env\Scripts\activate

# Install SignXAI2 with PyTorch support
pip install signxai2[pytorch]

# Download models and example data
git lfs install
bash ./prepare.sh
```

## Overview

In this tutorial, we'll cover:
1. Using different LRP variants with the unified API
2. Advanced method parameters and configurations
3. Working with custom model architectures
4. Combining multiple explanation methods
5. Advanced visualization techniques

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# SignXAI2 unified API imports
from signxai import explain, list_methods, get_method_info

## 2. Set Up Paths and Load Data

In [None]:
# Set up data and model paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
VGG16_MODEL_PATH = os.path.join(DATA_DIR, "models", "pytorch", "VGG16", "vgg16_ported_weights.pt")
IMAGE_PATH = os.path.join(DATA_DIR, "images", "example.jpg")

# Verify paths
assert os.path.exists(VGG16_MODEL_PATH), f"VGG16 model not found at {VGG16_MODEL_PATH}"
assert os.path.exists(IMAGE_PATH), f"Image not found at {IMAGE_PATH}"

# Import the VGG16 model definition
sys.path.append(os.path.join(DATA_DIR, "models", "pytorch", "VGG16"))
from VGG16 import VGG16_PyTorch

# Load model
try:
    # Initialize the model architecture
    model = VGG16_PyTorch(num_classes=1000)
    # Load the pre-trained weights
    model.load_state_dict(torch.load(VGG16_MODEL_PATH, map_location=torch.device('cpu')))
    print("Loaded pre-saved VGG16 model weights")
except Exception as e:
    print(f"Could not load saved model: {e}\nLoading from torchvision instead")
    import torchvision.models as models
    model = models.vgg16(pretrained=True)
    
model.eval()  # Set to evaluation mode

# Helper function to load and preprocess image
def preprocess_image(image_path, size=(224, 224)):
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return img, transform(img).unsqueeze(0)  # Add batch dimension

# Load and display image
original_img, preprocessed_img = preprocess_image(IMAGE_PATH)

plt.figure(figsize=(6, 6))
plt.imshow(original_img)
plt.axis('off')
plt.title('Sample Image')
plt.show()

## 3. Using Different LRP Variants with the Unified API

SignXAI2 provides various LRP (Layer-wise Relevance Propagation) methods through the unified API. Let's explore different LRP rules and configurations.

In [None]:
# Get a prediction for reference
with torch.no_grad():
    output = model(preprocessed_img)
    _, predicted_idx = torch.max(output, 1)
    predicted_class = predicted_idx.item()
    
print(f"Predicted class index: {predicted_class}")

# Different LRP variants available in SignXAI2
lrp_methods = [
    'lrp_z',
    'lrp_epsilon_0_001',
    'lrp_epsilon_0_01', 
    'lrp_epsilon_0_1',
    'lrp_epsilon_0_2',
    'lrp_alpha_1_beta_0',
    'lrp_alpha_2_beta_1',
    'lrpsign_z',
    'lrpsign_epsilon_0_1'
]

# Generate explanations using different LRP variants
lrp_explanations = {}
for method in lrp_methods:
    if method in list_methods():
        print(f"Generating {method} explanation...")
        explanation = explain(
            model=model,
            x=preprocessed_img,
            method_name=method,
            target_class=predicted_class
        )
        lrp_explanations[method] = explanation

In [None]:
# Visualize the different LRP variants
def preprocess_for_visualization(explanation):
    # Handle PyTorch tensor shape [B, C, H, W]
    if isinstance(explanation, torch.Tensor):
        explanation = explanation.detach().cpu().numpy()
    if explanation.ndim == 4:
        explanation = np.transpose(explanation[0], (1, 2, 0))  # [H, W, C]
    elif explanation.ndim == 3 and explanation.shape[0] == 3:  # If it's [C, H, W]
        explanation = np.transpose(explanation, (1, 2, 0))      # [H, W, C]
    
    # Convert single channel to RGB
    if explanation.ndim == 2:
        explanation = np.expand_dims(explanation, axis=-1)
        explanation = np.repeat(explanation, 3, axis=-1)
    elif explanation.shape[-1] == 1:
        explanation = np.repeat(explanation, 3, axis=-1)
    
    # Use absolute values for gradient-based methods
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

# Create a figure for visualization
n_methods = len(lrp_explanations)
n_cols = 3
n_rows = (n_methods + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 6*n_rows))
axes = axes.flatten() if n_rows > 1 else [axes]

# Get numpy array of original image
np_img = np.array(original_img.resize((224, 224)))

# Plot the different LRP variants
for i, (method_name, explanation) in enumerate(lrp_explanations.items()):
    if i < len(axes):
        processed_explanation = preprocess_for_visualization(explanation)
        axes[i].imshow(np_img)
        axes[i].imshow(processed_explanation, alpha=0.5, cmap='hot')
        axes[i].set_title(method_name.replace('_', ' ').upper())
        axes[i].axis('off')

# Hide unused subplots
for i in range(len(lrp_explanations), len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 4. Advanced Method Parameters and Configurations

SignXAI2's unified API allows you to customize explanation methods with various parameters. Let's explore some advanced configurations.

In [None]:
# Example 1: SmoothGrad with different noise levels
smoothgrad_params = [
    {'noise_level': 0.05, 'num_samples': 10},
    {'noise_level': 0.1, 'num_samples': 25},
    {'noise_level': 0.2, 'num_samples': 50}
]

smoothgrad_explanations = {}
for params in smoothgrad_params:
    label = f"SmoothGrad (noise={params['noise_level']}, n={params['num_samples']})"
    print(f"Generating {label}...")
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name='smoothgrad',
        target_class=predicted_class,
        **params
    )
    smoothgrad_explanations[label] = explanation

# Example 2: Integrated Gradients with different step counts
ig_steps = [10, 25, 50]
ig_explanations = {}
for steps in ig_steps:
    label = f"Integrated Gradients ({steps} steps)"
    print(f"Generating {label}...")
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name='integrated_gradients',
        target_class=predicted_class,
        ig_steps=steps  # PyTorch uses 'ig_steps' parameter
    )
    ig_explanations[label] = explanation

# Example 3: Grad-CAM with different layers
# Find convolutional layers in VGG16
conv_layers = []
for name, module in model.features.named_children():
    if isinstance(module, nn.Conv2d):
        conv_layers.append(f"features.{name}")

# Select a few layers for comparison
selected_layers = [conv_layers[-3], conv_layers[-1]]  # Third-to-last and last conv layers
gradcam_explanations = {}
for layer in selected_layers:
    label = f"Grad-CAM ({layer})"
    print(f"Generating {label}...")
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name='grad_cam',
        target_class=predicted_class,
        target_layer=layer  # PyTorch uses 'target_layer'
    )
    gradcam_explanations[label] = explanation

In [None]:
# Visualize parameter variations
fig, axes = plt.subplots(3, 3, figsize=(18, 18))

# Row 1: SmoothGrad variations
for i, (label, explanation) in enumerate(smoothgrad_explanations.items()):
    processed = preprocess_for_visualization(explanation)
    axes[0, i].imshow(np_img)
    axes[0, i].imshow(processed, alpha=0.5, cmap='hot')
    axes[0, i].set_title(label)
    axes[0, i].axis('off')

# Row 2: Integrated Gradients variations
for i, (label, explanation) in enumerate(ig_explanations.items()):
    processed = preprocess_for_visualization(explanation)
    axes[1, i].imshow(np_img)
    axes[1, i].imshow(processed, alpha=0.5, cmap='hot')
    axes[1, i].set_title(label)
    axes[1, i].axis('off')

# Row 3: Grad-CAM variations
for i, (label, explanation) in enumerate(gradcam_explanations.items()):
    processed = preprocess_for_visualization(explanation)
    axes[2, i].imshow(np_img)
    axes[2, i].imshow(processed, alpha=0.5, cmap='hot')
    axes[2, i].set_title(label)
    axes[2, i].axis('off')

# Hide unused subplot
axes[2, 2].axis('off')

plt.tight_layout()
plt.show()

## 5. Working with Custom Model Architectures

SignXAI2's unified API works seamlessly with custom PyTorch model architectures. Let's create a simple custom model and explain its predictions.

In [None]:
# Define a simple custom CNN model
class CustomCNN(nn.Module):
    def __init__(self, num_classes=1000):
        super(CustomCNN, self).__init__()
        # Feature extraction
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # Classification
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
# Initialize custom model
custom_model = CustomCNN()
custom_model.eval()

# For demo purposes, we'll use random weights
# In a real scenario, you'd train this model first

# Remove softmax from the custom model
custom_model_no_softmax = remove_softmax(custom_model)

In [None]:
# Generate a forward pass to get a random prediction
with torch.no_grad():
    custom_output = custom_model(preprocessed_img)
    _, custom_predicted_idx = torch.max(custom_output, 1)
    custom_predicted_class = custom_predicted_idx.item()
    
print(f"Custom model prediction (random weights): Class {custom_predicted_class}")

# Apply explainability methods to the custom model using the unified API
# Find the last conv layer for GradCAM
custom_last_conv = None
for module in custom_model.features.children():
    if isinstance(module, nn.Conv2d):
        custom_last_conv = module

custom_methods = {
    'gradient': {},
    'grad_cam': {'target_layer': 'features.6'},  # Last conv layer in our custom model
    'guided_backprop': {},
    'lrp_z': {},
    'smoothgrad': {'noise_level': 0.1, 'num_samples': 25}
}

custom_explanations = {}
for method_name, params in custom_methods.items():
    print(f"Generating {method_name} explanation for custom model...")
    explanation = explain(
        model=custom_model,
        x=preprocessed_img,
        method_name=method_name,
        target_class=custom_predicted_class,
        **params
    )
    custom_explanations[method_name] = explanation

In [None]:
# Visualize explanations for the custom model
n_methods = len(custom_explanations)
fig, axes = plt.subplots(1, n_methods, figsize=(4*n_methods, 4))

if n_methods == 1:
    axes = [axes]

for i, (name, explanation) in enumerate(custom_explanations.items()):
    processed = preprocess_for_visualization(explanation)
    axes[i].imshow(np_img)
    axes[i].imshow(processed, alpha=0.5, cmap='hot')
    axes[i].set_title(f"Custom Model: {name.replace('_', ' ').title()}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 6. Combining Multiple Explanation Methods

Sometimes it's useful to combine insights from multiple explanation methods. Let's create ensemble explanations.

In [None]:
# Generate multiple explanations for ensemble
ensemble_methods = ['gradient', 'guided_backprop', 'lrp_epsilon_0_1', 'smoothgrad']
ensemble_explanations = {}

for method in ensemble_methods:
    print(f"Generating {method} explanation...")
    params = {}
    if method == 'smoothgrad':
        params = {'noise_level': 0.1, 'num_samples': 25}
    
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name=method,
        target_class=predicted_class,
        **params
    )
    ensemble_explanations[method] = preprocess_for_visualization(explanation)

# Create ensemble explanation by averaging
ensemble_avg = np.mean(list(ensemble_explanations.values()), axis=0)

# Create ensemble explanation by taking maximum
ensemble_max = np.max(list(ensemble_explanations.values()), axis=0)

# Visualize individual and ensemble explanations
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Individual methods
for i, (method, explanation) in enumerate(list(ensemble_explanations.items())[:3]):
    axes[0, i].imshow(np_img)
    axes[0, i].imshow(explanation, alpha=0.5, cmap='hot')
    axes[0, i].set_title(method.replace('_', ' ').title())
    axes[0, i].axis('off')

# Row 2: Last method and ensemble results
axes[1, 0].imshow(np_img)
axes[1, 0].imshow(list(ensemble_explanations.values())[3], alpha=0.5, cmap='hot')
axes[1, 0].set_title(list(ensemble_explanations.keys())[3].replace('_', ' ').title())
axes[1, 0].axis('off')

axes[1, 1].imshow(np_img)
axes[1, 1].imshow(ensemble_avg, alpha=0.5, cmap='hot')
axes[1, 1].set_title("Ensemble (Average)")
axes[1, 1].axis('off')

axes[1, 2].imshow(np_img)
axes[1, 2].imshow(ensemble_max, alpha=0.5, cmap='hot')
axes[1, 2].set_title("Ensemble (Maximum)")
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# 7. Advanced Visualization Techniques

SignXAI2 provides various ways to visualize explanations. Let's explore different visualization options.

# Get a high-quality explanation to visualize
lrp_explanation = explain(
    model=model,
    x=preprocessed_img,
    method_name='lrp_epsilon_0_1',
    target_class=predicted_class
)
lrp_processed = preprocess_for_visualization(lrp_explanation)

# 1. Visualization with different color maps and overlays
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Different colormaps
colormaps = ['hot', 'jet', 'RdBu_r']
for i, cmap in enumerate(colormaps):
    axes[0, i].imshow(np_img)
    im = axes[0, i].imshow(lrp_processed, alpha=0.5, cmap=cmap)
    axes[0, i].set_title(f"{cmap} colormap")
    axes[0, i].axis('off')

# Row 2: Different overlay techniques
# Pure heatmap
axes[1, 0].imshow(lrp_processed, cmap='hot')
axes[1, 0].set_title("Pure Heatmap")
axes[1, 0].axis('off')

# Masked overlay (only show high relevance areas)
threshold = 0.3
mask = (lrp_processed.mean(axis=2) > threshold)
masked_img = np_img.copy()
masked_img[~mask] = masked_img[~mask] * 0.3  # Dim irrelevant areas
axes[1, 1].imshow(masked_img)
axes[1, 1].set_title(f"Masked (threshold={threshold})")
axes[1, 1].axis('off')

# Contour overlay
from skimage import measure
contours = measure.find_contours(lrp_processed.mean(axis=2), 0.3)
axes[1, 2].imshow(np_img)
for contour in contours:
    axes[1, 2].plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')
axes[1, 2].set_title("Contour Overlay")
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# 2. Advanced: Multiple transparency levels and thresholding
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Different transparency levels
alphas = [0.3, 0.5, 0.7]
for i, alpha in enumerate(alphas):
    axes[0, i].imshow(np_img)
    axes[0, i].imshow(lrp_processed, alpha=alpha, cmap='hot')
    axes[0, i].set_title(f"Alpha = {alpha}")
    axes[0, i].axis('off')

# Different thresholding approaches
# Top-k pixels
k_percent = 20  # Show top 20% of pixels
flat_values = lrp_processed.mean(axis=2).flatten()
threshold_k = np.percentile(flat_values, 100 - k_percent)
top_k_mask = lrp_processed.mean(axis=2) > threshold_k

axes[1, 0].imshow(np_img)
axes[1, 0].imshow(np.ma.masked_where(~top_k_mask, lrp_processed), alpha=0.7, cmap='hot')
axes[1, 0].set_title(f"Top {k_percent}% pixels")
axes[1, 0].axis('off')

# Binary threshold
binary_threshold = 0.5
binary_mask = (lrp_processed.mean(axis=2) > binary_threshold).astype(float)
axes[1, 1].imshow(np_img)
axes[1, 1].imshow(binary_mask, alpha=0.5, cmap='Reds')
axes[1, 1].set_title(f"Binary (threshold={binary_threshold})")
axes[1, 1].axis('off')

# Smooth gradient overlay
smooth_mask = lrp_processed.mean(axis=2)
axes[1, 2].imshow(np_img)
axes[1, 2].imshow(smooth_mask, alpha=0.6, cmap='plasma')
axes[1, 2].set_title("Smooth Gradient")
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Create a custom LRP analyzer that tracks layer-wise relevance
class LayerTrackingLRP(LRPZ):
    def __init__(self, model, tracked_layers=None):
        super().__init__(model)
        self.tracked_layers = tracked_layers or []
        self.layer_relevances = {}
        
        # Register hooks for tracking
        self.hooks = []
        for name, module in model.named_modules():
            if any(tracked in name for tracked in self.tracked_layers):
                hook = module.register_forward_hook(self._hook_fn(name))
                self.hooks.append(hook)
    
    def _hook_fn(self, name):
        def hook(module, input, output):
            self.layer_relevances[name] = output.detach()
        return hook
    
    def __del__(self):
        for hook in self.hooks:
            hook.remove()

# For VGG16, track feature layers at different depths
tracked_layers = [
    'features.1',    # Early layer (after first ReLU)
    'features.15',   # Middle layer
    'features.28',   # Deep layer (last feature layer)
    'classifier.3'   # Final classifier layer
]

# Create the layer tracking analyzer
tracking_lrp = LayerTrackingLRP(model_no_softmax, tracked_layers=tracked_layers)

# Generate explanation
tracking_explanation = tracking_lrp.attribute(preprocessed_img, target=predicted_class)

## 8. Conclusion

In this advanced tutorial, we explored several sophisticated aspects of using SignXAI2 with PyTorch:

1. **LRP Variants**: We demonstrated how to use different LRP methods (Z-rule, epsilon-rule, alpha-beta rules, and SIGN variants) through the unified API.

2. **Advanced Parameters**: We showed how to customize explanation methods with various parameters like noise levels for SmoothGrad, step counts for Integrated Gradients, and layer selection for Grad-CAM.

3. **Custom Models**: We demonstrated that SignXAI2's unified API works seamlessly with custom model architectures.

4. **Ensemble Methods**: We showed how to combine multiple explanation methods to get more robust insights.

5. **Advanced Visualization**: We explored different visualization techniques including various colormaps, transparency settings, masking, and thresholding approaches.

The unified API in SignXAI2 makes it easy to:
- Switch between different explanation methods
- Compare results across methods
- Work with both pre-trained and custom models
- Customize parameters for fine-tuned explanations
- Use the same interface for both TensorFlow and PyTorch

For more information and to contribute to the project, visit the [SignXAI2 GitHub repository](https://github.com/IRISlaboratory/signxai2).