In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# Import explainability toolkit
from explainers import GradCAM, GradCAMPlusPlus, IntegratedGradients
from utils import load_model, get_target_layer_name, load_image, compare_methods

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## Load Model

Load a pre-trained model from torchvision.

In [None]:
# Load ResNet50
model = load_model('resnet50', pretrained=True, device=device)
target_layer = get_target_layer_name(model)

print(f"Model loaded: ResNet50")
print(f"Target layer for GradCAM: {target_layer}")

## Load Image

Load and preprocess an image for classification.

In [None]:
# Load image - update this path!
image_path = 'path/to/your/image.jpg'
image = load_image(image_path)

# Display image
plt.figure(figsize=(6, 6))
plt.imshow(image.permute(1, 2, 0))
plt.axis('off')
plt.title('Input Image')
plt.show()

## Get Prediction

Get the model's prediction for the image.

In [None]:
# Prepare input
input_tensor = image.unsqueeze(0).to(device)

# Get prediction
with torch.no_grad():
    output = model(input_tensor)
    probs = torch.nn.functional.softmax(output, dim=1)
    top5_probs, top5_indices = torch.topk(probs[0], 5)

# Display results
print("Top 5 Predictions:")
for i, (prob, idx) in enumerate(zip(top5_probs, top5_indices)):
    print(f"  {i+1}. Class {idx}: {prob:.4f}")

target_class = int(top5_indices[0])
print(f"\nTarget class for explanation: {target_class}")

## Generate Explanations

Generate explanations using different methods.

In [None]:
# Initialize explainers
explainers = {
    'GradCAM': GradCAM(model, target_layer, device),
    'GradCAM++': GradCAMPlusPlus(model, target_layer, device),
    'Integrated Gradients': IntegratedGradients(model, device)
}

# Generate explanations
explanations = {}
for method_name, explainer in explainers.items():
    print(f"Generating {method_name}...")
    explanation = explainer.explain(input_tensor, target_class)
    explanations[method_name] = explanation

print("âœ… All explanations generated!")

## Visualize Comparisons

Compare the different explanation methods side-by-side.

In [None]:
# Create comparison visualization
fig = compare_methods(
    image,
    explanations,
    colormap='jet',
    figsize=(15, 4)
)
plt.show()

## Individual Explanations

View each explanation in detail.

In [None]:
from utils import overlay_heatmap

for method_name, heatmap in explanations.items():
    # Create overlay
    overlaid = overlay_heatmap(image, heatmap, alpha=0.5)
    
    # Display
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Show heatmap
    ax1.imshow(heatmap, cmap='jet')
    ax1.set_title(f'{method_name} - Heatmap')
    ax1.axis('off')
    
    # Show overlay
    ax2.imshow(overlaid)
    ax2.set_title(f'{method_name} - Overlay')
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()

## Quantitative Evaluation

Evaluate explanations using deletion and insertion metrics.

In [None]:
from metrics import DeletionInsertion

# Initialize metric
di_metric = DeletionInsertion(model, device, n_steps=50)

# Evaluate each method
results = {}
for method_name, heatmap in explanations.items():
    print(f"Evaluating {method_name}...")
    result = di_metric.evaluate(input_tensor, heatmap, target_class)
    results[method_name] = result
    print(f"  Deletion AUC: {result['deletion_auc']:.3f}")
    print(f"  Insertion AUC: {result['insertion_auc']:.3f}")

## Plot Evaluation Curves

Visualize the deletion and insertion curves.

In [None]:
from utils import plot_deletion_insertion_curves

fig = plot_deletion_insertion_curves(results, figsize=(15, 5))
plt.show()

## Summary

Compare all methods based on their metrics.

In [None]:
import pandas as pd

# Create summary table
summary_data = []
for method_name, result in results.items():
    summary_data.append({
        'Method': method_name,
        'Deletion AUC': f"{result['deletion_auc']:.3f}",
        'Insertion AUC': f"{result['insertion_auc']:.3f}"
    })

df = pd.DataFrame(summary_data)
print("\nSummary of Results:")
print(df.to_string(index=False))

print("\nðŸ“Š Interpretation:")
print("  â€¢ Deletion AUC: Lower is better (explanation captures important features)")
print("  â€¢ Insertion AUC: Higher is better (explanation is sufficient for prediction)")

## Next Steps

- Try different models (ResNet18, DenseNet, EfficientNet)
- Experiment with medical imaging datasets
- Add ground truth masks for plausibility evaluation
- Test perturbation-based methods (RISE, Occlusion)
- Explore attention-based methods for Vision Transformers