# SignXAI2 PyTorch Advanced Tutorial - Image Classification

This advanced tutorial demonstrates sophisticated analysis techniques using SignXAI2 with PyTorch, including class-specific explanations and positive/negative contribution separation.

## Prerequisites

Complete the basic PyTorch tutorial first, and ensure you have the required data and model setup.

⚠️ **Data Requirements**: This tutorial requires example data from the GitHub repository.

## Setup and Basic Model Loading

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
from signxai import explain, list_methods
from signxai.utils.utils import normalize_heatmap
import urllib.request

# Download an example image
url = "https://farm1.staticflickr.com/148/414245159_7549a49046_z.jpg"
urllib.request.urlretrieve(url, "dog.jpg")

# Load the pre-trained model
model = models.vgg16(pretrained=True)
model.eval()

# Remove softmax layer (critical for explanations)
model.classifier[-1] = torch.nn.Identity()

# Load and preprocess the image
img_path = "dog.jpg"
img = Image.open(img_path).convert('RGB')

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(img).unsqueeze(0)  # Add batch dimension
img_np = np.array(img.resize((224, 224))) / 255.0  # For visualization

# Make prediction
with torch.no_grad():
    output = model(input_tensor)

# Get the predicted class
_, predicted_idx = torch.max(output, 1)
print(f"Predicted class index: {predicted_idx.item()}")

## Advanced Analysis

Let's compare class-specific explanations for PyTorch:

In [None]:
# Get top 3 predicted classes
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_probs, top_classes = torch.topk(probabilities, 3)

# Calculate explanations for each class using a complex method with parameter chaining
# This demonstrates: gradient (base) + x_input (multiply by input) + x_sign (apply sign) + mu_neg_0_5 (parameter)
class_explanations = {}
for idx in top_classes:
    class_explanations[idx.item()] = explain(
        model=model,
        x=input_tensor,
        method_name='gradient_x_input_x_sign_mu_neg_0_5',
        target_class=idx.item()
    )

# Visualize
fig, axs = plt.subplots(1, 4, figsize=(20, 5))

# Original image
axs[0].imshow(img_np)
axs[0].set_title('Original Image', fontsize=14)
axs[0].axis('off')

# Class-specific explanations
for i, idx in enumerate(top_classes):
    explanation = class_explanations[idx.item()][0].sum(axis=0)
    axs[i+1].imshow(normalize_heatmap(explanation), cmap='seismic', clim=(-1, 1))
    axs[i+1].set_title(f'Class: {idx.item()}', fontsize=14)
    axs[i+1].axis('off')

plt.tight_layout()
plt.show()

## Positive and Negative Contribution Separation

We can also highlight the positive and negative contributions separately:

In [None]:
# Choose a complex method with parameter chaining and generate explanation
# This showcases: gradient (base) + x_input (multiply by input) + x_sign (apply sign) + mu_neg_0_5 (parameter)
method = 'gradient_x_input_x_sign_mu_neg_0_5'
explanation = explain(
    model=model,
    x=input_tensor,
    method_name=method,
    target_class=predicted_idx.item()
)[0].sum(axis=0)  # Sum over channels

# Separate positive and negative contributions
pos_expl = np.maximum(0, explanation)
neg_expl = np.minimum(0, explanation)

# Normalize
pos_norm = pos_expl / np.max(pos_expl) if np.max(pos_expl) > 0 else pos_expl
neg_norm = neg_expl / np.min(neg_expl) if np.min(neg_expl) < 0 else neg_expl

# Visualize
fig, axs = plt.subplots(1, 4, figsize=(20, 5))

# Original image
axs[0].imshow(img_np)
axs[0].set_title('Original Image', fontsize=14)
axs[0].axis('off')

# Combined explanation
axs[1].imshow(normalize_heatmap(explanation), cmap='seismic', clim=(-1, 1))
axs[1].set_title(f'{method} - Combined', fontsize=14)
axs[1].axis('off')

# Positive contributions
axs[2].imshow(pos_norm, cmap='Reds')
axs[2].set_title('Positive Contributions', fontsize=14)
axs[2].axis('off')

# Negative contributions
axs[3].imshow(-neg_norm, cmap='Blues')
axs[3].set_title('Negative Contributions', fontsize=14)
axs[3].axis('off')

plt.tight_layout()
plt.show()

## Summary

In this advanced tutorial, we've demonstrated how to:

1. **Class-specific Analysis**: Generate explanations for different predicted classes to understand what features the model associates with each class
2. **Contribution Separation**: Separate positive and negative contributions to better understand how different regions support or oppose the prediction
3. **Advanced Visualization**: Create comprehensive visualizations that reveal different aspects of the model's decision-making process

These techniques provide deeper insights into model behavior and can help identify potential biases or areas for model improvement.