# SignXAI-Torch Basic Usage Example

This notebook demonstrates the usage of the SignXAI-Torch package with different visualization methods.

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from signxai_torch.methods.wrapper import calculate_relevancemap
from signxai_torch.utils.visualization import normalize_attribution

%matplotlib inline
plt.rcParams['figure.figsize'] = [15, 10]

## Load and Preprocess Image

In [None]:
def load_and_preprocess_image(image_path):
    # Load image
    img = Image.open(image_path)
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Transform image
    input_tensor = transform(img)
    input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
    
    return img, input_tensor

# Load image
img, input_tensor = load_and_preprocess_image('testimage.jpg')

# Display original image
plt.imshow(img)
plt.axis('off')
plt.title('Original Image')
plt.show()

## Load Model

In [None]:
# Load pretrained VGG16
model = models.vgg16(pretrained=True)
model.eval()

# Remove softmax from last layer
model.classifier[-1] = nn.Linear(in_features=4096, out_features=1000, bias=True)

## Apply Different Methods

In [None]:
def visualize_attribution(attribution, title):
    plt.figure(figsize=(5, 5))
    plt.imshow(normalize_attribution(attribution[0].sum(dim=0).cpu().numpy()), 
               cmap='seismic', 
               clim=(-1, 1))
    plt.title(title)
    plt.axis('off')

# List of methods to demonstrate
methods = {
    'gradient': 'Basic gradient',
    'gradient_x_input': 'Input × Gradient',
    'gradient_x_sign_mu_0': 'SIGN (μ=0)',
    'guided_backprop': 'Guided Backprop',
    'smoothgrad': 'SmoothGrad',
    'grad_cam_VGG16ILSVRC': 'Grad-CAM',
    'integrated_gradients': 'Integrated Gradients',
    'lrp_z': 'LRP-Z',
    'lrp_epsilon_0_1': 'LRP-ε (ε=0.1)',
    'lrp_alpha_1_beta_0': 'LRP-αβ (α=1,β=0)',
    'lrpsign_epsilon_0_1': 'LRP-SIGN (ε=0.1)'
}

# Create visualizations for each method
rows = (len(methods) + 2) // 3
fig, axes = plt.subplots(rows, 3, figsize=(15, 5*rows))
axes = axes.ravel()

# Show original image in first subplot
axes[0].imshow(np.array(img))
axes[0].set_title("Original Image")
axes[0].axis('off')

# Apply each method and visualize
for i, (method_name, method_title) in enumerate(methods.items(), start=1):
    try:
        with torch.enable_grad():
            attribution = calculate_relevancemap(method_name, input_tensor.detach().requires_grad_(True), model)
            if attribution is not None:
                # Compute visualization
                attribution_vis = normalize_attribution(
                    attribution[0].detach().sum(dim=0).cpu().numpy(),
                    symmetric=True  # Use symmetric normalization for better visualization
                )
                # Plot
                axes[i].imshow(attribution_vis, cmap='seismic', clim=(-1, 1))
                axes[i].set_title(method_title)
            else:
                axes[i].text(0.5, 0.5, 'Failed', horizontalalignment='center')
    except Exception as e:
        print(f"Failed to compute {method_title}: {str(e)}")
        axes[i].text(0.5, 0.5, 'Error', horizontalalignment='center')
    axes[i].axis('off')

# Hide empty subplots
for i in range(len(methods) + 1, len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Compare SIGN Variants

In [None]:
# Compare different mu values for SIGN
sign_methods = [
    'gradient_x_sign_mu_neg_0_5',
    'gradient_x_sign_mu_0',
    'gradient_x_sign_mu_0_5'
]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, method in enumerate(sign_methods):
    attribution = calculate_relevancemap(method, input_tensor, model)
    attribution_vis = normalize_attribution(attribution[0].sum(dim=0).cpu().numpy())
    axes[i].imshow(attribution_vis, cmap='seismic', clim=(-1, 1))
    axes[i].set_title(method)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Compare LRP Variants

In [None]:
# Compare different LRP variants
lrp_methods = [
    'lrp_z',
    'lrp_epsilon_0_1',
    'lrp_epsilon_1',
    'lrpsign_epsilon_0_1',
    'lrp_alpha_1_beta_0',
    'lrp_sequential_composite_a'
]

rows = (len(lrp_methods) + 2) // 3
fig, axes = plt.subplots(rows, 3, figsize=(15, 5*rows))
axes = axes.ravel()

for i, method in enumerate(lrp_methods):
    attribution = calculate_relevancemap(method, input_tensor, model)
    attribution_vis = normalize_attribution(attribution[0].sum(dim=0).cpu().numpy())
    axes[i].imshow(attribution_vis, cmap='seismic', clim=(-1, 1))
    axes[i].set_title(method)
    axes[i].axis('off')

for i in range(len(lrp_methods), len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.show()