# SignXAI with TensorFlow - Advanced Usage

This tutorial demonstrates advanced techniques for using SignXAI with TensorFlow models. It builds on the basic tutorial and explores more sophisticated explainability methods and customizations.

## Setup Requirements

For this TensorFlow tutorial, you'll need to install SignXAI with TensorFlow dependencies:

```bash
# For conda users
conda create -n signxai-tensorflow python=3.8
conda activate signxai-tensorflow
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/tensorflow.txt

# Or for pip users
python -m venv signxai_tensorflow_env
source signxai_tensorflow_env/bin/activate  # On Windows: signxai_tensorflow_env\Scripts\activate
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/tensorflow.txt
```

## Overview

In this tutorial, we'll cover:
1. Customizing LRP rules with TensorFlow
2. Creating custom explainability methods
3. Working with custom model architectures
4. Advanced visualization techniques
5. Integrating with TensorFlow's GradientTape API

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.vgg16 import preprocess_input

# SignXAI imports - including advanced modules
from signxai.tf_signxai.methods import SIGN, GradCAM, GuidedBackprop
from signxai.tf_signxai.methods.innvestigate.analyzer import LRPZPlus, LRPEpsilon, LRPAlpha1Beta0
from signxai.common.visualization import visualize_attribution
from signxai.utils.utils import remove_softmax, load_image

## 2. Set Up Paths and Load Data

In [None]:
# Set up data and model paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
VGG16_MODEL_PATH = os.path.join(DATA_DIR, "models", "tensorflow", "VGG16", "vgg16_tf.h5")
IMAGE_PATH = os.path.join(DATA_DIR, "images", "example.jpg")

# Verify paths
assert os.path.exists(VGG16_MODEL_PATH), f"VGG16 model not found at {VGG16_MODEL_PATH}"
assert os.path.exists(IMAGE_PATH), f"Image not found at {IMAGE_PATH}"

# Load model
try:
    model = tf.keras.models.load_model(VGG16_MODEL_PATH)
    print("Loaded pre-saved VGG16 model")
except Exception as e:
    print(f"Could not load saved model: {e}\nLoading from Keras applications instead")
    model = tf.keras.applications.VGG16(weights='imagenet', include_top=True)
    
# Remove softmax for explanation methods
model_no_softmax = remove_softmax(model)

# Load and display image
original_img, preprocessed_img = load_image(IMAGE_PATH, expand_dims=True)

plt.figure(figsize=(6, 6))
plt.imshow(original_img)
plt.axis('off')
plt.title('Sample Image')
plt.show()

## 3. Customizing LRP Rules with TensorFlow

SignXAI provides extensive customization options for Layer-wise Relevance Propagation (LRP) with TensorFlow models. Let's explore different LRP rules and configurations.

In [None]:
# Get a prediction for reference
predictions = model.predict(preprocessed_img)
predicted_class = np.argmax(predictions[0])
print(f"Predicted class index: {predicted_class}")

# Let's import the specialized analyzer creation function
from signxai.tf_signxai.methods.innvestigate.analyzer import create_analyzer

# Create LRP analyzers with different configurations
# 1. LRP-Z
lrp_z_analyzer = create_analyzer('lrp.z', model_no_softmax)

# 2. LRP-Epsilon with custom epsilon
epsilon = 0.01
lrp_epsilon_analyzer = create_analyzer('lrp.epsilon', model_no_softmax, epsilon=epsilon)

# 3. LRP-Alpha-Beta with fixed alpha=1, beta=0
lrp_alpha_beta_analyzer = create_analyzer('lrp.alpha_1_beta_0', model_no_softmax)

# 4. LRP with composite rules - different rules for different layers
lrp_composite_analyzer = create_analyzer('lrp.sequential_composite_a', model_no_softmax)

# 5. LRP with custom input layer rule
lrp_custom_input = create_analyzer('lrp.z', model_no_softmax, input_layer_rule='SIGN')

In [None]:
# Function to get explanations from analyzers
def get_lrp_explanation(analyzer, input_img, class_idx=None):
    # Default to predicted class if none provided
    if class_idx is None:
        class_idx = predicted_class
        
    # Use the analyzer to get explanation
    explanation = analyzer.analyze([input_img[0]], neuron_selection=class_idx)
    
    # Ensure we get the actual explanation (might be in a dict)
    if isinstance(explanation, dict):
        explanation = explanation[list(explanation.keys())[0]][0]
    else:
        explanation = explanation[0]
    
    return explanation

