## 1. Imports

# Nuisance Aware Representation Refinement (NARR) for Robust Deepfake Detection

This notebook implements the NARR model, a novel approach for deepfake detection that focuses on refining feature representations by estimating and mitigating nuisance factors that affect detection robustness.

## Overview

NARR addresses the challenge of deepfake detection under various corruptions and cross-dataset scenarios by:

1. **Nuisance Estimation**: Learning to identify nuisance factors in feature representations
2. **Adaptive Refinement**: Using learned gates to suppress nuisance while preserving discriminative features
3. **Domain Adversarial Training**: Improving generalization across different data distributions
4. **Contrastive Invariance**: Ensuring robustness to image corruptions

## Key Components

- **CNN Backbone**: ResNet-34 feature extractor
- **Multi-Scale Nuisance Estimator**: Estimates nuisance at multiple scales
- **Adaptive Gates**: Channel-wise and spatial gating for feature refinement
- **Token-based Classifier**: Transformer-based classification with token pooling
- **Training Objectives**: Classification + Invariance Contrastive + Domain Adversarial losses

## Sections

1. Imports and Dependencies
2. Configuration and Reproducibility
3. Dataset Classes
4. Data Augmentations and Corruptions
5. Corruption Functions for Training
6. Model Architecture Components
7. Tokenization and Classification Heads
8. Loss Functions
9. Training and Evaluation Functions
10. Main Training Loop
11. Model Loading
12. FF++ Test Set Evaluation
13. JPEG Compression Robustness Test
14. Cross-Dataset Evaluations (DFDC, Celeb-DF)

In [None]:
import os
import io
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score
)

# Set up device for GPU acceleration if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Config & Reproducibility

In [None]:
class CFG:
    """Configuration class containing all hyperparameters and settings"""
    SEED = 42
    IMG_SIZE = 224
    BATCH_SIZE = 16
    NUM_WORKERS = 0  # Set to 0 for Windows compatibility
    LR = 1e-4

    # Loss weights for multi-objective training
    LAMBDA_INV = 0.05  # Weight for invariance contrastive loss
    LAMBDA_DOM = 0.2   # Weight for domain adversarial loss

    DATA_ROOT = "FFPP_CViT"  # Root directory for FaceForensics++ dataset
    WEIGHTS_DIR = "weights"  # Directory to save model checkpoints


