# Baseline Deepfake Detector (Benchmark Notebook)

This notebook is the **fair baseline** used alongside NARR: same backbone/tokenization/classifier training protocol family, but **without** the NARR module.

**Run order (typical):**
1. Imports + config
2. Dataset / augmentations
3. Model
4. Training loop
5. Evaluation blocks (FF++, JPEG, DFDC, CelebDF)
6. Params/FLOPs report

## 1. Imports

In [69]:
import os
import io
import random
import numpy as np
from PIL import Image
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score
)

## 2. Config & Reproducibility

### Notes
- `CFG.DATA_ROOT` should point at a folder with `train/`, `val/`, `test/` and each split containing `real/` + `fake/`.
- `CFG.WEIGHTS_DIR` is where checkpoints are saved/loaded from.
- Evaluation blocks use `NUM_RUNS` to average metrics over multiple seeded runs.

In [70]:
class CFG:
    SEED = 42
    IMG_SIZE = 224
    BATCH_SIZE = 16
    NUM_WORKERS = 0
    LR = 1e-4
    
    LAMBDA_INV = 0.05
    LAMBDA_DOM = 0.1

    DATA_ROOT = "FFPP_CViT"
    WEIGHTS_DIR = "weights"


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(CFG.SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## 3. Dataset (Unified Binary Folder)

In [71]:
class BinaryImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.samples = []
        self.transform = transform

        for label, cls in enumerate(["real", "fake"]):
            cls_dir = os.path.join(root, cls)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.samples.append(
                        (os.path.join(cls_dir, fname), label)
                    )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

## 4. Standard Augmentations & Test Corruptions

In [72]:
class JPEGCompression:
    def __init__(self, quality):
        self.quality = quality

    def __call__(self, img):
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=self.quality)
        buffer.seek(0)
        return Image.open(buffer).convert("RGB")


class RandomGamma:
    def __init__(self, gamma_range=(0.7, 1.5), p=0.5):
        self.gamma_range = gamma_range
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            gamma = random.uniform(*self.gamma_range)
            return transforms.functional.adjust_gamma(img, gamma)
        return img

In [73]:
train_tfms = transforms.Compose([
    transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    transforms.RandomAffine(2, translate=(0.02, 0.02), scale=(0.95, 1.05), shear=2),
    transforms.ColorJitter(0.6, 0.6, 0.6, 0.15),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))
    ], p=0.3),
    transforms.RandomGrayscale(p=0.2),
    RandomGamma(p=0.5),
    transforms.RandomApply([
        transforms.RandomAdjustSharpness(0.5)
    ], p=0.3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

eval_tfms = transforms.Compose([
    transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    transforms.ToTensor(),
])


def build_jpeg_tfms(q):
    return transforms.Compose([
        JPEGCompression(q),
        transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
        transforms.ToTensor(),
    ])

## 4.5 Corruption Functions (Training Only)

In [74]:
def corrupt_image(x):
    out = x.clone()

    if torch.rand(1).item() < 0.5:
        out = F.interpolate(out, scale_factor=0.75, mode="bilinear", align_corners=False)
        out = F.interpolate(out, size=x.shape[-2:], mode="bilinear", align_corners=False)

    if torch.rand(1).item() < 0.5:
        out = torch.clamp(out + 0.03 * torch.randn_like(out), 0, 1)

    return out


def freq_mix(x, alpha=0.15):
    fft = torch.fft.fft2(x)
    mag, phase = torch.abs(fft), torch.angle(fft)
    mag = mag * (1 + alpha * torch.randn_like(mag))
    return torch.real(torch.fft.ifft2(mag * torch.exp(1j * phase)))

## 5. Backbone

In [75]:
class CNNBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet34(pretrained=True)
        self.features = nn.Sequential(*list(model.children())[:-2])
        self.out_channels = 512

        for m in self.features.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                for p in m.parameters():
                    p.requires_grad = False

    def forward(self, x):
        return self.features(x)


class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, λ):
        ctx.λ = λ
        return x

    @staticmethod
    def backward(ctx, grad):
        return -ctx.λ * grad, None

## 6. Tokenization & Classifier

In [76]:
class EmbeddingHead(nn.Module):
    def __init__(self, in_channels, embed_dim=256, num_tokens=8):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, 1)
        self.pool = nn.AdaptiveAvgPool2d((num_tokens, 1))

    def forward(self, x):
        x = self.proj(x)
        x = self.pool(x)
        return x.squeeze(-1).permute(0, 2, 1)

In [77]:
class TokenClassifier(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=4,
            dim_feedforward=512,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, 2)
        self.fc = nn.Linear(embed_dim, 1)

    def forward(self, x):
        x = self.encoder(x).mean(dim=1)
        return self.fc(x).squeeze(-1)

