# NARR Deepfake Detector (Benchmark Notebook)

This notebook trains/evaluates the **NARR** model (nuisance-aware representation refinement) on FF++ and optionally tests robustness/cross-dataset generalization.

**Run order (typical):**
1. Imports + config
2. Dataset / augmentations
3. Model + losses
4. (Optional) training loop
5. Load best weights
6. Evaluation blocks (FF++, JPEG, DFDC, CelebDF)
7. Params/FLOPs report

## 1. Imports

In [3]:
import os
import io
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score
)


## 2. Config & Reproducibility

### Notes
- `CFG.DATA_ROOT` should point at a folder with `train/`, `val/`, `test/` and each split containing `real/` + `fake/`.
- `CFG.WEIGHTS_DIR` is where checkpoints are saved/loaded from.
- Most blocks below are written so you can **toggle training/evaluation** by changing `EPOCHS` and `NUM_RUNS`.

In [4]:
class CFG:
    SEED = 42
    IMG_SIZE = 224
    BATCH_SIZE = 16
    NUM_WORKERS = 0
    LR = 1e-4

    LAMBDA_INV = 0.05
    LAMBDA_DOM = 0.1

    DATA_ROOT = "FFPP_CViT"
    WEIGHTS_DIR = "weights"


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(CFG.SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

## 3. Datasets

In [5]:
class BinaryImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.samples = []
        self.transform = transform

        for label, cls in enumerate(["real", "fake"]):
            cls_dir = os.path.join(root, cls)
            for f in os.listdir(cls_dir):
                if f.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.samples.append((os.path.join(cls_dir, f), label))

        #print(f"[Dataset] Loaded {len(self.samples)} samples from {root}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)


## 4. Augmentations & Corruptions

In [6]:
class JPEGCompression:
    def __init__(self, quality):
        self.quality = quality

    def __call__(self, img):
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=self.quality)
        buffer.seek(0)
        return Image.open(buffer).convert("RGB")


class RandomGamma:
    def __init__(self, gamma_range=(0.7, 1.5), p=0.5):
        self.gamma_range = gamma_range
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            gamma = random.uniform(*self.gamma_range)
            return transforms.functional.adjust_gamma(img, gamma)
        return img


train_tfms = transforms.Compose([
    transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    transforms.RandomAffine(2, translate=(0.02, 0.02), scale=(0.95, 1.05), shear=2),

    transforms.ColorJitter(0.6, 0.6, 0.6, 0.15),

    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))
    ], p=0.3),

    transforms.RandomGrayscale(p=0.2),
    RandomGamma(p=0.5),

    transforms.RandomApply([
        transforms.RandomAdjustSharpness(0.5)
    ], p=0.3),

    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


eval_tfms = transforms.Compose([
    transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    transforms.ToTensor(),
])


def build_jpeg_tfms(q):
    return transforms.Compose([
        JPEGCompression(q),
        transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
        transforms.ToTensor()
    ])


## 5. Corruption Functions (Training Only)

In [7]:
def corrupt_image(x):
    out = x.clone()

    # spatial degradation
    if torch.rand(1).item() < 0.5:
        out = F.interpolate(out, scale_factor=0.75, mode="bilinear", align_corners=False)
        out = F.interpolate(out, size=x.shape[-2:], mode="bilinear", align_corners=False)

    # noise
    if torch.rand(1).item() < 0.5:
        out = torch.clamp(out + 0.03 * torch.randn_like(out), 0, 1)

    return out


def freq_mix(x, alpha=0.15):
    fft = torch.fft.fft2(x)
    mag, phase = torch.abs(fft), torch.angle(fft)

    mag = mag * (1 + alpha * torch.randn_like(mag))

    return torch.real(
        torch.fft.ifft2(mag * torch.exp(1j * phase))
    )


## 6. Model Components

In [8]:
class CNNBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet34(pretrained=True)
        self.features = nn.Sequential(*list(model.children())[:-2])
        self.out_channels = 512

        for m in self.features.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                for p in m.parameters():
                    p.requires_grad = False

    def forward(self, x):
        return self.features(x)

