# Deepfake Detection Analysis & XAI

This notebook consolidates model loading, inference, and Explainable AI (XAI) analysis.
It implements **GradCAM** and **LIME** to visualize which parts of the image contribute to the deepfake detection decision.

**Author:** Team Member 3 (Analysis & Detection)

## Contents
1. Setup & Dependencies
2. Model Architecture
3. Preprocessing Pipeline
4. XAI Methods (GradCAM, LIME)
5. Visualization Utilities
6. Single Image Analysis
7. Batch Analysis & Comparison
8. Results Export


In [None]:
# Install required packages 
# %pip install torch torchvision timm opencv-python matplotlib pillow lime numpy

In [None]:
# Cell 1: Setup & Dependencies
import os
import json
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image
import timm
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Tuple, Optional, Dict
import warnings
warnings.filterwarnings('ignore')

# Check for CUDA
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Project paths
PROJECT_ROOT = Path("../").resolve()
print(f"Project root: {PROJECT_ROOT}")


Project root: /Users/aslanbayli/Documents/nyu/advanced-cv/forensic-bind/src


In [None]:
# Configuration (from v1/efficientnet/model_config.json)
CONFIG = {
    "num_classes": 3,
    "backbone_name": "efficientnet_b0",
    "imagenet_mean": [0.485, 0.456, 0.406],
    "imagenet_std": [0.229, 0.224, 0.225],
    "class_names": ["Real", "FaceSwap", "Face2Face"],
    "input_size": (224, 224),
}

NUM_CLASSES = CONFIG["num_classes"]
BACKBONE_NAME = CONFIG["backbone_name"]
IMAGENET_MEAN = CONFIG["imagenet_mean"]
IMAGENET_STD = CONFIG["imagenet_std"]
CLASS_NAMES = CONFIG["class_names"]

print("Configuration loaded:")
print(f"  Classes: {CLASS_NAMES}")
print(f"  Backbone: {BACKBONE_NAME}")


## 2. Model Architecture

EfficientNet-based deepfake classifier using the `timm` library.


In [None]:
class EfficientNetDeepfake(nn.Module):
    """
    EfficientNet-based deepfake classifier.
    Uses timm library for the backbone.
    """
    def __init__(self, num_classes: int = NUM_CLASSES, backbone_name: str = BACKBONE_NAME):
        super().__init__()
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=False,
            num_classes=0,  # Remove classification head
            in_chans=3,
        )
        # Get feature dimension dynamically
        if hasattr(self.backbone, "num_features"):
            feat_dim = self.backbone.num_features
        elif hasattr(self.backbone, "classifier") and hasattr(self.backbone.classifier, "in_features"):
            feat_dim = self.backbone.classifier.in_features
        else:
            feat_dim = 1280  # Default for efficientnet_b0
        
        self.classifier = nn.Linear(feat_dim, num_classes)
        self.feat_dim = feat_dim

    def forward(self, x):
        feats = self.backbone(x)
        logits = self.classifier(feats)
        return logits
    
    def get_features(self, x):
        """Get feature embeddings without classification."""
        return self.backbone(x)

print("EfficientNetDeepfake model class defined.")


## 3. Preprocessing Pipeline


