# SignXAI with PyTorch - Basic Usage (VGG16)

This tutorial demonstrates how to use the SignXAI package with PyTorch to explain a VGG16 image classification model. We'll walk through:

1. Setting up the environment
2. Loading a pre-trained VGG16 model and sample image
3. Generating explanations using different methods
4. Visualizing and comparing the results

## Setup Requirements

For this PyTorch tutorial, you'll need to install SignXAI with PyTorch dependencies:

```bash
# For conda users
conda create -n signxai-pytorch python=3.8
conda activate signxai-pytorch
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/pytorch.txt

# Or for pip users
python -m venv signxai_pytorch_env
source signxai_pytorch_env/bin/activate  # On Windows: signxai_pytorch_env\Scripts\activate
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/pytorch.txt
```

Let's get started!

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# SignXAI imports
from signxai.torch_signxai.methods import SIGN, GradCAM, GuidedBackprop, LRPZ, LRPEpsilon
from signxai.common.visualization import visualize_attribution
from signxai.torch_signxai.utils import remove_softmax

## 2. Set Up Paths

In [None]:
# Set up data and model paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
VGG16_MODEL_PATH = os.path.join(DATA_DIR, "models", "pytorch", "VGG16", "vgg16_torch.pt")
IMAGE_PATH = os.path.join(DATA_DIR, "images", "example.jpg")

# Check that files exist
assert os.path.exists(VGG16_MODEL_PATH), f"VGG16 model not found at {VGG16_MODEL_PATH}"
assert os.path.exists(IMAGE_PATH), f"Image not found at {IMAGE_PATH}"

## 3. Load VGG16 Model and Image

In [None]:
# Load the VGG16 model
try:
    # First try loading our pre-saved model
    vgg16_model = torch.load(VGG16_MODEL_PATH)
    print("Loaded pre-saved VGG16 model")
except Exception as e:
    # If that fails, load the model from torchvision
    print(f"Could not load saved model: {e}\nLoading from torchvision instead")
    import torchvision.models as models
    vgg16_model = models.vgg16(pretrained=True)
    
vgg16_model.eval()  # Set to evaluation mode
print(vgg16_model)

In [None]:
# Remove softmax from the VGG16 model - crucial step for many explainability methods
vgg16_model_no_softmax = remove_softmax(vgg16_model)