class MultiScaleNuisanceEstimator(nn.Module):
    def __init__(self, channels):
        super().__init__()
        c = channels // 4

        self.conv1 = nn.Conv2d(channels, c, 1)
        self.conv3 = nn.Conv2d(channels, c, 3, padding=2, dilation=2)
        self.conv5 = nn.Conv2d(channels, c, 3, padding=4, dilation=4)

        self.proj = nn.Conv2d(3 * c, channels, 1)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        f = torch.cat([self.conv1(x), self.conv3(x), self.conv5(x)], dim=1)
        return self.act(self.proj(f))

class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, λ):
        ctx.λ = λ
        return x

    @staticmethod
    def backward(ctx, grad):
        return -ctx.λ * grad, None


In [9]:
class NARR(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.nuisance = MultiScaleNuisanceEstimator(channels)

        self.gate_c = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels, 1),
            nn.Sigmoid()
        )

        self.gate_s = nn.Sequential(
            nn.Conv2d(channels, 1, 1),
            nn.Sigmoid()
        )

        self.alpha = nn.Parameter(torch.tensor(0.3))
        self.beta = nn.Parameter(torch.tensor(0.1))

        self.domain_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(channels, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )


    def forward(self, F, lambda_grl=0.0):
        # ----- Nuisance estimation -----
        N_hat = self.nuisance(F)

        # ----- Gates -----
        Gc = self.gate_c(N_hat)          # [B, C, 1, 1]
        Gs = self.gate_s(N_hat)          # [B, 1, H, W]
        G  = Gc * Gs                     # [B, C, H, W]

        alpha = torch.clamp(self.alpha, 0.0, 1.0)
        beta  = torch.clamp(self.beta,  0.0, 1.0)

        # ----- suppression equation -----
        F_ref = F * (1 - alpha * G + beta * (1 - G))


        # ----- Domain adversarial head -----
        dom = None
        if lambda_grl > 0:
            rev = GradReverse.apply(N_hat, lambda_grl)
            dom = self.domain_head(rev)

        return F_ref, N_hat, G, dom


## 7. Tokenization & Classifier

In [10]:
class EmbeddingHead(nn.Module):
    def __init__(self, in_channels, embed_dim=256, num_tokens=8):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, 1)
        self.pool = nn.AdaptiveAvgPool2d((num_tokens, 1))

    def forward(self, x):
        x = self.proj(x)                  # [B, D, H, W]
        x = self.pool(x)                  # [B, D, N, 1]
        return x.squeeze(-1).permute(0, 2, 1)  # [B, N, D]

In [11]:
class TokenClassifier(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=4,
            dim_feedforward=512,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, 2)
        self.fc = nn.Linear(embed_dim, 1)

    def forward(self, x):
        x = self.encoder(x)        # [B, N, D]
        x = x.mean(dim=1)          # token mean pooling
        return self.fc(x).squeeze(-1)

In [12]:
class DeepfakeDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = CNNBackbone()
        self.narr = NARR(self.backbone.out_channels)
        self.embedder = EmbeddingHead(self.backbone.out_channels)
        self.classifier = TokenClassifier(256)

    def forward(self, x):
        f = self.backbone(x)
        f_ref, _, _, _ = self.narr(f)
        tokens = self.embedder(f_ref)
        return self.classifier(tokens)


## 8. Losses

In [13]:
criterion = nn.BCEWithLogitsLoss()
domain_criterion = nn.CrossEntropyLoss()

def invariance_contrastive_loss(z1, z2, temp=0.2):
    z1 = F.normalize(z1.mean(1), dim=1)
    z2 = F.normalize(z2.mean(1), dim=1)
    logits = (z1 @ z2.T / temp).clamp(-50, 50)
    labels = torch.arange(z1.size(0), device=z1.device)
    return F.cross_entropy(logits, labels)


## 9. Training & Evaluation