In [None]:
def get_eval_transform():
    """Standard evaluation transform matching training."""
    return T.Compose([
        T.Resize((256, 256)),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

def load_image(path: str) -> Image.Image:
    """Load an image from path."""
    return Image.open(path).convert("RGB")

def preprocess_image(image: Image.Image) -> torch.Tensor:
    """Preprocess a PIL image for model input."""
    transform = get_eval_transform()
    return transform(image).unsqueeze(0)

def denormalize(tensor: torch.Tensor) -> np.ndarray:
    """Convert normalized tensor back to displayable image."""
    mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
    std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    tensor = tensor.cpu().clone()
    tensor = tensor * std + mean
    tensor = torch.clamp(tensor, 0, 1)
    return (tensor.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)

print("Preprocessing functions defined.")


## 4. XAI Methods

### Base Explainer Class


In [None]:
class BaseExplainer(ABC):
    """
    Abstract base class for all explanation methods.
    """
    def __init__(self, model: nn.Module, device: str = 'cpu'):
        self.model = model
        self.device = device
        self.model.to(self.device)
        self.model.eval()

    @abstractmethod
    def explain(self, input_tensor: torch.Tensor, target_class: int = None) -> np.ndarray:
        """Generate explanation heatmap."""
        pass
    
    def predict(self, input_tensor: torch.Tensor) -> Tuple[int, np.ndarray]:
        """Get prediction and probabilities."""
        with torch.no_grad():
            logits = self.model(input_tensor.to(self.device))
            probs = F.softmax(logits, dim=1)[0].cpu().numpy()
            pred_class = int(probs.argmax())
        return pred_class, probs

print("BaseExplainer class defined.")


### GradCAM Explainer

Gradient-weighted Class Activation Mapping highlights regions that contribute most to the prediction.


In [None]:
class GradCAMExplainer(BaseExplainer):
    """
    Gradient-weighted Class Activation Mapping (GradCAM).
    Highlights regions that contribute most to the prediction.
    """
    def __init__(self, model: nn.Module, target_layer: nn.Module, device: str = 'cpu'):
        super().__init__(model, device)
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self._save_activation)
        self.target_layer.register_full_backward_hook(self._save_gradient)

    def _save_activation(self, module, input, output):
        self.activations = output.detach()

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def explain(self, input_tensor: torch.Tensor, target_class: int = None) -> np.ndarray:
        """
        Compute GradCAM heatmap.
        
        Args:
            input_tensor: Preprocessed image tensor (1, C, H, W)
            target_class: Class index to explain (default: predicted class)
            
        Returns:
            Normalized heatmap (H, W) in range [0, 1]
        """
        input_tensor = input_tensor.to(self.device)
        input_tensor.requires_grad = True
        
        # Forward pass
        output = self.model(input_tensor)
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        score = output[0, target_class]
        score.backward()
        
        # Compute GradCAM
        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * self.activations, dim=1, keepdim=True)
        cam = F.relu(cam)
        
        # Resize to input size
        cam = F.interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False)
        cam = cam.squeeze().cpu().numpy()
        
        # Normalize
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        
        return cam

print("GradCAMExplainer class defined.")


### LIME Explainer

Local Interpretable Model-agnostic Explanations - explains predictions by learning local surrogate models.


In [None]:
class LIMEExplainer(BaseExplainer):
    """
    Local Interpretable Model-agnostic Explanations (LIME).
    Explains predictions by learning local surrogate models.
    """
    def __init__(self, model: nn.Module, device: str = 'cpu', num_samples: int = 1000):
        super().__init__(model, device)
        self.num_samples = num_samples
        self.explainer = None
        
        try:
            from lime import lime_image
            self.explainer = lime_image.LimeImageExplainer()
            print("LIME initialized successfully.")
        except ImportError:
            print("Warning: LIME not installed. Run `pip install lime`.")
    
    def _batch_predict(self, images: np.ndarray) -> np.ndarray:
        """Prediction function for LIME."""
        # images: (N, H, W, C) numpy array in [0, 255]
        transform = get_eval_transform()
        batch = []
        for img in images:
            pil_img = Image.fromarray(img.astype(np.uint8))
            batch.append(transform(pil_img))
        
        batch_tensor = torch.stack(batch).to(self.device)
        
        with torch.no_grad():
            logits = self.model(batch_tensor)
            probs = F.softmax(logits, dim=1).cpu().numpy()
        
        return probs

    def explain(self, input_tensor: torch.Tensor, target_class: int = None) -> np.ndarray:
        """
        Compute LIME explanation.
        
        Args:
            input_tensor: Preprocessed image tensor (1, C, H, W)
            target_class: Class index to explain (default: predicted class)
            
        Returns:
            Heatmap (H, W) showing feature importance
        """
        if self.explainer is None:
            print("LIME not available. Returning empty heatmap.")
            return np.zeros((224, 224))
        
        # Convert tensor to numpy image
        image_np = denormalize(input_tensor)
        
        # Get prediction if target_class not specified
        if target_class is None:
            target_class, _ = self.predict(input_tensor)
        
        # Run LIME
        explanation = self.explainer.explain_instance(
            image_np,
            self._batch_predict,
            top_labels=NUM_CLASSES,
            hide_color=0,
            num_samples=self.num_samples,
        )
        
        # Get heatmap for target class
        _, mask = explanation.get_image_and_mask(
            target_class,
            positive_only=False,
            num_features=10,
            hide_rest=False,
        )
        
        # Normalize mask
        mask = mask.astype(np.float32)
        mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
        
        return mask

