# SignXAI2 with TensorFlow - Advanced Usage

This tutorial demonstrates advanced techniques for using SignXAI2 with TensorFlow models. It builds on the basic tutorial and explores more sophisticated explainability methods and customizations.

## Setup Requirements

**Important**: SignXAI2 requires Python 3.9 or 3.10 (Python 3.11+ is not supported)

Since you're running this tutorial, you should already have cloned the signxai2 repository. From the repository root directory:

### Using conda:
```bash
# Create environment with Python 3.10
conda create -n signxai2 python=3.10
conda activate signxai2

# Install SignXAI2 with TensorFlow support
pip install signxai2[tensorflow]

# Download models and example data
git lfs install
bash ./prepare.sh
```

### Using venv:
```bash
# Create virtual environment
python3.10 -m venv signxai2_env
source signxai2_env/bin/activate  # On Windows: signxai2_env\Scripts\activate

# Install SignXAI2 with TensorFlow support
pip install signxai2[tensorflow]

# Download models and example data
git lfs install
bash ./prepare.sh
```

## Overview

In this tutorial, we'll cover:
1. Using different LRP variants with the unified API
2. Advanced method parameters and configurations
3. Working with custom model architectures
4. Combining multiple explanation methods
5. Advanced visualization techniques

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions

# SignXAI2 unified API imports
from signxai import explain, list_methods, get_method_info
from signxai.utils.utils import load_image

## 2. Set Up Paths and Load Data

In [None]:
# Set up data and model paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
VGG16_MODEL_PATH = os.path.join(DATA_DIR, "models", "tensorflow", "VGG16", "model.h5")
IMAGE_PATH = os.path.join(DATA_DIR, "images", "example.jpg")

# Verify paths
assert os.path.exists(VGG16_MODEL_PATH), f"VGG16 model not found at {VGG16_MODEL_PATH}"
assert os.path.exists(IMAGE_PATH), f"Image not found at {IMAGE_PATH}"

# Load model
try:
    model = tf.keras.models.load_model(VGG16_MODEL_PATH)
    print("Loaded pre-saved VGG16 model")
except Exception as e:
    print(f"Could not load saved model: {e}\nLoading from Keras applications instead")
    model = tf.keras.applications.VGG16(weights='imagenet', include_top=True)

# Load and display image
original_img, preprocessed_img = load_image(IMAGE_PATH, target_size=(224, 224), expand_dims=True)

plt.figure(figsize=(6, 6))
plt.imshow(original_img)
plt.axis('off')
plt.title('Sample Image')
plt.show()

## 3. Using Different LRP Variants with the Unified API

SignXAI2 provides various LRP (Layer-wise Relevance Propagation) methods through the unified API. Let's explore different LRP rules and configurations.

In [None]:
# Get a prediction for reference
predictions = model.predict(preprocessed_img)
predicted_class = np.argmax(predictions[0])

# Decode the prediction
try:
    decoded_predictions = decode_predictions(predictions, top=3)[0]
    print("Top 3 predictions:")
    for i, (imagenet_id, label, score) in enumerate(decoded_predictions):
        print(f"{i+1}: {label} ({score:.4f})")
    class_name = decoded_predictions[0][1]
except:
    class_name = f"Class {predicted_class}"
    
print(f"\nExplaining prediction: {class_name} (class index: {predicted_class})")

# Different LRP variants available in SignXAI2
lrp_methods = [
    'lrp_z',
    'lrp_epsilon_0_001',
    'lrp_epsilon_0_01', 
    'lrp_epsilon_0_1',
    'lrp_epsilon_0_2',
    'lrp_alpha_1_beta_0',
    'lrp_alpha_2_beta_1',
    'lrpsign_z',
    'lrpsign_epsilon_0_1'
]

# Generate explanations using different LRP variants
lrp_explanations = {}
for method in lrp_methods:
    if method in list_methods():
        print(f"Generating {method} explanation...")
        explanation = explain(
            model=model,
            x=preprocessed_img,
            method_name=method,
            target_class=predicted_class
        )
        lrp_explanations[method] = explanation

