# SignXAI2 with PyTorch - Basic Usage (VGG16)

This tutorial demonstrates how to use the SignXAI2 package with PyTorch to explain a VGG16 image classification model. We'll walk through:

1. Setting up the environment
2. Loading a pre-trained VGG16 model and sample image
3. Generating explanations using the unified API
4. Visualizing and comparing the results

## Setup Requirements

**Important**: SignXAI2 requires Python 3.9 or 3.10 (Python 3.11+ is not supported)

Since you're running this tutorial, you should already have cloned the signxai2 repository. From the repository root directory:

### Using conda:
```bash
# Create environment with Python 3.10
conda create -n signxai2 python=3.10
conda activate signxai2

# Install PyTorch dependencies only
pip install -r requirements/common.txt -r requirements/pytorch.txt

# Download models and example data
git lfs install
bash ./prepare.sh
```

### Using venv:
```bash
# Create virtual environment
python3.10 -m venv signxai2_env
source signxai2_env/bin/activate  # On Windows: signxai2_env\Scripts\activate

# Install PyTorch dependencies only
pip install -r requirements/common.txt -r requirements/pytorch.txt

# Download models and example data
git lfs install
bash ./prepare.sh
```

Let's get started!

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# SignXAI2 unified API imports
from signxai import explain, list_methods

## 2. Set Up Paths

In [None]:
# Set up data and model paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
VGG16_MODEL_PATH = os.path.join(DATA_DIR, "models", "pytorch", "VGG16", "vgg16_ported_weights.pt")
IMAGE_PATH = os.path.join(DATA_DIR, "images", "example.jpg")

# Check that files exist
assert os.path.exists(VGG16_MODEL_PATH), f"VGG16 model not found at {VGG16_MODEL_PATH}"
assert os.path.exists(IMAGE_PATH), f"Image not found at {IMAGE_PATH}"

## 3. Load VGG16 Model and Image

In [None]:
# Import the VGG16 model definition
sys.path.append(os.path.join(DATA_DIR, "models", "pytorch", "VGG16"))
from VGG16 import VGG16_PyTorch

# Load the VGG16 model
try:
    # Initialize the model architecture
    vgg16_model = VGG16_PyTorch(num_classes=1000)
    # Load the pre-trained weights
    vgg16_model.load_state_dict(torch.load(VGG16_MODEL_PATH, map_location=torch.device('cpu')))
    print("Loaded pre-saved VGG16 model weights")
except Exception as e:
    # If that fails, load the model from torchvision
    print(f"Could not load saved model: {e}\nLoading from torchvision instead")
    import torchvision.models as models
    vgg16_model = models.vgg16(pretrained=True)
    
vgg16_model.eval()  # Set to evaluation mode
print(f"Model loaded successfully. Type: {type(vgg16_model).__name__}")