# Generate explanations for each LRP variant
lrp_explanations = {
    "LRP-Z": get_lrp_explanation(lrp_z_analyzer, preprocessed_img),
    f"LRP-Epsilon (ε={epsilon})": get_lrp_explanation(lrp_epsilon_analyzer, preprocessed_img),
    "LRP-Alpha1-Beta0": get_lrp_explanation(lrp_alpha_beta_analyzer, preprocessed_img),
    "LRP-Composite": get_lrp_explanation(lrp_composite_analyzer, preprocessed_img),
    "LRP-Z with SIGN input": get_lrp_explanation(lrp_custom_input, preprocessed_img)
}

In [None]:
# Utility function for visualization
def preprocess_explanation(explanation):
    # Use absolute values for attribution
    abs_explanation = np.abs(explanation)
    
    # Normalize for visualization
    if abs_explanation.max() > 0:
        abs_explanation = abs_explanation / abs_explanation.max()
    
    return abs_explanation

# Visualize all LRP variants
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

# Get numpy array of original image for visualization
np_img = np.array(original_img)

# Plot the original image in the first position
axes[0].imshow(np_img)
axes[0].set_title("Original Image")
axes[0].axis('off')

# Plot all LRP variants in the remaining positions
for i, (method_name, explanation) in enumerate(lrp_explanations.items(), 1):
    if i < len(axes):
        visualize_attribution(np_img, preprocess_explanation(explanation), ax=axes[i])
        axes[i].set_title(method_name)
        axes[i].axis('off')

plt.tight_layout()
plt.show()

## 4. Creating Custom Explainability Methods

SignXAI allows you to create your own custom explainability methods or modify existing ones. Let's implement a simple custom method and integrate it with SignXAI.

In [None]:
# Let's create a custom method that combines gradient-based methods with an activation threshold
class ThresholdedGradient:
    def __init__(self, model, threshold=0.2):
        self.model = model
        self.threshold = threshold
        
    def attribute(self, input_tensor, target_class=None):
        # Convert to TensorFlow tensor if necessary
        if not isinstance(input_tensor, tf.Tensor):
            input_tensor = tf.convert_to_tensor(input_tensor)
        
        # Make a copy that we can use with gradient tape
        input_tensor = tf.identity(input_tensor)
        
        with tf.GradientTape() as tape:
            # Watch the input tensor
            tape.watch(input_tensor)
            
            # Forward pass
            predictions = self.model(input_tensor)
            
            # If target class not specified, use the predicted class
            if target_class is None:
                target_class = tf.argmax(predictions[0])
                
            # Get the target output
            target_output = predictions[:, target_class]
            
        # Calculate gradients
        gradients = tape.gradient(target_output, input_tensor)
        
        # Convert to numpy for further processing
        gradients_np = gradients.numpy()
        input_np = input_tensor.numpy()
        
        # Apply our custom thresholding logic
        # Only keep gradients where the absolute value of the input is above threshold
        normalized_input = input_np / np.max(np.abs(input_np) + 1e-10)
        mask = (np.abs(normalized_input) > self.threshold).astype(float)
        
        # Apply the mask to the gradients
        thresholded_gradients = gradients_np * mask
        
        return thresholded_gradients

# Create our custom explainer
custom_explainer = ThresholdedGradient(model_no_softmax, threshold=0.2)

# Generate an explanation
custom_explanation = custom_explainer.attribute(preprocessed_img, target_class=predicted_class)

In [None]:
# Let's compare our custom method with standard gradient and gradient × input
# Standard gradient method
gradient_explainer = SIGN(model_no_softmax)
gradient_explanation = gradient_explainer.attribute(preprocessed_img, target_class=predicted_class).numpy()

# Gradient × input method (manually implemented)
gradient_x_input = gradient_explanation * preprocessed_img[0].numpy()

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot standard gradient
visualize_attribution(np_img, preprocess_explanation(gradient_explanation), ax=axes[0])
axes[0].set_title("Standard Gradient")
axes[0].axis('off')