## 7. Baseline Model

In [78]:
class DummyDomainHead(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(channels, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def forward(self, x, lambda_grl=0.0):
        if lambda_grl > 0:
            x = GradReverse.apply(x, lambda_grl)
        return self.head(x)


class DeepfakeDetectorBaselinePP(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = CNNBackbone()
        self.domain_head = DummyDomainHead(self.backbone.out_channels)
        self.embedder = EmbeddingHead(self.backbone.out_channels)
        self.classifier = TokenClassifier(256)

    def forward(self, x):
        f = self.backbone(x)
        tokens = self.embedder(f)
        return self.classifier(tokens)

### Fair Baseline Protocol
This baseline keeps the same training protocol family as NARR (classification + invariance + domain adversarial), but removes the NARR module itself.
- No nuisance estimator / suppression block
- Backbone, tokenizer, and classifier remain
- Domain adversarial control head kept for protocol parity

## 8. Training Setup (Fair Baseline: NARR Removed)

In [79]:
EPOCHS = 5

model = DeepfakeDetectorBaselinePP().to(device)
optimizer = torch.optim.Adam([
    {"params": model.backbone.parameters(),  "lr": CFG.LR * 0.2},
    {"params": model.domain_head.parameters(),"lr": CFG.LR},
    {"params": model.embedder.parameters(),  "lr": CFG.LR},
    {"params": model.classifier.parameters(),"lr": CFG.LR},
])

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS
)

criterion = nn.BCEWithLogitsLoss()
domain_criterion = nn.CrossEntropyLoss()

def invariance_contrastive_loss(z1, z2, temp=0.2):
    z1 = F.normalize(z1.mean(1), dim=1)
    z2 = F.normalize(z2.mean(1), dim=1)
    logits = (z1 @ z2.T / temp).clamp(-50, 50)
    labels = torch.arange(z1.size(0), device=z1.device)
    return F.cross_entropy(logits, labels)

## 9. Training & Evaluation Functions

In [80]:
def now():
    return str(datetime.datetime.now().time())[:-7]

In [81]:
def train_epoch_baselinepp(loader, model, optimizer):
    model.train()
    total = 0.0

    for x, y in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)

        f = model.backbone(x)
        dom_clean = model.domain_head(f, lambda_grl=0.1)
        tok_clean = model.embedder(f)

        with torch.no_grad():
            if torch.rand(1) < 0.5:
                x_corr = corrupt_image(x)
            else:
                x_corr = freq_mix(x)

        f_c = model.backbone(x_corr)
        dom_corrupt = model.domain_head(f_c, lambda_grl=0.1)
        tok_corr = model.embedder(f_c)

        loss_inv = invariance_contrastive_loss(tok_clean, tok_corr)

        logit = model.classifier(tok_clean)
        loss_cls = criterion(logit, y)

        dom_y_clean = torch.zeros(x.size(0), dtype=torch.long, device=device)
        dom_y_corrupt = torch.ones(x.size(0), dtype=torch.long, device=device)

        loss_dom = (
            domain_criterion(dom_clean, dom_y_clean)
            + domain_criterion(dom_corrupt, dom_y_corrupt)
        )

        loss = loss_cls + CFG.LAMBDA_INV * loss_inv + CFG.LAMBDA_DOM * loss_dom

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total += loss.item()

    return total / len(loader)

In [82]:
@torch.no_grad()
def evaluate(loader, model, threshold=0.5):
    model.eval()
    logits, labels = [], []

    for x, y in tqdm(loader, desc="Evaluating", leave=False):
        x = x.to(device)
        logits.append(model(x).cpu())
        labels.append(y)

    logits = torch.cat(logits).numpy()
    labels = torch.cat(labels).numpy()

    probs = 1 / (1 + np.exp(-logits))
    preds = (probs >= threshold).astype(int)

    return {
        "acc": accuracy_score(labels, preds),
        "auc": roc_auc_score(labels, probs),
        "precision": precision_score(labels, preds, zero_division=0),
        "recall": recall_score(labels, preds, zero_division=0),
        "f1": f1_score(labels, preds, zero_division=0),
    }

## 10. DataLoaders

In [83]:
train_ds = BinaryImageFolder(os.path.join(CFG.DATA_ROOT, "train"), train_tfms)
val_ds   = BinaryImageFolder(os.path.join(CFG.DATA_ROOT, "val"),   eval_tfms)
test_ds  = BinaryImageFolder(os.path.join(CFG.DATA_ROOT, "test"),  eval_tfms)

train_loader = DataLoader(train_ds, CFG.BATCH_SIZE, True,  num_workers=CFG.NUM_WORKERS)
val_loader   = DataLoader(val_ds,   CFG.BATCH_SIZE, False, num_workers=CFG.NUM_WORKERS)
test_loader  = DataLoader(test_ds,  CFG.BATCH_SIZE, False, num_workers=CFG.NUM_WORKERS)

## 11. Training Loop

In [84]:
os.makedirs(CFG.WEIGHTS_DIR, exist_ok=True)

best_auc = -1.0
EPOCHS = 5

for epoch in range(EPOCHS):
    avg_loss = train_epoch_baselinepp(train_loader, model, optimizer)
    val_metrics = evaluate(val_loader, model)
    current_auc = val_metrics["auc"]

    print(
        f"Epoch {epoch+1:02d} | "
        f"Loss: {avg_loss:.4f} | "
        f"Val Acc: {val_metrics['acc']:.4f} | "
        f"AUC: {current_auc:.4f} | "
        f"P: {val_metrics['precision']:.4f} | "
        f"R: {val_metrics['recall']:.4f} | "
        f"F1: {val_metrics['f1']:.4f} | "
        f"Time: {now()}"
    )

    if current_auc > best_auc:
        best_auc = current_auc
        torch.save(model.state_dict(), f"{CFG.WEIGHTS_DIR}/best_Baseline.pt")
        print(f"  ✓ Saved new best model (AUC={best_auc:.4f})")

    scheduler.step()

                                                             

Epoch 01 | Loss: 0.5492 | Val Acc: 0.7489 | AUC: 0.8947 | P: 0.9648 | R: 0.7206 | F1: 0.8250 | Time: 13:26:19
  ✓ Saved new best model (AUC=0.8947)


                                                             

Epoch 02 | Loss: 0.4129 | Val Acc: 0.8167 | AUC: 0.9213 | P: 0.9661 | R: 0.8052 | F1: 0.8783 | Time: 13:43:18
  ✓ Saved new best model (AUC=0.9213)


                                                             

Epoch 03 | Loss: 0.3428 | Val Acc: 0.8470 | AUC: 0.9226 | P: 0.9607 | R: 0.8484 | F1: 0.9011 | Time: 13:59:48
  ✓ Saved new best model (AUC=0.9226)


                                                             

Epoch 04 | Loss: 0.2936 | Val Acc: 0.8440 | AUC: 0.9299 | P: 0.9639 | R: 0.8416 | F1: 0.8986 | Time: 14:15:33
  ✓ Saved new best model (AUC=0.9299)


                                                             

Epoch 05 | Loss: 0.2572 | Val Acc: 0.8324 | AUC: 0.9232 | P: 0.9635 | R: 0.8273 | F1: 0.8903 | Time: 14:33:02




## 12. FF++ Test Evaluation

In [None]:
model.load_state_dict(
    torch.load(f"{CFG.WEIGHTS_DIR}/best_Baseline.pt", map_location=device)
)
model.eval()

# Number of repeated evaluation runs (different seeds) to average metrics over.
NUM_RUNS = 1
all_metrics = []

with torch.no_grad():
    for i in range(NUM_RUNS):
        set_seed(CFG.SEED + i)
        test_metrics = evaluate(test_loader, model)
        all_metrics.append(test_metrics)

# Average metrics
avg_metrics = {}
for key in all_metrics[0].keys():
    avg_metrics[key] = sum(m[key] for m in all_metrics) / NUM_RUNS

print(f"\n===== FF++ TEST (BASELINE) | AVERAGED OVER {NUM_RUNS} RUN(S) =====")
for k, v in avg_metrics.items():
    print(f"{k.upper():>10}: {v:.4f}")


  torch.load(f"{CFG.WEIGHTS_DIR}/best_Baseline.pt", map_location=device)
                                                             


===== FF++ TEST (BASELINE) | AVERAGED OVER 3 RUNS =====
       ACC: 0.8455
       AUC: 0.9331
 PRECISION: 0.9698
    RECALL: 0.8383
        F1: 0.8993




## 13. JPEG Compression Robustness

In [None]:
print(f"\n===== JPEG ROBUSTNESS (BASELINE | {NUM_RUNS}-RUN AVG) =====")

jpeg_qualities = [100, 90, 75, 50, 30]

for q in jpeg_qualities:
    print(f"\n--- JPEG {q} ---")

    run_metrics = []

    for run_idx in range(NUM_RUNS):
        set_seed(CFG.SEED + run_idx)
        jpeg_ds = BinaryImageFolder(
            os.path.join(CFG.DATA_ROOT, "test"),
            build_jpeg_tfms(q)
)

        jpeg_loader = DataLoader(
            jpeg_ds,
            CFG.BATCH_SIZE,
            shuffle=False,
            num_workers=CFG.NUM_WORKERS
)

        metrics = evaluate(jpeg_loader, model)
        run_metrics.append(metrics)

    avg_auc = sum(m["auc"] for m in run_metrics) / NUM_RUNS
    avg_acc = sum(m["acc"] for m in run_metrics) / NUM_RUNS
    avg_f1  = sum(m["f1"]  for m in run_metrics) / NUM_RUNS

    print(
        f"AUC: {avg_auc:.4f} | "
        f"ACC: {avg_acc:.4f} | "
        f"F1: {avg_f1:.4f}"
)



===== JPEG ROBUSTNESS (BASELINE | 3-RUN AVG) =====

--- JPEG 100 ---


                                                             

AUC: 0.9331 | ACC: 0.8509 | F1: 0.9034

--- JPEG 90 ---


                                                             

AUC: 0.9257 | ACC: 0.8656 | F1: 0.9147

--- JPEG 75 ---


                                                             

AUC: 0.9147 | ACC: 0.8039 | F1: 0.8681

--- JPEG 50 ---


                                                             

AUC: 0.8811 | ACC: 0.7332 | F1: 0.8120

--- JPEG 30 ---


                                                             

AUC: 0.8467 | ACC: 0.6794 | F1: 0.7651




## 14. DFDC Cross-Dataset Test

In [None]:
print(f"\n===== DFDC CROSS-DATASET (BASELINE | {NUM_RUNS}-RUN AVG) =====")

DFDC_ROOT = "./DFDC/validation"

dfdc_ds = BinaryImageFolder(DFDC_ROOT, eval_tfms)
dfdc_loader = DataLoader(
    dfdc_ds,
    CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS
)

all_metrics = []

for run_idx in range(NUM_RUNS):
    set_seed(CFG.SEED + run_idx)
    metrics = evaluate(dfdc_loader, model, threshold=0.5)
    all_metrics.append(metrics)

avg_metrics = {
    k: sum(m[k] for m in all_metrics) / NUM_RUNS
    for k in all_metrics[0].keys()
}

for k, v in avg_metrics.items():
    print(f"{k.upper():>10}: {v:.4f}")



===== DFDC CROSS-DATASET (BASELINE | 3-RUN AVG) =====


                                                               

       ACC: 0.7544
       AUC: 0.6548
 PRECISION: 0.8342
    RECALL: 0.8669
        F1: 0.8502


## 15. CelebDF Cross-Dataset Test

In [None]:
print(f"\n===== CELEB-DF CROSS-DATASET (BASELINE | {NUM_RUNS}-RUN AVG) =====")

CELEBDF_ROOT = "./CelebDF_images/test"

celeb_ds = BinaryImageFolder(CELEBDF_ROOT, eval_tfms)
celeb_loader = DataLoader(
    celeb_ds,
    CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS
)

all_metrics = []

for run_idx in range(NUM_RUNS):
    set_seed(CFG.SEED + run_idx)
    metrics = evaluate(celeb_loader, model, threshold=0.5)
    all_metrics.append(metrics)

avg_metrics = {
    k: sum(m[k] for m in all_metrics) / NUM_RUNS
    for k in all_metrics[0].keys()
}

for k, v in avg_metrics.items():
    print(f"{k.upper():>10}: {v:.4f}")



===== CELEB-DF CROSS-DATASET (BASELINE | 3-RUN AVG) =====


                                                             

       ACC: 0.6677
       AUC: 0.6837
 PRECISION: 0.7524
    RECALL: 0.7362
        F1: 0.7442


In [None]:
# =====================================
# COMPUTE PARAMS + FLOPs FOR PAPER
# =====================================
from thop import profile
from thop import clever_format

def compute_model_cost(model, name="Model"):
    """Compute Params/FLOPs using a single dummy forward pass."""
    model.eval().to(device)

    # Dummy input (same as training resolution)
    dummy = torch.randn(1, 3, CFG.IMG_SIZE, CFG.IMG_SIZE).to(device)

    # Compute FLOPs + Params
    flops, params = profile(
        model,
        inputs=(dummy,),
        verbose=False
)

    # Pretty formatting (K, M, G)
    flops_str, params_str = clever_format([flops, params], "%.3f")

    print(f"\n===== {name} =====")
    print(f"Params : {params_str}")
    print(f"FLOPs  : {flops_str}")

    return flops, params


## 16. Model Cost (Params + FLOPs)
Uses `thop` to estimate Params/FLOPs with a dummy input at `CFG.IMG_SIZE`. Install once via: `pip install thop`.

In [None]:
# Instantiate a fresh model for cost reporting (weights are not required for Params/FLOPs).
baseline_model = DeepfakeDetectorBaselinePP().to(device)

# Compute costs
flops_base, params_base = compute_model_cost(
    baseline_model,
    name="Baseline Detector",
)



===== Baseline Detector =====
Params : 21.944M
FLOPs  : 3.689G
