# XAI Visual Quality Control - Demo Notebook

This notebook demonstrates the complete functionality of the XAI Visual Quality Control system for radiographic defect detection.

## Features Demonstrated:
1. Image loading and preprocessing
2. Defect detection with Faster R-CNN
3. All 4 XAI explanation methods (Grad-CAM, SHAP, LIME, Integrated Gradients)
4. Explanation aggregation and consensus scoring
5. Uncertainty quantification with MC-Dropout
6. Model calibration (ECE, Temperature Scaling)
7. Comprehensive metrics calculation
8. Visualization of results

**Note**: Make sure you have trained a model first using `backend/scripts/train.py`

## 1. Setup and Imports

In [None]:
import sys
sys.path.append('../backend')

import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

# Core modules
from core.models.detector import DefectDetector
from core.preprocessing.image_processor import ImageProcessor

# XAI modules
from core.xai.gradcam import GradCAM
from core.xai.shap_explainer import SHAPExplainer
from core.xai.lime_explainer import LIMEExplainer
from core.xai.integrated_gradients import IntegratedGradientsExplainer
from core.xai.aggregator import XAIAggregator

# Uncertainty modules
from core.uncertainty.mc_dropout import MCDropoutEstimator
from core.uncertainty.calibration import calculate_ece, TemperatureScaling

# Metrics modules
from core.metrics.business_metrics import calculate_confusion_matrix_metrics
from core.metrics.detection_metrics import calculate_map, calculate_auroc
from core.metrics.segmentation_metrics import calculate_mean_iou

print("All imports successful!")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Model and Initialize Components

In [None]:
# Configuration
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_PATH = '../backend/models/checkpoints/best_model.pth'

# Initialize defect detector
print(f"Loading model on {DEVICE}...")
model = DefectDetector(num_classes=2, device=DEVICE)

if Path(MODEL_PATH).exists():
    model.load_weights(MODEL_PATH)
    print(f"Model loaded from {MODEL_PATH}")
else:
    print("Warning: Model weights not found. Using untrained model for demonstration.")

# Initialize image processor
image_processor = ImageProcessor(target_size=(512, 512))

# Initialize XAI explainers
print("Initializing XAI explainers...")
gradcam = GradCAM(model.model)
shap_explainer = SHAPExplainer(model.model)
lime_explainer = LIMEExplainer(model.model)
ig_explainer = IntegratedGradientsExplainer(model.model)
xai_aggregator = XAIAggregator()

# Initialize uncertainty estimator
mc_dropout = MCDropoutEstimator(model.model, n_samples=10, device=DEVICE)

print("All components initialized!")

## 3. Load and Preprocess Test Image

In [None]:
# Load test image (use a sample from your test dataset)
test_image_path = '../backend/data/test/images/test_0000.jpg'

if not Path(test_image_path).exists():
    print(f"Test image not found at {test_image_path}")
    print("Please run: python backend/scripts/generate_test_dataset.py")
    print("Or provide your own test image")
else:
    # Load image
    image = image_processor.load_image(test_image_path)
    print(f"Image loaded: shape={image.shape}, dtype={image.dtype}")
    
    # Preprocess
    preprocessed = image_processor.preprocess(image)
    print(f"Preprocessed: shape={preprocessed.shape}, dtype={preprocessed.dtype}")
    
    # Convert to tensor
    image_tensor = torch.from_numpy(
        image_processor.to_tensor(preprocessed)
    ).float().unsqueeze(0).to(DEVICE)
    print(f"Tensor shape: {image_tensor.shape}")
    
    # Visualize original image
    plt.figure(figsize=(8, 6))
    plt.imshow(image, cmap='gray' if len(image.shape) == 2 else None)
    plt.title('Original Test Image')
    plt.axis('off')
    plt.show()

## 4. Run Defect Detection

In [None]:
# Run detection
detections = model.predict(image_tensor)

print(f"Number of detections: {len(detections)}")
print("\nDetection results:")
for i, det in enumerate(detections):
    print(f"  Detection {i+1}:")
    print(f"    - Label: {det['label']}")
    print(f"    - Confidence: {det['score']:.3f}")
    print(f"    - Bounding Box: {det['box']}")
    