# Plot gradient × input
visualize_attribution(np_img, preprocess_explanation(gradient_x_input), ax=axes[1])
axes[1].set_title("Gradient × Input")
axes[1].axis('off')

# Plot our custom thresholded gradient
visualize_attribution(np_img, preprocess_explanation(custom_explanation), ax=axes[2])
axes[2].set_title("Custom Thresholded Gradient")
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 5. Working with Custom Model Architectures

SignXAI works with custom TensorFlow model architectures, not just pre-defined ones. Let's create a simple custom model and explain its predictions.

In [None]:
# Define a simple custom CNN model
def create_custom_model(input_shape=(224, 224, 3), num_classes=1000):
    inputs = keras.Input(shape=input_shape)
    
    # Feature extraction
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    # Classification
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(512, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return keras.Model(inputs=inputs, outputs=outputs)

# Create the custom model
custom_model = create_custom_model()
custom_model.compile(optimizer='adam', loss='categorical_crossentropy')

# Display the model architecture
custom_model.summary()

# Remove softmax for explainability
custom_model_no_softmax = remove_softmax(custom_model)

In [None]:
# Make a prediction with the custom model
# (Since we didn't train it, it'll just make random predictions)
custom_prediction = custom_model.predict(preprocessed_img)
custom_predicted_class = np.argmax(custom_prediction[0])
print(f"Custom model prediction: Class {custom_predicted_class}")

# Apply different explainability methods to the custom model
methods = {
    "Gradient": SIGN(custom_model_no_softmax),
    "GradCAM": GradCAM(custom_model, layer_name='conv2d_2'),  # Last conv layer
    "GuidedBackprop": GuidedBackprop(custom_model_no_softmax),
}

custom_explanations = {}
for name, method in methods.items():
    explanation = method.attribute(preprocessed_img, target_class=custom_predicted_class).numpy()
    custom_explanations[name] = explanation

In [None]:
# Visualize explanations for the custom model
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

for i, (name, explanation) in enumerate(custom_explanations.items()):
    visualize_attribution(np_img, preprocess_explanation(explanation), ax=axes[i])
    axes[i].set_title(f"Custom Model: {name}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 6. Advanced Visualization Techniques

SignXAI provides advanced visualization options to help interpret complex models more effectively.

In [None]:
# Get a high-quality explanation to visualize
lrp_explanation = get_lrp_explanation(lrp_z_analyzer, preprocessed_img)
lrp_processed = preprocess_explanation(lrp_explanation)

# 1. Visualization with different color maps
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Default Red-Blue colormap
visualize_attribution(np_img, lrp_processed, ax=axes[0], cmap='RdBu_r')
axes[0].set_title("Red-Blue Colormap")
axes[0].axis('off')

# Heat colormap
visualize_attribution(np_img, lrp_processed, ax=axes[1], cmap='hot')
axes[1].set_title("Heat Colormap")
axes[1].axis('off')

# Viridis colormap
visualize_attribution(np_img, lrp_processed, ax=axes[2], cmap='viridis')
axes[2].set_title("Viridis Colormap")
axes[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# 2. Advanced: Creating an occlusion sensitivity map
def occlusion_sensitivity(model, image, target_class, patch_size=20, stride=10):
    """Generate occlusion sensitivity map by sliding a gray patch over the image."""
    # Get the image dimensions
    h, w, c = image.shape
    
    # Create a baseline prediction
    baseline_prediction = model.predict(np.expand_dims(image, axis=0))[0, target_class]
    
    # Initialize sensitivity map
    sensitivity_map = np.zeros((h, w))
    
    # Create a gray occlusion patch
    occlusion_value = 128  # Gray value
    
    # Loop through the image
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            # Create a copy of the image
            occluded_image = image.copy()
            
            # Apply the occlusion patch
            occluded_image[y:y+patch_size, x:x+patch_size, :] = occlusion_value
            
            # Get the prediction for the occluded image
            occluded_prediction = model.predict(np.expand_dims(occluded_image, axis=0))[0, target_class]
            
            # Calculate change in prediction
            diff = baseline_prediction - occluded_prediction
            
            # Update the sensitivity map
            sensitivity_map[y:y+patch_size, x:x+patch_size] += diff
    
    # Normalize the sensitivity map
    if sensitivity_map.max() != sensitivity_map.min():
        sensitivity_map = (sensitivity_map - sensitivity_map.min()) / (sensitivity_map.max() - sensitivity_map.min())
    
    return sensitivity_map

# Generate occlusion sensitivity map (this may take a while)
print("Generating occlusion sensitivity map (this may take a few minutes)...")
# Using a smaller stride for faster computation in this example
occlusion_map = occlusion_sensitivity(model, preprocessed_img[0], predicted_class, 
                                     patch_size=40, stride=20)
print("Done!")

In [None]:
# Visualize the occlusion sensitivity map
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original image
axes[0].imshow(np_img)
axes[0].set_title("Original Image")
axes[0].axis('off')

# Occlusion sensitivity map
axes[1].imshow(occlusion_map, cmap='hot')
axes[1].set_title("Occlusion Sensitivity Map")
axes[1].axis('off')

# Overlay on original image
visualize_attribution(np_img, occlusion_map, ax=axes[2])
axes[2].set_title("Overlay on Original Image")
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 7. Integrating with TensorFlow's GradientTape API

You can create your own custom explainability methods using TensorFlow's GradientTape for more flexibility.

In [None]:
# Let's implement Integrated Gradients manually using GradientTape
def integrated_gradients(model, input_image, target_class=None, steps=50, baseline=None):
    """Implement Integrated Gradients explainability method."""
    # Convert to TensorFlow tensor if needed
    if not isinstance(input_image, tf.Tensor):
        input_image = tf.convert_to_tensor(input_image)
    
    # Create a baseline (black image) if not provided
    if baseline is None:
        baseline = tf.zeros_like(input_image)
    
    # Calculate the path from baseline to input
    path = [baseline + (i / steps) * (input_image - baseline) for i in range(1, steps + 1)]
    path_tensor = tf.stack(path)
    
    # Compute gradients at each step
    gradients = []
    
    # For each step along the path
    for step_input in path:
        # Use GradientTape to compute gradients
        with tf.GradientTape() as tape:
            tape.watch(step_input)
            prediction = model(step_input)
            
            # Use predicted class if target_class is not specified
            if target_class is None:
                target_class = tf.argmax(prediction[0])
                
            output = prediction[:, target_class]
        
        # Get the gradient
        gradient = tape.gradient(output, step_input)
        gradients.append(gradient)
    
    # Stack gradients
    gradients_tensor = tf.stack(gradients)
    
    # Riemann sum approximation of the integral
    integrated_gradients = tf.reduce_mean(gradients_tensor, axis=0)
    
    # Multiply by (input - baseline) for final attribution
    attribution = integrated_gradients * (input_image - baseline)
    
    return attribution.numpy()

# Generate integrated gradients explanation
ig_explanation = integrated_gradients(model_no_softmax, preprocessed_img, target_class=predicted_class, steps=20)

In [None]:
# Visualize integrated gradients explanation compared to standard gradient
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Standard gradient
visualize_attribution(np_img, preprocess_explanation(gradient_explanation), ax=axes[0])
axes[0].set_title("Standard Gradient")
axes[0].axis('off')

# Integrated gradients
visualize_attribution(np_img, preprocess_explanation(ig_explanation), ax=axes[1])
axes[1].set_title("Integrated Gradients (20 steps)")
axes[1].axis('off')

plt.tight_layout()
plt.show()

## 8. Conclusion

In this advanced tutorial, we explored several sophisticated aspects of using SignXAI with TensorFlow:

1. **Customizing LRP Rules**: We demonstrated how to configure different LRP variants including Z-rule, epsilon-rule, alpha-beta rules, and composite LRP configurations.

2. **Creating Custom Explainability Methods**: We implemented custom methods like thresholded gradient and manually implemented Integrated Gradients.

3. **Working with Custom Models**: We showed how SignXAI works seamlessly with custom model architectures.

4. **Advanced Visualization**: We explored different visualization options including different colormaps and occlusion sensitivity analysis.

5. **Integrating with GradientTape**: We demonstrated how to use TensorFlow's GradientTape API to create custom explainability methods.

These advanced techniques provide deeper insights into how your models make decisions and can help identify potential biases or weaknesses in the model's reasoning process.

For more information and to contribute to the project, visit the SignXAI documentation and GitHub repository.