In [14]:
# %%
def train_epoch(loader, model, optimizer):
    model.train()
    total = 0.0

    for x, y in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)

        # ---------- CLEAN FORWARD ----------
        f = model.backbone(x)

        # NARR forward with GRL enabled (domain adversarial)
        f_ref, N_hat, _, dom_clean = model.narr(f, lambda_grl=0.1)
        tok_n = model.embedder(N_hat)

        # ---------- CORRUPTED VIEW ----------
        with torch.no_grad():
            if torch.rand(1) < 0.5:
                x_corr = corrupt_image(x)
            else:
                x_corr = freq_mix(x)

        f_c = model.backbone(x_corr)
        _, N_hat_c, _, dom_corrupt = model.narr(f_c, lambda_grl=0.1)
        tok_n_c = model.embedder(N_hat_c)

        # ---------- INVARIANCE LOSS ----------
        loss_inv = invariance_contrastive_loss(tok_n, tok_n_c)

        # ---------- CLASSIFICATION ----------
        tok = model.embedder(f_ref)
        logit = model.classifier(tok)
        loss_cls = criterion(logit, y)

        # ---------- DOMAIN ADVERSARIAL LOSS ----------
        dom_y_clean = torch.zeros(x.size(0), dtype=torch.long, device=device)
        dom_y_corrupt = torch.ones(x.size(0), dtype=torch.long, device=device)
        loss_dom = domain_criterion(dom_clean, dom_y_clean) + domain_criterion(dom_corrupt, dom_y_corrupt)

        # ---------- TOTAL ----------
        loss = loss_cls + CFG.LAMBDA_INV * loss_inv + CFG.LAMBDA_DOM * loss_dom

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total += loss.item()

    return total / len(loader)


In [15]:
@torch.no_grad()
def evaluate(loader, model, threshold=0.5):
    model.eval()
    logits, labels = [], []

    for x, y in tqdm(loader, desc="Evaluating", leave=False):
        x = x.to(device)
        logits.append(model(x).cpu())
        labels.append(y)

    logits = torch.cat(logits).numpy()
    labels = torch.cat(labels).numpy()

    probs = 1 / (1 + np.exp(-logits))
    preds = (probs >= threshold).astype(int)

    return {
        "acc": accuracy_score(labels, preds),
        "auc": roc_auc_score(labels, probs),
        "precision": precision_score(labels, preds, zero_division=0),
        "recall": recall_score(labels, preds, zero_division=0),
        "f1": f1_score(labels, preds, zero_division=0),
    }

## 10. Training Loop

In [16]:
os.makedirs(CFG.WEIGHTS_DIR, exist_ok=True)

model = DeepfakeDetector().to(device)
optimizer = torch.optim.Adam([
    {"params": model.backbone.parameters(),  "lr": CFG.LR * 0.2},
    {"params": model.narr.parameters(),      "lr": CFG.LR},
    {"params": model.embedder.parameters(),  "lr": CFG.LR},
    {"params": model.classifier.parameters(),"lr": CFG.LR},
])
EPOCHS = 5

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS
)

train_ds = BinaryImageFolder(os.path.join(CFG.DATA_ROOT, "train"), train_tfms)
val_ds   = BinaryImageFolder(os.path.join(CFG.DATA_ROOT, "val"),   eval_tfms)

train_loader = DataLoader(
    train_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=True,
    num_workers=CFG.NUM_WORKERS
)

val_loader = DataLoader(
    val_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS
)

ema_auc = None
ema_decay = 0.8
best_ema_auc = -1.0



In [None]:
# Set EPOCHS > 0 to train. Keep at 0 to skip training (useful when you only want to run eval).
EPOCHS = 0

for epoch in range(EPOCHS):

    avg_loss = train_epoch(train_loader, model, optimizer)
    val_metrics = evaluate(val_loader, model)
    current_auc = val_metrics["auc"]

    if ema_auc is None:
        ema_auc = current_auc
    else:
        ema_auc = ema_decay * ema_auc + (1 - ema_decay) * current_auc

    print(
        f"Epoch {epoch+1:02d} | "
        f"Loss: {avg_loss:.4f} | "
        f"Val Acc: {val_metrics['acc']:.4f} | "
        f"AUC: {current_auc:.4f} | "
        f"EMA-AUC: {ema_auc:.4f} | "
        f"P: {val_metrics['precision']:.4f} | "
        f"R: {val_metrics['recall']:.4f} | "
        f"F1: {val_metrics['f1']:.4f}"
    )

    if ema_auc > best_ema_auc:
        best_ema_auc = ema_auc
        torch.save(
            model.state_dict(),
            f"{CFG.WEIGHTS_DIR}/best_NARR.pt"
        )
        print(f"  ✓ Saved new best model (EMA-AUC={best_ema_auc:.4f})")

    scheduler.step()