# Visualize detections
plt.figure(figsize=(12, 8))
plt.imshow(image, cmap='gray' if len(image.shape) == 2 else None)

for det in detections:
    box = det['box']
    x1, y1, x2, y2 = box
    width = x2 - x1
    height = y2 - y1
    
    # Draw bounding box
    rect = plt.Rectangle((x1, y1), width, height, 
                         fill=False, edgecolor='red', linewidth=2)
    plt.gca().add_patch(rect)
    
    # Add label
    label = f"{det['label']} ({det['score']:.2f})"
    plt.text(x1, y1-5, label, color='red', fontsize=10, 
            bbox=dict(facecolor='white', alpha=0.7))

plt.title(f'Defect Detection Results ({len(detections)} detections)')
plt.axis('off')
plt.tight_layout()
plt.show()

## 5. Generate XAI Explanations

Now let's generate explanations using all 4 XAI methods:

In [None]:
# Generate explanations from all 4 methods
explanations = {}

print("Generating XAI explanations...")

# Grad-CAM
print("  1/4 Grad-CAM...")
explanations['gradcam'] = gradcam.generate_heatmap(image_tensor, target_class=1)

# SHAP
print("  2/4 SHAP...")
explanations['shap'] = shap_explainer.generate_heatmap(image_tensor, target_class=1)

# LIME
print("  3/4 LIME...")
explanations['lime'] = lime_explainer.generate_heatmap(image_tensor, target_class=1)

# Integrated Gradients
print("  4/4 Integrated Gradients...")
explanations['ig'] = ig_explainer.generate_heatmap(image_tensor, target_class=1, baseline='black')

print("All explanations generated!")