def set_seed(seed=42):
    """Set random seeds for reproducibility across all libraries"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(CFG.SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

device(type='cuda')

## 3. Datasets

In [None]:
class BinaryImageFolder(Dataset):
    """
    Custom dataset class for loading binary classification image data.
    Expects directory structure: root/real/ and root/fake/ subdirectories.
    """
    def __init__(self, root, transform=None):
        self.samples = []
        self.transform = transform

        # Load samples from 'real' and 'fake' subdirectories
        for label, cls in enumerate(["real", "fake"]):
            cls_dir = os.path.join(root, cls)
            if not os.path.exists(cls_dir):
                continue
            for f in os.listdir(cls_dir):
                if f.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.samples.append((os.path.join(cls_dir, f), label))

        print(f"[Dataset] Loaded {len(self.samples)} samples from {root}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

## 4. Augmentations & Corruptions

In [None]:
class JPEGCompression:
    """Custom transform to simulate JPEG compression artifacts"""
    def __init__(self, quality):
        self.quality = quality

    def __call__(self, img):
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=self.quality)
        buffer.seek(0)
        return Image.open(buffer).convert("RGB")


class RandomGamma:
    """Random gamma correction augmentation"""
    def __init__(self, gamma_range=(0.7, 1.5), p=0.5):
        self.gamma_range = gamma_range
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            gamma = random.uniform(*self.gamma_range)
            return transforms.functional.adjust_gamma(img, gamma)
        return img


# Training augmentations: aggressive transforms to improve robustness
train_tfms = transforms.Compose([
    transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    transforms.RandomAffine(2, translate=(0.02, 0.02), scale=(0.95, 1.05), shear=2),
    transforms.ColorJitter(0.6, 0.6, 0.6, 0.15),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([
        transforms.GaussianBlur(3, sigma=(0.1, 2.0))
    ], p=0.3),
    RandomGamma(p=0.5),
    transforms.RandomApply([
        transforms.RandomAdjustSharpness(0.5)
    ], p=0.3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Evaluation augmentations: minimal transforms for fair evaluation
eval_tfms = transforms.Compose([
    transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    transforms.ToTensor(),
])


def build_jpeg_tfms(q):
    """Build transforms for JPEG compression robustness testing"""
    return transforms.Compose([
        JPEGCompression(q),
        transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
        transforms.ToTensor()
    ])

## 5. Corruption Functions (Training Only)

In [None]:
def corrupt_image(x):
    """Apply random spatial and noise corruptions to input tensor"""
    out = x.clone()

    # Spatial degradation: downscale and upsample
    if torch.rand(1).item() < 0.5:
        out = F.interpolate(out, scale_factor=0.75, mode="bilinear", align_corners=False)
        out = F.interpolate(out, size=x.shape[-2:], mode="bilinear", align_corners=False)

    # Additive noise
    if torch.rand(1).item() < 0.5:
        out = torch.clamp(out + 0.03 * torch.randn_like(out), 0, 1)

    return out


def freq_mix(x, alpha=0.15):
    """Apply frequency domain mixing corruption"""
    fft = torch.fft.fft2(x)
    mag, phase = torch.abs(fft), torch.angle(fft)

    # Perturb magnitude with Gaussian noise
    mag = mag * (1 + alpha * torch.randn_like(mag))

    # Reconstruct signal
    return torch.real(
        torch.fft.ifft2(mag * torch.exp(1j * phase))
    )

## 6. Model Components

In [None]:
class CNNBackbone(nn.Module):
    """ResNet-34 backbone for feature extraction"""
    def __init__(self):
        super().__init__()
        model = models.resnet34(pretrained=True)
        self.features = nn.Sequential(*list(model.children())[:-2])  # Remove final pooling and FC
        self.out_channels = 512

        # Freeze BatchNorm layers for stability
        for m in self.features.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                for p in m.parameters():
                    p.requires_grad = False

    def forward(self, x):
        return self.features(x)


class MultiScaleNuisanceEstimator(nn.Module):
    """Estimates nuisance factors using multi-scale convolutional features"""
    def __init__(self, channels):
        super().__init__()
        c = channels // 4

        # Multi-scale convolutions with different receptive fields
        self.conv1 = nn.Conv2d(channels, c, 1)                    # 1x1 conv
        self.conv3 = nn.Conv2d(channels, c, 3, padding=2, dilation=2)  # 3x3 dilated
        self.conv5 = nn.Conv2d(channels, c, 3, padding=4, dilation=4)  # 5x5 dilated

        self.proj = nn.Conv2d(3 * c, channels, 1)  # Project back to original channels
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        # Concatenate multi-scale features and project
        f = torch.cat([self.conv1(x), self.conv3(x), self.conv5(x)], dim=1)
        return self.act(self.proj(f))

In [None]:
class GradReverse(torch.autograd.Function):
    """Gradient reversal layer for domain adversarial training"""
    @staticmethod
    def forward(ctx, x, λ):
        ctx.λ = λ
        return x

    @staticmethod
    def backward(ctx, grad):
        return -ctx.λ * grad, None


class NARR(nn.Module):
    """
    Nuisance Aware Representation Refinement module.
    Learns to estimate and suppress nuisance factors while preserving discriminative features.
    """
    def __init__(self, channels):
        super().__init__()

        self.nuisance = MultiScaleNuisanceEstimator(channels)

        # Channel-wise gating: learns which channels contain nuisance
        self.gate_c = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels, 1),
            nn.Sigmoid()
        )

        # Spatial gating: learns which spatial locations contain nuisance
        self.gate_s = nn.Sequential(
            nn.Conv2d(channels, 1, 1),
            nn.Sigmoid()
        )

        # Learnable parameters for feature refinement equation
        self.alpha = nn.Parameter(torch.tensor(0.3))  # Suppression strength
        self.beta = nn.Parameter(torch.tensor(0.1))   # Enhancement strength

        # Domain classifier for adversarial training
        self.domain_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(channels, 2)  # Binary domain classification
        )

    def forward(self, F, lambda_grl=0.0):
        # Estimate nuisance factors
        N_hat = self.nuisance(F)

        # Compute gating signals
        Gc = self.gate_c(N_hat)          # [B, C, 1, 1] - Channel gates
        Gs = self.gate_s(N_hat)          # [B, 1, H, W] - Spatial gates
        G  = Gc * Gs                     # [B, C, H, W] - Combined gates

        # Clamp learnable parameters to [0, 1]
        alpha = torch.clamp(self.alpha, 0.0, 1.0)
        beta  = torch.clamp(self.beta,  0.0, 1.0)

        # Feature refinement equation: suppress nuisance, enhance clean features
        F_ref = F * (1 - alpha * G + beta * (1 - G))

        # Domain adversarial classification (if enabled)
        dom = None
        if lambda_grl > 0:
            rev = GradReverse.apply(N_hat, lambda_grl)
            dom = self.domain_head(rev)

        return F_ref, N_hat, G, dom

## 7. Tokenization & Classifier

In [None]:
class EmbeddingHead(nn.Module):
    """Converts feature maps to token embeddings for transformer processing"""
    def __init__(self, in_channels, embed_dim=256, num_tokens=8):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, 1)  # Project to embedding dimension
        self.pool = nn.AdaptiveAvgPool2d((num_tokens, 1))  # Create spatial tokens

    def forward(self, x):
        x = self.proj(x)                  # [B, D, H, W]
        x = self.pool(x)                  # [B, D, N, 1]
        return x.squeeze(-1).permute(0, 2, 1)  # [B, N, D] - Token sequence

In [None]:
class TokenClassifier(nn.Module):
    """Transformer-based classifier that processes token sequences"""
    def __init__(self, embed_dim):
        super().__init__()
        # 2-layer transformer encoder
        layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=4,
            dim_feedforward=512,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, 2)
        self.fc = nn.Linear(embed_dim, 1)  # Binary classification

    def forward(self, x):
        x = self.encoder(x)        # [B, N, D] -> [B, N, D]
        x = x.mean(dim=1)          # Token mean pooling
        return self.fc(x).squeeze(-1)  # [B] - Logits

In [None]:
class DeepfakeDetector(nn.Module):
    """Complete NARR-based deepfake detection model"""
    def __init__(self):
        super().__init__()
        self.backbone = CNNBackbone()
        self.narr = NARR(self.backbone.out_channels)
        self.embedder = EmbeddingHead(self.backbone.out_channels)
        self.classifier = TokenClassifier(256)

    def forward(self, x):
        f = self.backbone(x)                    # Extract features
        f_ref, _, _, _ = self.narr(f)           # Refine features with NARR
        tokens = self.embedder(f_ref)           # Convert to token sequence
        return self.classifier(tokens)          # Classify

## 8. Losses

In [None]:
# Standard binary cross-entropy loss for classification
criterion = nn.BCEWithLogitsLoss()

# Cross-entropy for domain adversarial training
domain_criterion = nn.CrossEntropyLoss()

def invariance_contrastive_loss(z1, z2, temp=0.2):
    """
    Contrastive loss to enforce invariance between clean and corrupted views.
    Pulls representations of the same image closer, pushes different images apart.
    """
    z1 = F.normalize(z1.mean(1), dim=1)  # Mean pool tokens and normalize
    z2 = F.normalize(z2.mean(1), dim=1)

    logits = (z1 @ z2.T / temp).clamp(-50, 50)  # Cosine similarity matrix
    labels = torch.arange(z1.size(0), device=z1.device)
    return F.cross_entropy(logits, labels)

## 9. Training & Evaluation

In [None]:
def train_epoch(loader, model, optimizer):
    """Single training epoch with multi-objective loss"""
    model.train()
    total = 0.0

    for x, y in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)

        # ---------- CLEAN FORWARD PASS ----------
        f = model.backbone(x)

        # NARR forward with GRL enabled for domain adversarial training
        f_ref, N_hat, _, dom_clean = model.narr(f, lambda_grl=0.1)
        tok_n = model.embedder(N_hat)  # Tokens from nuisance features

        # ---------- CORRUPTED VIEW GENERATION ----------
        with torch.no_grad():
            # Randomly choose between spatial/noise or frequency corruption
            if torch.rand(1) < 0.5:
                x_corr = corrupt_image(x)
            else:
                x_corr = freq_mix(x)

        # Forward pass through corrupted view
        f_c = model.backbone(x_corr)
        _, N_hat_c, _, dom_corrupt = model.narr(f_c, lambda_grl=0.1)
        tok_n_c = model.embedder(N_hat_c)

        # ---------- INVARIANCE CONTRASTIVE LOSS ----------
        # Enforce that nuisance representations are invariant to corruptions
        loss_inv = invariance_contrastive_loss(tok_n, tok_n_c)

        # ---------- CLASSIFICATION LOSS ----------
        tok = model.embedder(f_ref)  # Tokens from refined features
        logit = model.classifier(tok)
        loss_cls = criterion(logit, y)

        # ---------- DOMAIN ADVERSARIAL LOSS ----------
        # Domain classifier tries to distinguish clean vs corrupted
        # NARR tries to fool it by making domains indistinguishable
        dom_y_clean = torch.zeros(x.size(0), dtype=torch.long, device=device)
        dom_y_corrupt = torch.ones(x.size(0), dtype=torch.long, device=device)
        loss_dom = domain_criterion(dom_clean, dom_y_clean) + domain_criterion(dom_corrupt, dom_y_corrupt)

        # ---------- TOTAL LOSS ----------
        loss = loss_cls + CFG.LAMBDA_INV * loss_inv + CFG.LAMBDA_DOM * loss_dom

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total += loss.item()

    return total / len(loader)

In [None]:
@torch.no_grad()
def evaluate(loader, model, threshold=0.5):
    """Evaluate model on test set and compute comprehensive metrics"""
    model.eval()
    logits, labels = [], []

    for x, y in tqdm(loader, desc="Evaluating", leave=False):
        x = x.to(device)
        logits.append(model(x).cpu())
        labels.append(y)

    logits = torch.cat(logits).numpy()
    labels = torch.cat(labels).numpy()

    probs = 1 / (1 + np.exp(-logits))  # Sigmoid to get probabilities
    preds = (probs >= threshold).astype(int)

    return {
        "acc": accuracy_score(labels, preds),
        "auc": roc_auc_score(labels, probs),
        "precision": precision_score(labels, preds, zero_division=0),
        "recall": recall_score(labels, preds, zero_division=0),
        "f1": f1_score(labels, preds, zero_division=0),
    }

## 10. Training Loop

In [None]:
# Create weights directory for saving checkpoints
os.makedirs(CFG.WEIGHTS_DIR, exist_ok=True)

# Initialize model
model = DeepfakeDetector().to(device)

# Different learning rates for different components
optimizer = torch.optim.Adam([
    {"params": model.backbone.parameters(),  "lr": CFG.LR * 0.2},  # Lower LR for pretrained backbone
    {"params": model.narr.parameters(),      "lr": CFG.LR},        # Full LR for NARR module
    {"params": model.embedder.parameters(),  "lr": CFG.LR},        # Full LR for embedder
    {"params": model.classifier.parameters(),"lr": CFG.LR},        # Full LR for classifier
])

EPOCHS = 5

# Cosine annealing scheduler for gradual learning rate decay
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS
)

# Load datasets
train_ds = BinaryImageFolder(os.path.join(CFG.DATA_ROOT, "train"), train_tfms)
val_ds   = BinaryImageFolder(os.path.join(CFG.DATA_ROOT, "val"),   eval_tfms)

train_loader = DataLoader(
    train_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=True,
    num_workers=CFG.NUM_WORKERS
)

val_loader = DataLoader(
    val_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS
)

# Exponential moving average for AUC to reduce checkpoint saving noise
ema_auc = None
ema_decay = 0.8
best_ema_auc = -1.0



In [None]:
EPOCHS = 5

for epoch in range(EPOCHS):
    # Training phase
    avg_loss = train_epoch(train_loader, model, optimizer)

    # Validation phase
    val_metrics = evaluate(val_loader, model)
    current_auc = val_metrics["auc"]

    # Update exponential moving average of AUC
    if ema_auc is None:
        ema_auc = current_auc
    else:
        ema_auc = ema_decay * ema_auc + (1 - ema_decay) * current_auc

    # Logging
    print(
        f"Epoch {epoch+1:02d} | "
        f"Loss: {avg_loss:.4f} | "
        f"Val Acc: {val_metrics['acc']:.4f} | "
        f"AUC: {current_auc:.4f} | "
        f"EMA-AUC: {ema_auc:.4f} | "
        f"P: {val_metrics['precision']:.4f} | "
        f"R: {val_metrics['recall']:.4f} | "
        f"F1: {val_metrics['f1']:.4f}"
    )

    # Save best model based on EMA-AUC
    if ema_auc > best_ema_auc:
        best_ema_auc = ema_auc
        torch.save(
            model.state_dict(),
            f"{CFG.WEIGHTS_DIR}/best_NARR.pt"
        )
        print(f"  ✓ Saved new best model (EMA-AUC={best_ema_auc:.4f})")

    # Update learning rate
    scheduler.step()

                                                             

Epoch 01 | Loss: 0.6926 | Val Acc: 0.7471 | AUC: 0.8872 | EMA-AUC: 0.8872 | P: 0.9646 | R: 0.7185 | F1: 0.8236
  ✓ Saved new best model (EMA-AUC=0.8872)


                                                             

Epoch 02 | Loss: 0.5595 | Val Acc: 0.8822 | AUC: 0.9352 | EMA-AUC: 0.8968 | P: 0.9491 | R: 0.9051 | F1: 0.9266
  ✓ Saved new best model (EMA-AUC=0.8968)


Training:  96%|█████████▌| 2290/2391 [17:52<00:44,  2.28it/s]

## 11. Load Best Model (Once)

In [None]:
print("Loading best model...")
model.load_state_dict(
    torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)
)
model.eval()
print("✓ Best model loaded")
NUM_RUNS = 3  # Number of evaluation runs for stable metrics

Loading best model...
✓ Best model loaded


  torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)


## 12. FF++ Test Set Evaluation

In [None]:
print("\n===== FF++ TEST | AVERAGED OVER 3 RUNS =====")

all_metrics = []

# Load FF++ test set
ffpp_test_ds = BinaryImageFolder(
    os.path.join(CFG.DATA_ROOT, "test"),
    eval_tfms
)

ffpp_test_loader = DataLoader(
    ffpp_test_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS
)

# Multiple evaluation runs for statistical stability
for run_idx in range(NUM_RUNS):
    set_seed(CFG.SEED + run_idx)
    metrics = evaluate(ffpp_test_loader, model)
    all_metrics.append(metrics)

# Average metrics across runs
avg_metrics = {
    k: sum(m[k] for m in all_metrics) / NUM_RUNS
    for k in all_metrics[0]
}

for k, v in avg_metrics.items():
    print(f"{k.upper():>10}: {v:.4f}")


===== FF++ TEST | AVERAGED OVER 3 RUNS =====


                                                             

       ACC: 0.7737
       AUC: 0.9182
 PRECISION: 0.9761
    RECALL: 0.7433
        F1: 0.8439




## 13. JPEG Compression Robustness Test

In [None]:
print("\n===== JPEG COMPRESSION TEST | AVERAGED OVER 3 RUNS =====")

jpeg_qualities = [100, 90, 75, 50, 30]

for q in jpeg_qualities:
    print(f"\n--- JPEG Quality {q}% ---")
    run_metrics = []

    for run_idx in range(NUM_RUNS):
        set_seed(CFG.SEED + run_idx)
        # Create dataset with JPEG compression at quality q
        jpeg_ds = BinaryImageFolder(
            os.path.join(CFG.DATA_ROOT, "test"),
            build_jpeg_tfms(q)
)

        jpeg_loader = DataLoader(
            jpeg_ds,
            batch_size=CFG.BATCH_SIZE,
            shuffle=False,
            num_workers=CFG.NUM_WORKERS
)

        metrics = evaluate(jpeg_loader, model)
        run_metrics.append(metrics)

    # Average metrics for this compression level
    avg_auc = sum(m["auc"] for m in run_metrics) / NUM_RUNS
    avg_acc = sum(m["acc"] for m in run_metrics) / NUM_RUNS
    avg_f1  = sum(m["f1"]  for m in run_metrics) / NUM_RUNS

    print(
        f"AUC: {avg_auc:.4f} | ",
        f"ACC: {avg_acc:.4f} | ",
        f"F1: {avg_f1:.4f}"
)


===== JPEG COMPRESSION TEST | AVERAGED OVER 3 RUNS =====

--- JPEG Quality 100% ---


                                                             

AUC: 0.9198 |  ACC: 0.7761 |  F1: 0.8458

--- JPEG Quality 90% ---


                                                             

AUC: 0.9209 |  ACC: 0.8018 |  F1: 0.8664

--- JPEG Quality 75% ---


                                                             

AUC: 0.9013 |  ACC: 0.6977 |  F1: 0.7785

--- JPEG Quality 50% ---


                                                             

AUC: 0.8648 |  ACC: 0.5819 |  F1: 0.6643

--- JPEG Quality 30% ---


                                                             

AUC: 0.8233 |  ACC: 0.4590 |  F1: 0.5163




## 14. DFDC Cross-Dataset Evaluation

In [None]:
print("\n===== DFDC CROSS-DATASET TEST | AVERAGED OVER 3 RUNS =====")

DFDC_ROOT = "./DFDC/validation"

# Load DFDC dataset (different distribution from FF++)
dfdc_ds = BinaryImageFolder(
    DFDC_ROOT,
    eval_tfms
)

dfdc_loader = DataLoader(
    dfdc_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS
)

all_metrics = []

for run_idx in range(NUM_RUNS):
    set_seed(CFG.SEED + run_idx)
    metrics = evaluate(dfdc_loader, model, threshold=0.5)
    all_metrics.append(metrics)

avg_metrics = {
    k: sum(m[k] for m in all_metrics) / NUM_RUNS
    for k in all_metrics[0]
}

for k, v in avg_metrics.items():
    print(f"{k.upper():>10}: {v:.4f}")


===== DFDC CROSS-DATASET TEST | AVERAGED OVER 3 RUNS =====


                                                               

       ACC: 0.6606
       AUC: 0.6278
 PRECISION: 0.8291
    RECALL: 0.7282
        F1: 0.7753




In [None]:
print("\n===== CELEB-DF CROSS-DATASET (NARR) | AVERAGED OVER 3 RUNS =====")

CELEBDF_ROOT = "./CelebDF_images/test"

# Load Celeb-DF dataset (different distribution from FF++)
celeb_ds = BinaryImageFolder(
    CELEBDF_ROOT,
    eval_tfms
)

celeb_loader = DataLoader(
    celeb_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS,
    pin_memory=True
)

# Load best NARR model
model.load_state_dict(
    torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)
)
model.eval()

all_metrics = []

for run_idx in range(NUM_RUNS):
    set_seed(CFG.SEED + run_idx)
    metrics = evaluate(
        celeb_loader,
        model,
        threshold=0.5
)
    all_metrics.append(metrics)

avg_metrics = {
    k: sum(m[k] for m in all_metrics) / NUM_RUNS
    for k in all_metrics[0]
}

for k, v in avg_metrics.items():
    print(f"{k.upper():>10}: {v:.4f}")


===== CELEB-DF CROSS-DATASET (NARR) | AVERAGED OVER 3 RUNS =====


  torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)
                                                             

       ACC: 0.6174
       AUC: 0.6943
 PRECISION: 0.7850
    RECALL: 0.5746
        F1: 0.6636


