# SignXAI2 with PyTorch - Basic Usage

This tutorial demonstrates how to use SignXAI2 with PyTorch models using the new dynamic method parsing approach.

## Key Features:
- **Dynamic Method Parsing**: Parameters are embedded directly in method names
- **Unified API**: Same interface for both TensorFlow and PyTorch
- **No Wrappers**: Direct method calls without wrapper functions

## Setup Requirements

```bash
# Install SignXAI2 with PyTorch support
pip install signxai2[pytorch]
```

## 1. Import Libraries

In [None]:
import warnings
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms

# Suppress deprecation warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)

# Import the unified SignXAI API
from signxai.api import explain

## 2. Load Pre-trained Model

In [None]:
# Load a pre-trained VGG16 model
print("Loading VGG16 model...")
model = models.vgg16(pretrained=True)
model.eval()
print("Model loaded successfully!")

## 3. Load and Preprocess Image

In [None]:
# Load an image
img_path = 'examples/data/images/example.jpg'  # Update with your image path
img = Image.open(img_path).convert('RGB')

# Display the original image
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title('Original Image')
plt.axis('off')
plt.show()

# Preprocess for VGG16
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(img).unsqueeze(0)  # Add batch dimension
print(f"Input tensor shape: {input_tensor.shape}")

## 4. Get Model Prediction

In [None]:
# Get model prediction
with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

# Get the predicted class
predicted_idx = torch.argmax(output, dim=1)
top_probs, top_idxs = torch.topk(probabilities, 5)

print("Top 5 predictions:")
for i, (prob, idx) in enumerate(zip(top_probs, top_idxs)):
    print(f"  {i+1}. Class {idx.item()}: {prob.item()*100:.2f}%")

print(f"\nUsing class {predicted_idx.item()} for explanation")

## 5. Generate Explanations with Dynamic Method Parsing

### New Feature: Dynamic Method Parsing
Parameters are now embedded directly in the method name. For example:
- `gradient` - Basic gradient
- `gradient_x_input` - Gradient × Input
- `gradient_x_input_x_sign_mu_neg_0_5` - Complex combination with parameter
- `smoothgrad_noise_0_3_samples_50` - SmoothGrad with custom parameters
- `lrp_epsilon_0_25` - LRP with epsilon=0.25

In [None]:
# Example 1: Basic gradient
method = "gradient"
print(f"Calculating explanation using: {method}")

explanation_gradient = explain(
    model,
    input_tensor,
    method_name=method,
    target_class=predicted_idx.item()
)

print(f"Explanation shape: {explanation_gradient.shape}")

In [None]:
# Example 2: Gradient with transformations and parameters
method = "gradient_x_input_x_sign_mu_neg_0_5"
print(f"Calculating explanation using: {method}")

explanation_complex = explain(
    model,
    input_tensor,
    method_name=method,
    target_class=predicted_idx.item()
)

print(f"Explanation shape: {explanation_complex.shape}")

In [None]:
# Example 3: SmoothGrad with custom parameters
method = "smoothgrad_noise_0_3_samples_50"
print(f"Calculating explanation using: {method}")

explanation_smoothgrad = explain(
    model,
    input_tensor,
    method_name=method,
    target_class=predicted_idx.item()
)

print(f"Explanation shape: {explanation_smoothgrad.shape}")

In [None]:
# Example 4: Multiple methods with dynamic parameters
methods = [
    "gradient",
    "gradient_x_input",
    "gradient_x_sign",
    "smoothgrad",
    "smoothgrad_noise_0_1_samples_25",
    "integrated_gradients",
    "integrated_gradients_steps_100",
    "guided_backprop",
    "deconvnet",
    "lrp_epsilon_0_25",
    "lrp_epsilon_50_x_sign",
    "lrp_alpha_2_beta_1"
]

explanations = {}
for method in methods:
    try:
        print(f"Calculating: {method}")
        explanation = explain(
            model,
            input_tensor,
            method_name=method,
            target_class=predicted_idx.item()
        )
        explanations[method] = explanation
    except Exception as e:
        print(f"  Failed: {e}")

print(f"\nSuccessfully calculated {len(explanations)} explanations")

## 6. Visualization Helper Functions