# Visualize all explanations
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Original image
axes[0, 0].imshow(image, cmap='gray' if len(image.shape) == 2 else None)
axes[0, 0].set_title('Original Image', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

# Grad-CAM
axes[0, 1].imshow(explanations['gradcam'], cmap='jet')
axes[0, 1].set_title('Grad-CAM', fontsize=12, fontweight='bold')
axes[0, 1].axis('off')

# SHAP
axes[0, 2].imshow(explanations['shap'], cmap='jet')
axes[0, 2].set_title('SHAP', fontsize=12, fontweight='bold')
axes[0, 2].axis('off')

# LIME
axes[1, 0].imshow(explanations['lime'], cmap='jet')
axes[1, 0].set_title('LIME', fontsize=12, fontweight='bold')
axes[1, 0].axis('off')

# Integrated Gradients
axes[1, 1].imshow(explanations['ig'], cmap='jet')
axes[1, 1].set_title('Integrated Gradients', fontsize=12, fontweight='bold')
axes[1, 1].axis('off')

# Leave last subplot empty for now
axes[1, 2].axis('off')

plt.suptitle('XAI Explanations Comparison', fontsize=14, fontweight='bold', y=0.98)
plt.tight_layout()
plt.show()

## 6. Aggregate Explanations and Compute Consensus

In [None]:
# Aggregate all explanations
heatmaps_list = list(explanations.values())

# Try different aggregation methods
aggregation_methods = ['mean', 'median', 'weighted']
aggregated_results = {}

for method in aggregation_methods:
    if method == 'weighted':
        weights = [0.3, 0.3, 0.2, 0.2]  # Prioritize Grad-CAM and SHAP
        aggregated = xai_aggregator.aggregate(heatmaps_list, method=method, weights=weights)
    else:
        aggregated = xai_aggregator.aggregate(heatmaps_list, method=method)
    
    aggregated_results[method] = aggregated

# Compute consensus score
consensus_metrics = ['correlation', 'iou', 'dice']
consensus_scores = {}

for metric in consensus_metrics:
    score = xai_aggregator.compute_consensus_score(heatmaps_list, metric=metric)
    consensus_scores[metric] = score
    print(f"Consensus Score ({metric}): {score:.4f}")

# Visualize aggregated results
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(image, cmap='gray' if len(image.shape) == 2 else None)
axes[0].set_title('Original', fontsize=12, fontweight='bold')
axes[0].axis('off')

for idx, (method, heatmap) in enumerate(aggregated_results.items(), 1):
    axes[idx].imshow(heatmap, cmap='jet')
    axes[idx].set_title(f'Aggregated ({method})', fontsize=12, fontweight='bold')
    axes[idx].axis('off')

plt.suptitle(f'Aggregated Explanations | Consensus: {consensus_scores["correlation"]:.3f}', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Uncertainty Quantification with MC-Dropout

In [None]:
# Compute predictive uncertainty using MC-Dropout
print("Computing uncertainty with MC-Dropout (10 samples)...")

# Get uncertainty map
uncertainty_map = mc_dropout.compute_predictive_entropy(image_tensor)
mean_uncertainty = uncertainty_map.mean()
max_uncertainty = uncertainty_map.max()

print(f"Mean Uncertainty: {mean_uncertainty:.4f}")
print(f"Max Uncertainty: {max_uncertainty:.4f}")

# Visualize uncertainty
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original image
axes[0].imshow(image, cmap='gray' if len(image.shape) == 2 else None)
axes[0].set_title('Original Image', fontsize=12, fontweight='bold')
axes[0].axis('off')

# Uncertainty map
im1 = axes[1].imshow(uncertainty_map, cmap='hot')
axes[1].set_title('Predictive Entropy (Uncertainty)', fontsize=12, fontweight='bold')
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046)

# Overlay uncertainty on image
axes[2].imshow(image, cmap='gray' if len(image.shape) == 2 else None, alpha=0.7)
im2 = axes[2].imshow(uncertainty_map, cmap='hot', alpha=0.3)
axes[2].set_title('Uncertainty Overlay', fontsize=12, fontweight='bold')
axes[2].axis('off')

plt.suptitle(f'MC-Dropout Uncertainty | Mean: {mean_uncertainty:.4f}', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 8. Model Calibration Metrics

Check the model's calibration using Expected Calibration Error (ECE):

In [None]:
# For demonstration, let's simulate some predictions and compute ECE
# In practice, you would evaluate on your validation set

# Simulate predictions (replace with actual validation data)
np.random.seed(42)
n_samples = 100
confidences = np.random.beta(5, 2, n_samples)  # Simulated confidence scores
predictions = (confidences > 0.5).astype(int)
labels = np.random.randint(0, 2, n_samples)  # Simulated ground truth

# Calculate ECE
ece = calculate_ece(
    torch.tensor(confidences),
    torch.tensor(predictions),
    torch.tensor(labels),
    n_bins=10
)

print(f"Expected Calibration Error (ECE): {ece:.4f}")
print(f"Interpretation: {'Well-calibrated' if ece < 0.1 else 'Needs calibration'}")

# Visualize calibration
from core.uncertainty.calibration import plot_reliability_diagram

fig = plot_reliability_diagram(
    torch.tensor(confidences),
    torch.tensor(predictions),
    torch.tensor(labels),
    n_bins=10
)
plt.show()

print("\nNote: This is simulated data. Run calibration on actual validation set for real metrics.")

## 9. Summary and Next Steps

### What We Demonstrated:
1. ✅ **Model Loading**: Loaded Faster R-CNN defect detector
2. ✅ **Image Processing**: Preprocessed radiographic images
3. ✅ **Defect Detection**: Detected and visualized defects with bounding boxes
4. ✅ **XAI Explanations**: Generated heatmaps using 4 methods (Grad-CAM, SHAP, LIME, IG)
5. ✅ **Aggregation**: Combined explanations and computed consensus scores
6. ✅ **Uncertainty**: Quantified prediction uncertainty with MC-Dropout
7. ✅ **Calibration**: Calculated Expected Calibration Error (ECE)

### Key Results:
- Detection confidence scores show model reliability
- XAI heatmaps highlight relevant image regions
- Consensus scores indicate agreement between methods
- Uncertainty maps identify ambiguous areas
- ECE shows calibration quality

### Next Steps:
1. **Train on Real Data**: Replace synthetic data with actual radiographic images
2. **API Integration**: Use the FastAPI endpoints (`/api/xai-qc/detect`, `/explain`)
3. **Frontend Development**: Build the Makerkit UI to consume these APIs
4. **Production Deployment**: Dockerize and deploy with `docker-compose.yml`
5. **Continuous Monitoring**: Track metrics over time with MLflow

For API usage, see: `backend/api/routes.py`  
For training: `backend/scripts/train.py`  
For deployment: `docker-compose.yml` (to be created)