## 11. Load Best Model (Once)

In [None]:
# Load a trained checkpoint before running the evaluation blocks below.
print("Loading best model...")
model.load_state_dict(
    torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)
)
model.eval()
print("✓ Best model loaded")

# Number of repeated evaluation runs (different seeds) to average metrics over.
# Set to 1 for a single run, >1 to reduce variance, or 0 to skip evaluation blocks safely.
NUM_RUNS = 1

Loading best model...


  torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)


✓ Best model loaded


## 12. FF++ Test Set Evaluation

In [None]:
print(f"\n===== FF++ TEST | AVERAGED OVER {NUM_RUNS} RUN(S) =====")

if NUM_RUNS <= 0:
    print("NUM_RUNS <= 0; skipping FF++ evaluation.")
else:
    all_metrics = []

    ffpp_test_ds = BinaryImageFolder(
        os.path.join(CFG.DATA_ROOT, "test"),
        eval_tfms
)

    ffpp_test_loader = DataLoader(
        ffpp_test_ds,
        batch_size=CFG.BATCH_SIZE,
        shuffle=False,
        num_workers=CFG.NUM_WORKERS
)

    for run_idx in range(NUM_RUNS):
        set_seed(CFG.SEED + run_idx)
        metrics = evaluate(ffpp_test_loader, model)
        all_metrics.append(metrics)

    avg_metrics = {
        k: sum(m[k] for m in all_metrics) / NUM_RUNS
        for k in all_metrics[0]
    }

    for k, v in avg_metrics.items():
        print(f"{k.upper():>10}: {v:.4f}")


===== FF++ TEST | AVERAGED OVER 3 RUNS =====


IndexError: list index out of range

## 13. JPEG Compression Robustness Test

In [None]:
print(f"\n===== JPEG COMPRESSION TEST | AVERAGED OVER {NUM_RUNS} RUN(S) =====")

if NUM_RUNS <= 0:
    print("NUM_RUNS <= 0; skipping JPEG robustness evaluation.")
else:
    jpeg_qualities = [100, 90, 75, 50, 30]

    for q in jpeg_qualities:
        print(f"\n--- JPEG Quality {q}% ---")
        run_metrics = []

        for run_idx in range(NUM_RUNS):
            set_seed(CFG.SEED + run_idx)
            jpeg_ds = BinaryImageFolder(
                os.path.join(CFG.DATA_ROOT, "test"),
                build_jpeg_tfms(q)
)

            jpeg_loader = DataLoader(
                jpeg_ds,
                batch_size=CFG.BATCH_SIZE,
                shuffle=False,
                num_workers=CFG.NUM_WORKERS
)

            metrics = evaluate(jpeg_loader, model)
            run_metrics.append(metrics)

        avg_auc = sum(m["auc"] for m in run_metrics) / NUM_RUNS
        avg_acc = sum(m["acc"] for m in run_metrics) / NUM_RUNS
        avg_f1  = sum(m["f1"]  for m in run_metrics) / NUM_RUNS

        print(
            f"AUC: {avg_auc:.4f} | ",
            f"ACC: {avg_acc:.4f} | ",
            f"F1: {avg_f1:.4f}"
)


===== JPEG COMPRESSION TEST | AVERAGED OVER 3 RUNS =====

--- JPEG Quality 100% ---


                                                             

AUC: 0.9379 |  ACC: 0.8726 |  F1: 0.9188

--- JPEG Quality 90% ---


                                                             

AUC: 0.9348 |  ACC: 0.8835 |  F1: 0.9272

--- JPEG Quality 75% ---


                                                             

AUC: 0.9214 |  ACC: 0.8036 |  F1: 0.8675

--- JPEG Quality 50% ---


                                                             

AUC: 0.8888 |  ACC: 0.7130 |  F1: 0.7929

--- JPEG Quality 30% ---


                                                             

AUC: 0.8533 |  ACC: 0.6320 |  F1: 0.7175




## 14. DFDC Cross-Dataset Evaluation

In [None]:
print(f"\n===== DFDC CROSS-DATASET TEST | AVERAGED OVER {NUM_RUNS} RUN(S) =====")