In [None]:
# Utility function for visualization
def preprocess_explanation(explanation):
    # Use absolute values for attribution
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

# Visualize all LRP variants
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

# Get numpy array of original image for visualization
np_img = np.array(original_img)

# Plot the original image in the first position
axes[0].imshow(np_img)
axes[0].set_title("Original Image")
axes[0].axis('off')

# Plot all LRP variants in the remaining positions
for i, (method_name, explanation) in enumerate(lrp_explanations.items(), 1):
    if i < len(axes):
        visualize_attribution(np_img, preprocess_explanation(explanation), ax=axes[i])
        axes[i].set_title(method_name)
        axes[i].axis('off')

plt.tight_layout()
plt.show()

# Utility function for visualization
def preprocess_explanation(explanation):
    # Remove batch dimension if present
    if explanation.ndim == 4:
        explanation = explanation[0]
    
    # Convert to RGB if single channel
    if explanation.ndim == 2:
        explanation = np.expand_dims(explanation, axis=-1)
        explanation = np.repeat(explanation, 3, axis=-1)
    elif explanation.shape[-1] == 1:
        explanation = np.repeat(explanation, 3, axis=-1)
    
    # Use absolute values for attribution
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

# Visualize all LRP variants
n_methods = len(lrp_explanations)
n_cols = 3
n_rows = (n_methods + 1) // n_cols + 1  # +1 for original image

fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 6*n_rows))
axes = axes.flatten()

# Plot the original image in the first position
axes[0].imshow(original_img)
axes[0].set_title("Original Image")
axes[0].axis('off')

# Plot all LRP variants
for i, (method_name, explanation) in enumerate(lrp_explanations.items(), 1):
    if i < len(axes):
        processed_explanation = preprocess_explanation(explanation)
        axes[i].imshow(original_img)
        axes[i].imshow(processed_explanation, alpha=0.5, cmap='hot')
        axes[i].set_title(method_name.replace('_', ' ').upper())
        axes[i].axis('off')

# Hide unused subplots
for i in range(len(lrp_explanations) + 1, len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
## 4. Advanced Method Parameters and Configurations

SignXAI2's unified API allows you to customize explanation methods with various parameters. Let's explore some advanced configurations.

In [None]:
# Example 1: SmoothGrad with different noise levels
smoothgrad_params = [
    {'noise_scale': 0.05, 'augment_by_n': 10},
    {'noise_scale': 0.1, 'augment_by_n': 25},
    {'noise_scale': 0.2, 'augment_by_n': 50}
]

smoothgrad_explanations = {}
for params in smoothgrad_params:
    label = f"SmoothGrad (noise={params['noise_scale']}, n={params['augment_by_n']})"
    print(f"Generating {label}...")
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name='smoothgrad',
        target_class=predicted_class,
        **params
    )
    smoothgrad_explanations[label] = explanation

# Example 2: Integrated Gradients with different step counts
ig_steps = [10, 25, 50]
ig_explanations = {}
for steps in ig_steps:
    label = f"Integrated Gradients ({steps} steps)"
    print(f"Generating {label}...")
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name='integrated_gradients',
        target_class=predicted_class,
        steps=steps
    )
    ig_explanations[label] = explanation

# Example 3: Grad-CAM with different layers
gradcam_layers = ['block4_conv3', 'block5_conv3']  # Different VGG16 layers
gradcam_explanations = {}
for layer in gradcam_layers:
    label = f"Grad-CAM ({layer})"
    print(f"Generating {label}...")
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name='grad_cam',
        target_class=predicted_class,
        layer_name=layer
    )
    gradcam_explanations[label] = explanation

# Visualize parameter variations
fig, axes = plt.subplots(3, 3, figsize=(18, 18))

# Row 1: SmoothGrad variations
for i, (label, explanation) in enumerate(smoothgrad_explanations.items()):
    processed = preprocess_explanation(explanation)
    axes[0, i].imshow(original_img)
    axes[0, i].imshow(processed, alpha=0.5, cmap='hot')
    axes[0, i].set_title(label)
    axes[0, i].axis('off')