In [None]:
# Function to load and preprocess image for VGG16
def load_image(image_path, resize_dim=(224, 224)):
    # Load image
    img = Image.open(image_path)
    
    # Display original image
    plt.figure(figsize=(6, 6))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Original Image')
    plt.show()
    
    # Define preprocessing for VGG16
    preprocess = transforms.Compose([
        transforms.Resize(resize_dim),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Preprocess image
    input_tensor = preprocess(img)
    
    # Create a mini-batch as expected by the model
    input_batch = input_tensor.unsqueeze(0)
    
    return img, input_batch

# Load and preprocess the image
original_img, preprocessed_img = load_image(IMAGE_PATH)

## 4. Predict the Class with VGG16

In [None]:
# Make a prediction with VGG16
with torch.no_grad():
    output = vgg16_model(preprocessed_img)

# Get the predicted class index
_, predicted_idx = torch.max(output, 1)
predicted_class = predicted_idx.item()

# Load class labels (ImageNet classes for VGG16)
try:
    import json
    with open(os.path.join(DATA_DIR, "imagenet_class_index.json")) as f:
        class_idx = json.load(f)
    class_name = class_idx[str(predicted_class)][1]
except:
    # If class file doesn't exist, just use the class index
    class_name = f"Class {predicted_class}"

print(f"Predicted class: {class_name} (index: {predicted_class})")

# Make a prediction with VGG16
with torch.no_grad():
    output = vgg16_model(preprocessed_img)

# Get the predicted class index
_, predicted_idx = torch.max(output, 1)
predicted_class = predicted_idx.item()

# Load class labels (ImageNet classes for VGG16)
try:
    import json
    imagenet_class_path = os.path.join(DATA_DIR, "imagenet_class_index.json")
    if os.path.exists(imagenet_class_path):
        with open(imagenet_class_path) as f:
            class_idx = json.load(f)
        class_name = class_idx[str(predicted_class)][1]
    else:
        # Try to decode using torchvision if available
        from torchvision.models import VGG16_Weights
        weights = VGG16_Weights.DEFAULT
        class_name = weights.meta["categories"][predicted_class]
except:
    # If class file doesn't exist, just use the class index
    class_name = f"Class {predicted_class}"

print(f"Predicted class: {class_name} (index: {predicted_class})")

In [None]:
# List all available methods in SignXAI2
print("Available methods in SignXAI2:")
available_methods = list_methods()
print(f"Total methods: {len(available_methods)}")
print("\nSome common methods:")
for method in ['gradient', 'gradient_x_input', 'grad_cam', 'guided_backprop', 'lrp_z']:
    if method in available_methods:
        print(f"  - {method}")

In [None]:
## 5. Generate Explanations with SignXAI2

Now let's use SignXAI2's unified API to explain the VGG16 model's prediction using different methods.

In [None]:
# Define the methods we want to use for VGG16 explanation
# Using the unified API method names
methods_to_test = [
    'gradient',
    'gradient_x_input',
    'grad_cam',
    'guided_backprop', 
    'lrp_z',
    'lrp_epsilon_0_1'
]

# Find the target layer name for GradCAM (last convolutional layer in VGG16)
# For PyTorch VGG16, it's typically the last conv layer in features
target_layer_name = None
for name, module in vgg16_model.features.named_children():
    if isinstance(module, nn.Conv2d):
        target_layer_name = f"features.{name}"

print(f"Target layer for GradCAM: {target_layer_name}")

# Additional parameters for specific methods
method_params = {
    'grad_cam': {'target_layer': target_layer_name}  # PyTorch uses 'target_layer' instead of 'layer_name'
}

# Storage for explanations
explanations = {}

# Generate explanations using the unified API
for method_name in methods_to_test:
    print(f"Generating {method_name} explanation...")
    
    # Get method-specific parameters
    params = method_params.get(method_name, {})
    
    # Use the unified explain API
    explanation = explain(
        model=vgg16_model,
        x=preprocessed_img,
        method_name=method_name,
        target_class=predicted_class,
        **params
    )
    
    explanations[method_name] = explanation
    
print("All explanations generated!")

In [None]:
# Utility function for prettier visualization
def preprocess_explanation(explanation):
    # Handle PyTorch tensor shape [B, C, H, W]
    if explanation.ndim == 4:
        explanation = np.transpose(explanation[0], (1, 2, 0))  # [H, W, C]
    elif explanation.ndim == 3 and explanation.shape[0] == 3:  # If it's [C, H, W]
        explanation = np.transpose(explanation, (1, 2, 0))      # [H, W, C]
    
    # Use absolute values for gradient-based methods
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

In [None]:
# Utility function for prettier visualization
def preprocess_explanation(explanation):
    # Convert to numpy if it's a torch tensor
    if torch.is_tensor(explanation):
        explanation = explanation.detach().cpu().numpy()
    
    # Handle PyTorch tensor shape [B, C, H, W]
    if explanation.ndim == 4:
        explanation = np.transpose(explanation[0], (1, 2, 0))  # [H, W, C]
    elif explanation.ndim == 3 and explanation.shape[0] == 3:  # If it's [C, H, W]
        explanation = np.transpose(explanation, (1, 2, 0))      # [H, W, C]
    
    # Convert single channel to RGB
    if explanation.ndim == 2:
        explanation = np.expand_dims(explanation, axis=-1)
        explanation = np.repeat(explanation, 3, axis=-1)
    elif explanation.shape[-1] == 1:
        explanation = np.repeat(explanation, 3, axis=-1)
    
    # Use absolute values for gradient-based methods
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

# Convert PIL image to numpy for visualization
np_img = np.array(original_img.resize((224, 224)))

# Create figure for visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, (method_name, explanation) in enumerate(explanations.items()):
    processed_explanation = preprocess_explanation(explanation)
    
    # Simple overlay visualization
    axes[i].imshow(np_img)
    axes[i].imshow(processed_explanation, alpha=0.5, cmap='hot')
    axes[i].set_title(method_name.replace('_', ' ').title())
    axes[i].axis('off')

plt.suptitle(f"VGG16 Explanations for class: {class_name}", fontsize=16)
plt.tight_layout()
plt.subplots_adjust(top=0.9)  # Adjust for the suptitle
plt.show()

In [None]:
# Import the TensorFlow-style API
from signxai.torch_signxai import tf_calculate_relevancemap

# Use the TensorFlow-style API with PyTorch model
methods_tf_style = ['gradient', 'gradient_x_input', 'guided_backprop', 'lrp_z', 'lrp_epsilon_0_1']
explanations_tf_style = {}

for method in methods_tf_style:
    print(f"Generating explanation using TF-style API: {method}...")
    
    # Using the TensorFlow-style API with PyTorch model
    explanation = tf_calculate_relevancemap(method, preprocessed_img, vgg16_model_no_softmax)
    
    # Store explanation
    explanations_tf_style[method] = explanation

## 8. Conclusion

In this tutorial, we've learned how to:
- Set up SignXAI2 with PyTorch
- Load and prepare a pre-trained VGG16 model for explanation
- Apply various explainability methods using the unified API
- Visualize and interpret the results

The SignXAI2 unified API makes it easy to:
- Use the same interface for both TensorFlow and PyTorch models
- Switch between different explanation methods effortlessly
- Apply framework-specific optimizations automatically

VGG16 is an excellent model for explainability demonstrations because of its straightforward architecture. The clear convolutional structure makes it easier to interpret how different parts of the image influence the model's predictions.

For more advanced techniques and detailed explanations of other models, check out the other tutorials in this series.