In [None]:
# Function to load and preprocess image for VGG16
def load_image(image_path, resize_dim=(224, 224)):
    # Load image
    img = Image.open(image_path)
    
    # Display original image
    plt.figure(figsize=(6, 6))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Original Image')
    plt.show()
    
    # Define preprocessing for VGG16
    preprocess = transforms.Compose([
        transforms.Resize(resize_dim),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Preprocess image
    input_tensor = preprocess(img)
    
    # Create a mini-batch as expected by the model
    input_batch = input_tensor.unsqueeze(0)
    
    return img, input_batch

# Load and preprocess the image
original_img, preprocessed_img = load_image(IMAGE_PATH)

## 4. Predict the Class with VGG16

In [None]:
# Make a prediction with VGG16
with torch.no_grad():
    output = vgg16_model(preprocessed_img)

# Get the predicted class index
_, predicted_idx = torch.max(output, 1)
predicted_class = predicted_idx.item()

# Load class labels (ImageNet classes for VGG16)
try:
    import json
    with open(os.path.join(DATA_DIR, "imagenet_class_index.json")) as f:
        class_idx = json.load(f)
    class_name = class_idx[str(predicted_class)][1]
except:
    # If class file doesn't exist, just use the class index
    class_name = f"Class {predicted_class}"

print(f"Predicted class: {class_name} (index: {predicted_class})")

## 5. Generate Explanations with SignXAI

Now let's use SignXAI to explain the VGG16 model's prediction using different methods.

In [None]:
# Define the methods we want to use for VGG16 explanation
# Find the target layer for GradCAM (last convolutional layer in VGG16)
target_layer = None
for name, module in vgg16_model.features.named_children():
    if isinstance(module, nn.Conv2d):
        target_layer = module

# Initialize explainers
methods = {
    'Gradient': SIGN(vgg16_model_no_softmax),
    'Gradient × Input': SIGN(vgg16_model_no_softmax),  # We'll multiply by input later
    'GradCAM': GradCAM(vgg16_model_no_softmax, target_layer),
    'Guided Backprop': GuidedBackprop(vgg16_model_no_softmax),
    'LRP-Z': LRPZ(vgg16_model_no_softmax),
    'LRP-Epsilon': LRPEpsilon(vgg16_model_no_softmax, epsilon=0.1)
}

# Storage for explanations
explanations = {}

In [None]:
# Generate explanations for VGG16
for method_name, explainer in methods.items():
    print(f"Generating {method_name} explanation...")
    
    if method_name == 'Gradient × Input':
        # Special case for gradient × input
        grad = explainer.attribute(preprocessed_img, target=predicted_class).numpy()
        explanation = grad * preprocessed_img.numpy()
    else:
        explanation = explainer.attribute(preprocessed_img, target=predicted_class).numpy()
    
    explanations[method_name] = explanation
    
print("All explanations generated!")

## 6. Visualize VGG16 Explanations

In [None]:
# Utility function for prettier visualization
def preprocess_explanation(explanation):
    # Handle PyTorch tensor shape [B, C, H, W]
    if explanation.ndim == 4:
        explanation = np.transpose(explanation[0], (1, 2, 0))  # [H, W, C]
    elif explanation.ndim == 3 and explanation.shape[0] == 3:  # If it's [C, H, W]
        explanation = np.transpose(explanation, (1, 2, 0))      # [H, W, C]
    
    # Use absolute values for gradient-based methods
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

In [None]:
# Convert PIL image to numpy for visualization
np_img = np.array(original_img.resize((224, 224)))

# Create figure for visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, (method_name, explanation) in enumerate(explanations.items()):
    processed_explanation = preprocess_explanation(explanation)
    
    # Use SignXAI's visualization utility
    visualize_attribution(image=np_img, attribution=processed_explanation, ax=axes[i])
    axes[i].set_title(method_name)
    axes[i].axis('off')

plt.suptitle(f"VGG16 Explanations for class: {class_name}", fontsize=16)
plt.tight_layout()
plt.subplots_adjust(top=0.9)  # Adjust for the suptitle
plt.show()

## 7. Interpret the VGG16 Results

Let's interpret what we're seeing in these explanation methods for the VGG16 model:

- **Gradient**: Shows pixel-level importance via the gradient of the output with respect to input. In VGG16, this often highlights edges and textures.

- **Gradient × Input**: Enhances gradient by multiplication with input values. This tends to focus more on the regions where both the gradient and input values are high.

- **GradCAM**: Uses the last convolutional layer of VGG16 to produce a coarse localization map highlighting important regions for the prediction.

- **Guided Backprop**: Creates sharper feature visualizations by modifying the backpropagation signal through ReLU layers. It's particularly effective for VGG16 which has many ReLU activations.

- **LRP-Z**: Layer-wise Relevance Propagation with the Z-rule propagates the prediction backward through the network to identify relevant input features.

- **LRP-Epsilon**: A variant of LRP that adds a stabilizing term (epsilon) to avoid division by zero, producing slightly different attribution maps.

Each method highlights different aspects of how VGG16 processes the image to make its prediction.

## 8. Using SignXAI's TensorFlow API with PyTorch

One of the advantages of SignXAI is its compatibility API, which allows you to use the TensorFlow-style API even with PyTorch models. This is useful if you're migrating from TensorFlow to PyTorch or if you're more familiar with the TensorFlow API.

In [None]:
# Import the TensorFlow-style API
from signxai.torch_signxai import tf_calculate_relevancemap

# Use the TensorFlow-style API with PyTorch model
methods_tf_style = ['gradient', 'gradient_x_input', 'guided_backprop', 'lrp_z', 'lrp_epsilon_0_1']
explanations_tf_style = {}

for method in methods_tf_style:
    print(f"Generating explanation using TF-style API: {method}...")
    
    # Using the TensorFlow-style API with PyTorch model
    explanation = tf_calculate_relevancemap(method, preprocessed_img, vgg16_model_no_softmax)
    
    # Store explanation
    explanations_tf_style[method] = explanation

In [None]:
# Visualize explanations generated with TensorFlow-style API
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, (method_name, explanation) in enumerate(explanations_tf_style.items()):
    if i >= len(axes):
        break
        
    processed_explanation = preprocess_explanation(explanation)
    
    # Use SignXAI's visualization utility
    visualize_attribution(image=np_img, attribution=processed_explanation, ax=axes[i])
    axes[i].set_title(f"TF-API: {method_name}")
    axes[i].axis('off')

# Hide any unused axes
for j in range(i+1, len(axes)):
    axes[j].axis('off')

plt.suptitle(f"TensorFlow-style API with PyTorch VGG16", fontsize=16)
plt.tight_layout()
plt.subplots_adjust(top=0.9)  # Adjust for the suptitle
plt.show()

## 9. Conclusion

In this tutorial, we've learned how to:
- Set up SignXAI with PyTorch
- Load and prepare a pre-trained VGG16 model for explanation
- Apply various explainability methods to understand VGG16 predictions
- Visualize and interpret the results
- Use SignXAI's TensorFlow-style API with PyTorch models

VGG16 is an excellent model for explainability demonstrations because of its straightforward architecture. The clear convolutional structure makes it easier to interpret how different parts of the image influence the model's predictions.

For more advanced techniques and detailed explanations of other models, check out the other tutorials in this series.