print("LIMEExplainer class defined.")


## 5. Visualization Utilities


In [None]:
def apply_heatmap(image: np.ndarray, heatmap: np.ndarray, 
                  alpha: float = 0.5, colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
    """
    Overlay heatmap on image.
    
    Args:
        image: Original image (H, W, 3) in [0, 255]
        heatmap: Heatmap (H, W) in [0, 1]
        alpha: Opacity of heatmap overlay
        colormap: OpenCV colormap
        
    Returns:
        Blended image
    """
    if heatmap.shape[:2] != image.shape[:2]:
        heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
    
    heatmap_uint8 = (heatmap * 255).astype(np.uint8)
    heatmap_colored = cv2.applyColorMap(heatmap_uint8, colormap)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
    
    overlay = cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
    return overlay

def plot_single_explanation(image: np.ndarray, heatmap: np.ndarray, 
                            title: str = "Explanation", pred_info: str = None):
    """Plot original image, heatmap, and overlay side by side."""
    overlay = apply_heatmap(image, heatmap)
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    axes[0].imshow(image)
    axes[0].set_title("Original")
    axes[0].axis('off')
    
    im = axes[1].imshow(heatmap, cmap='jet')
    axes[1].set_title("Heatmap")
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    
    axes[2].imshow(overlay)
    axes[2].set_title("Overlay")
    axes[2].axis('off')
    
    if pred_info:
        fig.suptitle(f"{title}\n{pred_info}", fontsize=12)
    else:
        fig.suptitle(title, fontsize=12)
    
    plt.tight_layout()
    plt.show()

def plot_method_comparison(image: np.ndarray, heatmaps: Dict[str, np.ndarray], 
                           pred_info: str = None, save_path: str = None):
    """
    Compare multiple XAI methods side by side.
    
    Args:
        image: Original image
        heatmaps: Dict mapping method name to heatmap
        pred_info: Prediction info string
        save_path: Optional path to save the figure
    """
    n_methods = len(heatmaps)
    fig, axes = plt.subplots(2, n_methods + 1, figsize=(4 * (n_methods + 1), 8))
    
    # Original image
    axes[0, 0].imshow(image)
    axes[0, 0].set_title("Original", fontsize=12)
    axes[0, 0].axis('off')
    axes[1, 0].axis('off')
    
    # Each method
    for i, (method_name, heatmap) in enumerate(heatmaps.items(), 1):
        # Heatmap
        im = axes[0, i].imshow(heatmap, cmap='jet')
        axes[0, i].set_title(f"{method_name}\nHeatmap", fontsize=12)
        axes[0, i].axis('off')
        
        # Overlay
        overlay = apply_heatmap(image, heatmap)
        axes[1, i].imshow(overlay)
        axes[1, i].set_title(f"{method_name}\nOverlay", fontsize=12)
        axes[1, i].axis('off')
    
    if pred_info:
        fig.suptitle(pred_info, fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved to {save_path}")
    
    plt.show()

def plot_prediction_bars(probs: np.ndarray, class_names: List[str] = CLASS_NAMES):
    """Plot prediction probabilities as a bar chart."""
    colors = ['green' if i == probs.argmax() else 'steelblue' for i in range(len(probs))]
    
    plt.figure(figsize=(8, 4))
    bars = plt.bar(class_names, probs, color=colors, edgecolor='black')
    plt.ylabel("Probability", fontsize=12)
    plt.xlabel("Class", fontsize=12)
    plt.title("Prediction Probabilities", fontsize=14)
    plt.ylim(0, 1)
    
    # Add value labels
    for bar, prob in zip(bars, probs):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                 f'{prob:.3f}', ha='center', fontsize=10)
    
    plt.tight_layout()
    plt.show()

print("Visualization functions defined.")


## 6. Model Loading & Initialization


In [None]:
# Path to model checkpoint
MODEL_PATH = PROJECT_ROOT / "models" / "v1" / "efficientnet" / "efficientnet_deepfake_inference.pth"

# Initialize model
model = EfficientNetDeepfake(num_classes=NUM_CLASSES).to(DEVICE)

# Load weights if available
if MODEL_PATH.exists():
    state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
    model.load_state_dict(state_dict)
    print(f"✓ Model loaded from {MODEL_PATH}")
else:
    print(f"✗ Model not found at {MODEL_PATH}")
    print("  Using random weights for demonstration.")

model.eval()
print(f"\nModel architecture: {BACKBONE_NAME}")
print(f"Number of classes: {NUM_CLASSES}")
print(f"Feature dimension: {model.feat_dim}")


In [None]:
# Initialize explainers
# For GradCAM, we need to specify the target layer
# For EfficientNet-B0, conv_head is the last convolutional layer
target_layer = model.backbone.conv_head
print(f"GradCAM target layer: {type(target_layer).__name__}")

# Initialize explainers
gradcam_explainer = GradCAMExplainer(model, target_layer=target_layer, device=DEVICE)
lime_explainer = LIMEExplainer(model, device=DEVICE, num_samples=500)  # Reduce for speed

print("\n✓ Explainers initialized")


## 7. Single Image Analysis


In [None]:
def analyze_image(image_path: str, run_lime: bool = False):
    """
    Complete analysis pipeline for a single image.
    
    Args:
        image_path: Path to the image file
        run_lime: Whether to run LIME (slower)
    """
    print(f"Analyzing: {image_path}")
    print("=" * 50)
    
    # Load and preprocess
    img = load_image(image_path)
    input_tensor = preprocess_image(img).to(DEVICE)
    display_img = denormalize(input_tensor)
    
    # Get prediction
    pred_class, probs = gradcam_explainer.predict(input_tensor)
    pred_label = CLASS_NAMES[pred_class]
    confidence = probs[pred_class]
    
    print(f"\nPrediction: {pred_label}")
    print(f"Confidence: {confidence:.4f}")
    print(f"\nAll probabilities:")
    for name, prob in zip(CLASS_NAMES, probs):
        marker = "←" if name == pred_label else ""
        print(f"  {name}: {prob:.4f} {marker}")
    
    # Plot prediction bars
    plot_prediction_bars(probs)
    
    # Generate explanations
    heatmaps = {}
    
    print("\nGenerating GradCAM...")
    heatmaps["GradCAM"] = gradcam_explainer.explain(input_tensor, target_class=pred_class)
    
    if run_lime:
        print("Generating LIME (this may take a minute)...")
        heatmaps["LIME"] = lime_explainer.explain(input_tensor, target_class=pred_class)
    
    # Plot comparison
    pred_info = f"Prediction: {pred_label} ({confidence:.2%})"
    plot_method_comparison(display_img, heatmaps, pred_info=pred_info)
    
    return pred_class, probs, heatmaps


In [None]:
# Example usage - replace with your actual image path
# image_path = "../data/raw/test_image.jpg"
# results = analyze_image(image_path, run_lime=True)

print("Ready to analyze images!")
print("Usage: results = analyze_image('path/to/image.jpg', run_lime=True)")


## 8. Batch Analysis


In [None]:
def batch_analyze(image_paths: List[str], output_dir: str = None, run_lime: bool = False):
    """
    Analyze multiple images and collect results.
    
    Args:
        image_paths: List of image file paths
        output_dir: Directory to save visualization results
        run_lime: Whether to run LIME analysis
        
    Returns:
        List of result dictionaries
    """
    results = []
    
    if output_dir:
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
    
    for i, img_path in enumerate(image_paths):
        print(f"\n[{i+1}/{len(image_paths)}] Processing: {img_path}")
        
        try:
            img = load_image(img_path)
            input_tensor = preprocess_image(img).to(DEVICE)
            display_img = denormalize(input_tensor)
            
            # Prediction
            pred_class, probs = gradcam_explainer.predict(input_tensor)
            
            # GradCAM
            gradcam_heatmap = gradcam_explainer.explain(input_tensor, target_class=pred_class)
            
            # LIME (optional)
            lime_heatmap = None
            if run_lime:
                lime_heatmap = lime_explainer.explain(input_tensor, target_class=pred_class)
            
            result = {
                'image_path': img_path,
                'prediction': CLASS_NAMES[pred_class],
                'confidence': probs[pred_class],
                'probabilities': probs.tolist(),
                'gradcam_heatmap': gradcam_heatmap,
                'lime_heatmap': lime_heatmap,
            }
            results.append(result)
            
            # Save visualization
            if output_dir:
                heatmaps = {'GradCAM': gradcam_heatmap}
                if lime_heatmap is not None:
                    heatmaps['LIME'] = lime_heatmap
                
                save_path = output_path / f"analysis_{i:03d}.png"
                pred_info = f"{CLASS_NAMES[pred_class]} ({probs[pred_class]:.2%})"
                plot_method_comparison(display_img, heatmaps, pred_info=pred_info, 
                                       save_path=str(save_path))
                plt.close()
                
        except Exception as e:
            print(f"  Error: {e}")
            results.append({'image_path': img_path, 'error': str(e)})
    
    print(f"\n✓ Processed {len(results)} images")
    return results

print("Batch analysis function defined.")


In [None]:
def summarize_results(results: List[dict]):
    """Print summary statistics from batch analysis."""
    valid_results = [r for r in results if 'error' not in r]
    
    if not valid_results:
        print("No valid results to summarize.")
        return
    
    print("\n" + "=" * 50)
    print("BATCH ANALYSIS SUMMARY")
    print("=" * 50)
    print(f"Total images: {len(results)}")
    print(f"Successfully analyzed: {len(valid_results)}")
    print(f"Errors: {len(results) - len(valid_results)}")
    
    # Class distribution
    class_counts = {name: 0 for name in CLASS_NAMES}
    confidences = []
    
    for r in valid_results:
        class_counts[r['prediction']] += 1
        confidences.append(r['confidence'])
    
    print(f"\nPrediction Distribution:")
    for name, count in class_counts.items():
        pct = count / len(valid_results) * 100
        print(f"  {name}: {count} ({pct:.1f}%)")
    
    print(f"\nConfidence Statistics:")
    print(f"  Mean: {np.mean(confidences):.4f}")
    print(f"  Min:  {np.min(confidences):.4f}")
    print(f"  Max:  {np.max(confidences):.4f}")
    print(f"  Std:  {np.std(confidences):.4f}")

print("Summary function defined.")


## 9. Results Export


In [None]:
def export_results_to_json(results: List[dict], output_path: str):
    """Export results to JSON file (excluding numpy arrays)."""
    export_data = []
    for r in results:
        entry = {
            'image_path': r.get('image_path'),
            'prediction': r.get('prediction'),
            'confidence': r.get('confidence'),
            'probabilities': r.get('probabilities'),
            'error': r.get('error'),
        }
        export_data.append(entry)
    
    with open(output_path, 'w') as f:
        json.dump(export_data, f, indent=2)
    
    print(f"Results exported to {output_path}")

# Example:
# export_results_to_json(results, '../reports/analysis_results.json')


---

## Quick Reference

### Single Image Analysis
```python
results = analyze_image('path/to/image.jpg', run_lime=True)
```

### Batch Analysis
```python
image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']
results = batch_analyze(image_paths, output_dir='../reports/figures', run_lime=False)
summarize_results(results)
export_results_to_json(results, '../reports/analysis_results.json')
```

### Manual Explanation
```python
img = load_image('path/to/image.jpg')
input_tensor = preprocess_image(img).to(DEVICE)
heatmap = gradcam_explainer.explain(input_tensor)
plot_single_explanation(denormalize(input_tensor), heatmap)
```
