# Paper 4 – Deepfake Detection Benchmark


This notebook implements the **Paper 4** variant of the deepfake detector.



The code cells below typically follow this structure:

- Import libraries and configure paths/devices.
- Prepare datasets and data loaders for the relevant benchmarks.
- Define the model architecture and loss functions used in Paper 4.
- Train and evaluate the model, printing metrics for comparison.



> Run the cells from top to bottom to reproduce the results reported for Paper 4.

Paper link : https://arxiv.org/pdf/2505.20653 (2505.20653v1.pdf)

In [2]:
import os
import random
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torchvision.models as models


KeyboardInterrupt: 

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
IMG_SIZE = 224
BATCH_SIZE = 8
EPOCHS = 5
LR = 0.005

# RoGA params from paper
RHO = 0.1
ALPHA = 0.0002

FFPP_REAL_PATH = r"C:\Users\vk200\OneDrive\Desktop\Benchmarking\FFPP_CViT\train\real"
FFPP_FAKE_PATH = r"C:\Users\vk200\OneDrive\Desktop\Benchmarking\FFPP_CViT\train\fake"

# optional cross dataset
CELEBDF_PATH = "PATH_TO_CELEBDF"
DFDC_PATH = "PATH_TO_DFDC"


cuda


In [None]:
class ImageDataset(Dataset):
    def __init__(self, real_path, fake_path=None, jpeg_quality=None):

        self.samples = []

        if fake_path is not None:
            for f in os.listdir(real_path):
                self.samples.append((os.path.join(real_path,f),0))
            for f in os.listdir(fake_path):
                self.samples.append((os.path.join(fake_path,f),1))
        else:
            for f in os.listdir(real_path):
                self.samples.append((os.path.join(real_path,f),0))

        self.jpeg_quality = jpeg_quality

        self.transform = T.Compose([
            T.Resize((IMG_SIZE,IMG_SIZE)),
            T.ToTensor(),
            T.Normalize([0.5]*3,[0.5]*3)
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self,idx):
        path,label=self.samples[idx]

        img = Image.open(path).convert("RGB")

        if self.jpeg_quality is not None:
            from io import BytesIO
            buffer = BytesIO()
            img.save(buffer, format="JPEG", quality=self.jpeg_quality)
            img = Image.open(buffer)

        img = self.transform(img)
        return img,label


In [None]:
model = models.resnet34(pretrained=True)
model.fc = nn.Linear(model.fc.in_features,2)
model = model.to(DEVICE)


In [None]:
def roga_step(model, imgs, labels, optimizer, criterion):

    # ===== Forward 1 (ERM gradient) =====
    logits = model(imgs)
    loss = criterion(logits, labels)

    optimizer.zero_grad()
    loss.backward(create_graph=True)

    # compute gradient norm
    grad_norm = torch.norm(
        torch.stack([
            p.grad.norm()
            for p in model.parameters()
            if p.grad is not None
        ])
    )

    # save original weights
    original_params = []
    for p in model.parameters():
        if p.grad is None:
            original_params.append(None)
            continue

        eps = RHO * p.grad / (grad_norm + 1e-12)
        original_params.append(p.data.clone())
        p.data = p.data + eps

    # ===== Forward 2 (perturbed) =====
    logits_pert = model(imgs)
    loss_pert = criterion(logits_pert, labels)

    optimizer.zero_grad()
    loss_pert.backward(create_graph=True)

    # ===== Gradient Alignment Term =====
    align = 0.0
    for p, orig in zip(model.parameters(), original_params):
        if p.grad is None or orig is None:
            continue
        align += (p.grad * (p.data - orig)).sum()

    roga_loss = loss_pert - ALPHA * align

    # restore original weights BEFORE final backward
    for p, orig in zip(model.parameters(), original_params):
        if orig is not None:
            p.data = orig

    optimizer.zero_grad()
    roga_loss.backward()
    optimizer.step()

    return roga_loss.item()


In [None]:
train_set = ImageDataset(FFPP_REAL_PATH,FFPP_FAKE_PATH)
train_loader = DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)

optimizer = torch.optim.SGD(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()


In [None]:
for epoch in range(EPOCHS):

    model.train()
    total_loss = 0

    for imgs,labels in tqdm(train_loader):

        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)

        loss = roga_step(model,imgs,labels,optimizer,criterion)
        total_loss += loss

    print("Epoch",epoch+1,"Loss:",total_loss/len(train_loader))


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
100%|██████████| 4782/4782 [36:58<00:00,  2.16it/s]  


Epoch 1 Loss: 0.7571298806413553


100%|██████████| 4782/4782 [35:22<00:00,  2.25it/s] 


Epoch 2 Loss: 0.5813964470342232


100%|██████████| 4782/4782 [34:57<00:00,  2.28it/s]


Epoch 3 Loss: 0.5549187639775709


100%|██████████| 4782/4782 [35:31<00:00,  2.24it/s] 


Epoch 4 Loss: 0.5411997244933667


100%|██████████| 4782/4782 [36:30<00:00,  2.18it/s] 

Epoch 5 Loss: 0.5307824743661558





In [None]:
SAVE_DIR = "./checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