if NUM_RUNS <= 0:
    print("NUM_RUNS <= 0; skipping DFDC cross-dataset evaluation.")
else:
    DFDC_ROOT = "./DFDC/validation"

    dfdc_ds = BinaryImageFolder(
        DFDC_ROOT,
        eval_tfms
)

    dfdc_loader = DataLoader(
        dfdc_ds,
        batch_size=CFG.BATCH_SIZE,
        shuffle=False,
        num_workers=CFG.NUM_WORKERS
)

    all_metrics = []

    for run_idx in range(NUM_RUNS):
        set_seed(CFG.SEED + run_idx)
        metrics = evaluate(dfdc_loader, model, threshold=0.5)
        all_metrics.append(metrics)

    avg_metrics = {
        k: sum(m[k] for m in all_metrics) / NUM_RUNS
        for k in all_metrics[0]
    }

    for k, v in avg_metrics.items():
        print(f"{k.upper():>10}: {v:.4f}")


===== DFDC CROSS-DATASET TEST | AVERAGED OVER 3 RUNS =====


                                                               

       ACC: 0.6909
       AUC: 0.7052
 PRECISION: 0.8741
    RECALL: 0.7193
        F1: 0.7892




In [None]:
print(f"\n===== CELEB-DF CROSS-DATASET (NARR) | AVERAGED OVER {NUM_RUNS} RUN(S) =====")

if NUM_RUNS <= 0:
    print("NUM_RUNS <= 0; skipping Celeb-DF cross-dataset evaluation.")
else:
    CELEBDF_ROOT = "./CelebDF_images/test"

    celeb_ds = BinaryImageFolder(
        CELEBDF_ROOT,
        eval_tfms
)

    celeb_loader = DataLoader(
        celeb_ds,
        batch_size=CFG.BATCH_SIZE,
        shuffle=False,
        num_workers=CFG.NUM_WORKERS,
        pin_memory=True
)

    # (Re)load best NARR model to be explicit about the weights used for reporting
    model.load_state_dict(
        torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)
)
    model.eval()

    all_metrics = []

    for run_idx in range(NUM_RUNS):
        set_seed(CFG.SEED + run_idx)
        metrics = evaluate(
            celeb_loader,
            model,
            threshold=0.5
)
        all_metrics.append(metrics)

    avg_metrics = {
        k: sum(m[k] for m in all_metrics) / NUM_RUNS
        for k in all_metrics[0]
}

    for k, v in avg_metrics.items():
        print(f"{k.upper():>10}: {v:.4f}")


===== CELEB-DF CROSS-DATASET (NARR) | AVERAGED OVER 3 RUNS =====


  torch.load(f"{CFG.WEIGHTS_DIR}/best_NARR.pt", map_location=device)
                                                             

       ACC: 0.6856
       AUC: 0.7113
 PRECISION: 0.7578
    RECALL: 0.7660
        F1: 0.7619




In [None]:
# =====================================
# COMPUTE PARAMS + FLOPs FOR PAPER
# =====================================
from thop import profile
from thop import clever_format

def compute_model_cost(model, name="Model"):
    """Compute Params/FLOPs using a single dummy forward pass."""
    model.eval().to(device)

    # Dummy input at training resolution
    dummy = torch.randn(1, 3, CFG.IMG_SIZE, CFG.IMG_SIZE).to(device)

    # thop returns raw numbers (multiply-adds convention depends on ops)
    flops, params = profile(
        model,
        inputs=(dummy,),
        verbose=False
)

    # Pretty formatting (K, M, G)
    flops_str, params_str = clever_format([flops, params], "%.3f")

    print(f"\n===== {name} =====")
    print(f"Params : {params_str}")
    print(f"FLOPs  : {flops_str}")

    return flops, params

## 15. Model Cost (Params + FLOPs)
Uses `thop` to estimate compute cost with a dummy input at `CFG.IMG_SIZE`. If you haven’t installed it yet: `pip install thop`.

In [None]:
# Instantiate a fresh model for cost reporting (weights are not required for Params/FLOPs).
narr_model = DeepfakeDetector().to(device)

# Compute costs
flops_narr, params_narr = compute_model_cost(
    narr_model,
    name="NARR Detector",
)




===== NARR Detector =====
Params : 23.650M
FLOPs  : 3.760G