In [None]:
def process_explanation(explanation):
    """Process explanation for visualization."""
    # Convert to numpy if needed
    if hasattr(explanation, 'detach'):
        explanation_np = explanation.detach().cpu().numpy()
    else:
        explanation_np = explanation
    
    # Remove batch dimension if present
    if explanation_np.ndim == 4:
        explanation_np = explanation_np[0]
    
    # Sum over channels to create 2D heatmap
    if explanation_np.ndim == 3:
        heatmap = explanation_np.sum(axis=0)
    else:
        heatmap = explanation_np
    
    # Normalize for visualization
    abs_max = np.max(np.abs(heatmap))
    if abs_max > 0:
        normalized = heatmap / abs_max
    else:
        normalized = heatmap
    
    return normalized

def visualize_explanations(explanations, original_img, cols=3):
    """Visualize multiple explanations in a grid."""
    n_methods = len(explanations)
    rows = (n_methods + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows))
    axes = axes.flatten() if n_methods > 1 else [axes]
    
    # Convert original image for display
    img_np = np.array(original_img.resize((224, 224))) / 255.0
    
    for idx, (method_name, explanation) in enumerate(explanations.items()):
        heatmap = process_explanation(explanation)
        
        axes[idx].imshow(img_np)
        axes[idx].imshow(heatmap, cmap='seismic', alpha=0.5, clim=(-1, 1))
        axes[idx].set_title(method_name.replace('_', ' ').title())
        axes[idx].axis('off')
    
    # Hide empty subplots
    for idx in range(n_methods, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

## 7. Visualize All Explanations

In [None]:
# Visualize all calculated explanations
if explanations:
    visualize_explanations(explanations, img, cols=4)
else:
    print("No explanations to visualize")

## 8. Compare Different Parameter Settings

Let's compare how different parameters affect the same base method:

In [None]:
# Compare SmoothGrad with different noise levels and sample counts
smoothgrad_variants = [
    "smoothgrad",  # Default parameters
    "smoothgrad_noise_0_1_samples_25",
    "smoothgrad_noise_0_3_samples_50",
    "smoothgrad_noise_0_5_samples_100"
]

smoothgrad_explanations = {}
for method in smoothgrad_variants:
    print(f"Calculating: {method}")
    explanation = explain(
        model,
        input_tensor,
        method_name=method,
        target_class=predicted_idx.item()
    )
    smoothgrad_explanations[method] = explanation

# Visualize SmoothGrad variants
visualize_explanations(smoothgrad_explanations, img, cols=2)

In [None]:
# Compare LRP with different epsilon values
lrp_variants = [
    "lrp_epsilon_0_01",
    "lrp_epsilon_0_1",
    "lrp_epsilon_0_25",
    "lrp_epsilon_1",
    "lrp_epsilon_10"
]

lrp_explanations = {}
for method in lrp_variants:
    try:
        print(f"Calculating: {method}")
        explanation = explain(
            model,
            input_tensor,
            method_name=method,
            target_class=predicted_idx.item()
        )
        lrp_explanations[method] = explanation
    except Exception as e:
        print(f"  Failed: {e}")

# Visualize LRP variants
if lrp_explanations:
    visualize_explanations(lrp_explanations, img, cols=3)

## 9. Advanced: Method Combinations

The dynamic parsing allows complex method combinations:

In [None]:
# Complex method combinations
complex_methods = [
    "gradient_x_input_x_sign_mu_neg_0_5",
    "lrp_epsilon_50_x_sign",
    "lrpsign_epsilon_0_25_std_x"
]

complex_explanations = {}
for method in complex_methods:
    try:
        print(f"Calculating: {method}")
        explanation = explain(
            model,
            input_tensor,
            method_name=method,
            target_class=predicted_idx.item()
        )
        complex_explanations[method] = explanation
    except Exception as e:
        print(f"  Failed: {e}")

# Visualize complex combinations
if complex_explanations:
    visualize_explanations(complex_explanations, img, cols=3)

## Summary

### Key Takeaways:

1. **Dynamic Method Parsing**: Parameters are embedded in method names (e.g., `smoothgrad_noise_0_3_samples_50`)

2. **Unified API**: Single `explain()` function works for all methods

3. **No Wrappers**: Direct method calls without intermediate wrapper functions

4. **Flexible Parameters**: Easy to experiment with different parameter values

### Method Name Format:
```
base_method[_param_value][_operation][_param_value]...
```

### Examples:
- `gradient` - Basic gradient
- `gradient_x_input` - Gradient multiplied by input
- `smoothgrad_noise_0_3_samples_50` - SmoothGrad with noise=0.3, samples=50
- `lrp_epsilon_0_25` - LRP with epsilon=0.25
- `integrated_gradients_steps_100` - Integrated Gradients with 100 steps

### Next Steps:
- Try the advanced tutorial for more complex use cases
- Explore time series explanations with ECG data
- Compare TensorFlow and PyTorch implementations