MODEL_NAME = "paper4_model"
best_loss = float("inf")

def save_checkpoint(model, optimizer, epoch, loss):
    path = os.path.join(SAVE_DIR, f"{MODEL_NAME}_BEST.pth")
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss
    }, path)
    print("Saved BEST checkpoint:", path)

# Save checkpoint after training (uses last epoch's stats)
save_checkpoint(model, optimizer, epoch+1, total_loss/len(train_loader))

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

@torch.no_grad()
def evaluate_model(loader):

    model.eval()

    all_probs = []
    all_labels = []

    correct = 0
    total = 0

    # show progress so long runs (e.g., cross-dataset) are visible
    for imgs, labels in tqdm(loader, desc="Evaluating", leave=False):

        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(imgs)

        probs = torch.softmax(logits, dim=1)[:,1]

        preds = torch.argmax(logits, dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

        all_probs.extend(probs.detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())

    # ===== Metrics =====
    acc = correct / total

    auc = roc_auc_score(all_labels, all_probs)
    ap  = average_precision_score(all_labels, all_probs)

    # ===== EER =====
    fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
    fnr = 1 - tpr
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]

    return {
        "ACC": acc,
        "AUC": auc,
        "AP": ap,
        "EER": eer
    }

In [None]:
print("\n===== FF++ Evaluation (TEST SET) =====")

# Use FF++ test split for evaluation
FFPP_REAL_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\test\\real"
FFPP_FAKE_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\test\\fake"

ffpp_test_loader = DataLoader(
    ImageDataset(FFPP_REAL_PATH, FFPP_FAKE_PATH),
    batch_size=BATCH_SIZE,
    shuffle=False
)

metrics = evaluate_model(ffpp_test_loader)

print("FF++ Test Metrics:")
print(metrics)


===== FF++ Evaluation =====
FF++ Metrics:
{'ACC': 0.7458694970197637, 'AUC': 0.8445491239755181, 'AP': 0.9382704997967297, 'EER': np.float64(0.2450892857142857)}


In [None]:
print("\n===== Cross Dataset Evaluation =====")

CELEBDF_REAL_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\CelebDF_images\\train\\real"
CELEBDF_FAKE_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\CelebDF_images\\train\\fake"

# Use DFDC validation split for faster cross-dataset evaluation
DFDC_REAL_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\DFDC\\validation\\real"
DFDC_FAKE_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\DFDC\\validation\\fake"

print("\n===== Cross Dataset Evaluation =====")

# ===== CelebDF =====
if CELEBDF_REAL_PATH is not None and CELEBDF_FAKE_PATH is not None:

    celeb_loader = DataLoader(
        ImageDataset(CELEBDF_REAL_PATH, CELEBDF_FAKE_PATH),
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    metrics = evaluate_model(celeb_loader)
    print("CelebDF:", metrics)


# ===== DFDC =====
if DFDC_REAL_PATH is not None and DFDC_FAKE_PATH is not None:

    dfdc_loader = DataLoader(
        ImageDataset(DFDC_REAL_PATH, DFDC_FAKE_PATH),
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    metrics = evaluate_model(dfdc_loader)
    print("DFDC:", metrics)




===== Cross Dataset Evaluation =====

===== Cross Dataset Evaluation =====
CelebDF: {'ACC': 0.8998505852079356, 'AUC': 0.5775024575117633, 'AP': 0.9051846860594184, 'EER': np.float64(0.4451345755693582)}
DFDC: {'ACC': 0.7794529743322004, 'AUC': 0.5398191363815406, 'AP': 0.8218131207698823, 'EER': np.float64(0.43862022319918836)}


In [None]:
print("\n===== JPEG Robustness Evaluation (FF++ TEST) =====")

# Ensure we use FF++ test split here as well
FFPP_REAL_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\test\\real"
FFPP_FAKE_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\test\\fake"

jpeg_qualities = [90, 70, 50, 30]

for q in jpeg_qualities:

    jpeg_loader = DataLoader(
        ImageDataset(
            FFPP_REAL_PATH,
            FFPP_FAKE_PATH,
            jpeg_quality=q
        ),
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    metrics = evaluate_model(jpeg_loader)

    print(f"JPEG Quality {q} Metrics:")
    print(metrics)


===== JPEG Robustness Evaluation =====
JPEG Quality 90 Metrics:
{'ACC': 0.7465492000418279, 'AUC': 0.8438919521397943, 'AP': 0.9379180947162962, 'EER': np.float64(0.24732142857142858)}
JPEG Quality 70 Metrics:
{'ACC': 0.7450067970302207, 'AUC': 0.8422462030797828, 'AP': 0.9371441674328989, 'EER': np.float64(0.24803571428571428)}
JPEG Quality 50 Metrics:
{'ACC': 0.7439349576492732, 'AUC': 0.8402491431845547, 'AP': 0.9361300484228315, 'EER': np.float64(0.2517857142857143)}
JPEG Quality 30 Metrics:
{'ACC': 0.7450329394541462, 'AUC': 0.8351262845630638, 'AP': 0.9334116056080912, 'EER': np.float64(0.2588392857142857)}
