# SignXAI2 TensorFlow Tutorial - Image Classification

This tutorial demonstrates how to use SignXAI2 for explaining image classification models with TensorFlow.

## Setup

⚠️ **Data Requirements**: This tutorial requires example data from the GitHub repository. Please ensure you have downloaded the necessary data files or cloned the repository.

First, let's download the signxai2 package and a sample image to work with:

In [None]:
# Download the signxai2 package if not already installed
 !pip install signxai2[tensorflow]

# Download an example image
import urllib.request

# Download an image of a dog
url = "http://vision.stanford.edu/aditya86/ImageNetDogs/images/n02106030-collie/n02106030_16370.jpg"
urllib.request.urlretrieve(url, "dog.jpg")

## TensorFlow Implementation

Let's use a pre-trained VGG16 model with TensorFlow:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from signxai import explain, list_methods
from signxai.utils.utils import normalize_heatmap

# Load the pre-trained model
model = VGG16(weights='imagenet')

# Remove softmax layer (critical for explanations)
model.layers[-1].activation = None

# Load and preprocess the image
img_path = "dog.jpg"
img = load_img(img_path, target_size=(224, 224))
x = img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

# Make prediction
preds = model.predict(x)
top_pred_idx = np.argmax(preds[0])
print(f"Predicted class: {decode_predictions(preds, top=1)[0][0][1]}")

In [None]:
# Calculate explanations with different methods
methods = [
    'gradient',
    'gradient_x_input',
    'integrated_gradients',
    'smoothgrad',
    'grad_cam',
    'lrp_z',
    'lrp_epsilon_0_1',
    'lrpsign_z'  # The SIGN method
]

explanations = {}
for method in methods:
    explanations[method] = explain(
        model=model,
        x=x,
        method_name=method,
        target_class=top_pred_idx
    )

In [None]:
# Visualize explanations
fig, axs = plt.subplots(2, 4, figsize=(20, 10))
axs = axs.flatten()

# Original image
axs[0].imshow(img)
axs[0].set_title('Original Image', fontsize=14)
axs[0].axis('off')

# Explanations
for i, method in enumerate(methods[:7]):
    axs[i+1].imshow(normalize_heatmap(explanations[method][0]), cmap='seismic', clim=(-1, 1))
    axs[i+1].set_title(method, fontsize=14)
    axs[i+1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Highlight the difference between standard LRP and SIGN
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.imshow(img)
plt.title('Original Image', fontsize=14)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(normalize_heatmap(explanations['lrp_z'][0]), cmap='seismic', clim=(-1, 1))
plt.title('LRP-Z', fontsize=14)
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(normalize_heatmap(explanations['lrpsign_z'][0]), cmap='seismic', clim=(-1, 1))
plt.title('LRP-SIGN', fontsize=14)
plt.axis('off')

plt.tight_layout()
plt.show()