# Row 2: Integrated Gradients variations
for i, (label, explanation) in enumerate(ig_explanations.items()):
    processed = preprocess_explanation(explanation)
    axes[1, i].imshow(original_img)
    axes[1, i].imshow(processed, alpha=0.5, cmap='hot')
    axes[1, i].set_title(label)
    axes[1, i].axis('off')

# Row 3: Grad-CAM variations
for i, (label, explanation) in enumerate(gradcam_explanations.items()):
    processed = preprocess_explanation(explanation)
    axes[2, i].imshow(original_img)
    axes[2, i].imshow(processed, alpha=0.5, cmap='hot')
    axes[2, i].set_title(label)
    axes[2, i].axis('off')

# Hide unused subplot
axes[2, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
## 5. Working with Custom Model Architectures

SignXAI2's unified API works seamlessly with custom TensorFlow model architectures. Let's create a simple custom model and explain its predictions.

In [None]:
# Make a prediction with the custom model
# (Since we didn't train it, it'll just make random predictions)
custom_prediction = custom_model.predict(preprocessed_img)
custom_predicted_class = np.argmax(custom_prediction[0])
print(f"Custom model prediction: Class {custom_predicted_class}")

# Apply different explainability methods to the custom model
methods = {
    "Gradient": SIGN(custom_model_no_softmax),
    "GradCAM": GradCAM(custom_model, layer_name='conv2d_2'),  # Last conv layer
    "GuidedBackprop": GuidedBackprop(custom_model_no_softmax),
}

custom_explanations = {}
for name, method in methods.items():
    explanation = method.attribute(preprocessed_img, target_class=custom_predicted_class).numpy()
    custom_explanations[name] = explanation

In [None]:
# Make a prediction with the custom model
# (Since we didn't train it, it'll just make random predictions)
custom_prediction = custom_model.predict(preprocessed_img)
custom_predicted_class = np.argmax(custom_prediction[0])
print(f"Custom model prediction: Class {custom_predicted_class}")

# Apply different explainability methods to the custom model using the unified API
custom_methods = {
    'gradient': {},
    'grad_cam': {'layer_name': 'conv2d_2'},  # Last conv layer
    'guided_backprop': {},
    'smoothgrad': {'noise_scale': 0.1, 'augment_by_n': 25}
}

custom_explanations = {}
for method_name, params in custom_methods.items():
    print(f"Generating {method_name} explanation for custom model...")
    explanation = explain(
        model=custom_model,
        x=preprocessed_img,
        method_name=method_name,
        target_class=custom_predicted_class,
        **params
    )
    custom_explanations[method_name] = explanation

# Visualize explanations for the custom model
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for i, (name, explanation) in enumerate(custom_explanations.items()):
    processed = preprocess_explanation(explanation)
    axes[i].imshow(original_img)
    axes[i].imshow(processed, alpha=0.5, cmap='hot')
    axes[i].set_title(f"Custom Model: {name.replace('_', ' ').title()}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
## 6. Advanced Visualization Techniques

SignXAI2 provides various ways to visualize explanations. Let's explore different visualization options and techniques.

In [None]:
# Get a high-quality explanation to visualize
lrp_explanation = explain(
    model=model,
    x=preprocessed_img,
    method_name='lrp_epsilon_0_1',
    target_class=predicted_class
)
lrp_processed = preprocess_explanation(lrp_explanation)

# 1. Visualization with different color maps and overlays
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Different colormaps
colormaps = ['hot', 'jet', 'RdBu_r']
for i, cmap in enumerate(colormaps):
    axes[0, i].imshow(original_img)
    im = axes[0, i].imshow(lrp_processed, alpha=0.5, cmap=cmap)
    axes[0, i].set_title(f"{cmap} colormap")
    axes[0, i].axis('off')

# Row 2: Different overlay techniques
# Pure heatmap
axes[1, 0].imshow(lrp_processed, cmap='hot')
axes[1, 0].set_title("Pure Heatmap")
axes[1, 0].axis('off')

# Masked overlay (only show high relevance areas)
threshold = 0.3
mask = (lrp_processed.mean(axis=2) > threshold)
masked_img = original_img.copy()
masked_img[~mask] = masked_img[~mask] * 0.3  # Dim irrelevant areas
axes[1, 1].imshow(masked_img)
axes[1, 1].set_title(f"Masked (threshold={threshold})")
axes[1, 1].axis('off')

# Contour overlay
from skimage import measure
contours = measure.find_contours(lrp_processed.mean(axis=2), 0.3)
axes[1, 2].imshow(original_img)
for contour in contours:
    axes[1, 2].plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')
axes[1, 2].set_title("Contour Overlay")
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

## 7. Integrating with TensorFlow's GradientTape API

You can create your own custom explainability methods using TensorFlow's GradientTape for more flexibility.

In [None]:
# Visualize integrated gradients explanation compared to standard gradient
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Standard gradient
visualize_attribution(np_img, preprocess_explanation(gradient_explanation), ax=axes[0])
axes[0].set_title("Standard Gradient")
axes[0].axis('off')

# Integrated gradients
visualize_attribution(np_img, preprocess_explanation(ig_explanation), ax=axes[1])
axes[1].set_title("Integrated Gradients (20 steps)")
axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Generate multiple explanations
ensemble_methods = ['gradient', 'guided_backprop', 'lrp_epsilon_0_1', 'smoothgrad']
ensemble_explanations = {}

for method in ensemble_methods:
    print(f"Generating {method} explanation...")
    explanation = explain(
        model=model,
        x=preprocessed_img,
        method_name=method,
        target_class=predicted_class,
        noise_scale=0.1 if method == 'smoothgrad' else None,
        augment_by_n=25 if method == 'smoothgrad' else None
    )
    ensemble_explanations[method] = preprocess_explanation(explanation)

# Create ensemble explanation by averaging
ensemble_avg = np.mean(list(ensemble_explanations.values()), axis=0)

# Create ensemble explanation by taking maximum
ensemble_max = np.max(list(ensemble_explanations.values()), axis=0)

# Visualize individual and ensemble explanations
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Individual methods
for i, (method, explanation) in enumerate(list(ensemble_explanations.items())[:3]):
    axes[0, i].imshow(original_img)
    axes[0, i].imshow(explanation, alpha=0.5, cmap='hot')
    axes[0, i].set_title(method.replace('_', ' ').title())
    axes[0, i].axis('off')

# Row 2: Last method and ensemble results
axes[1, 0].imshow(original_img)
axes[1, 0].imshow(list(ensemble_explanations.values())[3], alpha=0.5, cmap='hot')
axes[1, 0].set_title(list(ensemble_explanations.keys())[3].replace('_', ' ').title())
axes[1, 0].axis('off')

axes[1, 1].imshow(original_img)
axes[1, 1].imshow(ensemble_avg, alpha=0.5, cmap='hot')
axes[1, 1].set_title("Ensemble (Average)")
axes[1, 1].axis('off')

axes[1, 2].imshow(original_img)
axes[1, 2].imshow(ensemble_max, alpha=0.5, cmap='hot')
axes[1, 2].set_title("Ensemble (Maximum)")
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

## 8. Conclusion

In this advanced tutorial, we explored several sophisticated aspects of using SignXAI2 with TensorFlow:

1. **LRP Variants**: We demonstrated how to use different LRP methods (Z-rule, epsilon-rule, alpha-beta rules, and SIGN variants) through the unified API.

2. **Advanced Parameters**: We showed how to customize explanation methods with various parameters like noise levels for SmoothGrad, step counts for Integrated Gradients, and layer selection for Grad-CAM.

3. **Custom Models**: We demonstrated that SignXAI2's unified API works seamlessly with custom model architectures.

4. **Advanced Visualization**: We explored different visualization techniques including various colormaps, masking, and contour overlays.

5. **Ensemble Methods**: We showed how to combine multiple explanation methods to get more robust insights.

The unified API in SignXAI2 makes it easy to:
- Switch between different explanation methods
- Compare results across methods
- Work with both pre-trained and custom models
- Customize parameters for fine-tuned explanations

For more information and to contribute to the project, visit the [SignXAI2 GitHub repository](https://github.com/IRISlaboratory/signxai2).