# Paper 6 – Deepfake Detection Benchmark


This notebook implements the **Paper 6** variant of the deepfake detector.



The code cells below typically follow this structure:

- Import libraries and set global configuration (paths, batch size, epochs).
- Instantiate datasets and data loaders for training and evaluation.
- Define the model architecture and optimization setup for Paper 6.
- Train the model and evaluate it on FF++ and cross-dataset benchmarks.



> Run the cells from top to bottom to reproduce the results reported for Paper 6.

Paper link : https://arxiv.org/pdf/2404.04584 (Yang_D3_Scaling_Up_Deepfake_Detection_by_Learning_from_Discrepancy_CVPR_2025_paper.pdf)

In [2]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import open_clip

from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    roc_curve,
    precision_score,
    recall_score,
    f1_score
)

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE = 224
BATCH_SIZE = 8
EPOCHS = 5
LR = 1e-4

FFPP_REAL_PATH = r""
FFPP_FAKE_PATH = r""

# Use preprocessed CelebDF images for cross-dataset testing
CELEBDF_REAL_PATH = r""
CELEBDF_FAKE_PATH = r""

# Use DFDC validation split for faster cross-dataset evaluation
DFDC_REAL_PATH = r""
DFDC_FAKE_PATH = r""


In [4]:
class ImageDataset(Dataset):

    def __init__(self, real_path, fake_path, jpeg_quality=None):

        self.samples = []

        for f in os.listdir(real_path):
            self.samples.append((os.path.join(real_path,f),0))

        for f in os.listdir(fake_path):
            self.samples.append((os.path.join(fake_path,f),1))

        self.jpeg_quality = jpeg_quality

        self.tf = T.Compose([
            T.Resize((IMG_SIZE,IMG_SIZE)),
            T.ToTensor(),
            T.Normalize([0.5]*3,[0.5]*3)
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):

        path,label = self.samples[idx]
        img = Image.open(path).convert("RGB")

        if self.jpeg_quality:
            from io import BytesIO
            buf = BytesIO()
            img.save(buf,"JPEG",quality=self.jpeg_quality)
            img = Image.open(buf)

        img = self.tf(img)
        return img,label


In [5]:
clip_model,_,_ = open_clip.create_model_and_transforms(
    "ViT-B-16",
    pretrained="openai"
)

clip_model = clip_model.to(DEVICE)

for p in clip_model.parameters():
    p.requires_grad = False




In [6]:
def patch_shuffle(x, grid=4):

    B,C,H,W = x.shape
    ph = H//grid
    pw = W//grid

    patches = []

    for i in range(grid):
        for j in range(grid):
            patches.append(
                x[:,:,i*ph:(i+1)*ph,j*pw:(j+1)*pw]
            )

    patches = torch.stack(patches,dim=1)
    idx = torch.randperm(patches.size(1))
    patches = patches[:,idx]

    out = torch.zeros_like(x)

    k=0
    for i in range(grid):
        for j in range(grid):
            out[:,:,i*ph:(i+1)*ph,j*pw:(j+1)*pw] = patches[:,k]
            k+=1

    return out


In [7]:
class D3Model(nn.Module):

    def __init__(self):
        super().__init__()

        self.clip = clip_model

        # feature interaction
        self.fc1 = nn.Linear(512*2,512)
        self.fc2 = nn.Linear(512,2)

    def forward(self,x):

        # original branch
        feat_o = self.clip.encode_image(x)

        # discrepancy branch
        x_shuf = patch_shuffle(x)
        feat_s = self.clip.encode_image(x_shuf)

        # interaction (paper discrepancy idea)
        feat = torch.cat([feat_o,feat_s],dim=1)
        feat = F.relu(self.fc1(feat))

        logits = self.fc2(feat)
        return logits


In [8]:
model = D3Model().to(DEVICE)

optimizer = torch.optim.AdamW(
    filter(lambda p:p.requires_grad,model.parameters()),
    lr=LR
)

criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    ImageDataset(FFPP_REAL_PATH,FFPP_FAKE_PATH),
    batch_size=BATCH_SIZE,
    shuffle=True
)


In [9]:
for epoch in range(EPOCHS):

    model.train()
    total_loss = 0

    for imgs,labels in tqdm(train_loader):

        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        logits = model(imgs)
        loss = criterion(logits,labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print("Epoch",epoch+1,"Loss:",total_loss/len(train_loader))

100%|██████████| 4782/4782 [38:40<00:00,  2.06it/s] 


Epoch 1 Loss: 0.3840912442178388


100%|██████████| 4782/4782 [38:22<00:00,  2.08it/s] 


Epoch 2 Loss: 0.30219872515034735


100%|██████████| 4782/4782 [38:18<00:00,  2.08it/s]  


Epoch 3 Loss: 0.25919862440650926


100%|██████████| 4782/4782 [39:43<00:00,  2.01it/s]  


Epoch 4 Loss: 0.23062606203015437


100%|██████████| 4782/4782 [37:15<00:00,  2.14it/s] 

Epoch 5 Loss: 0.20427273787706676





In [20]:
SAVE_DIR = "./checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

MODEL_NAME = "paper6_model"
best_loss = float("inf")

def save_checkpoint(model, optimizer, epoch, loss):
    path = os.path.join(SAVE_DIR, f"{MODEL_NAME}_BEST.pth")
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss
    }, path)
    print("Saved BEST checkpoint:", path)

# Save checkpoint after training (uses last epoch's stats)
save_checkpoint(model, optimizer, epoch+1, total_loss/len(train_loader))

Saved BEST checkpoint: ./checkpoints\paper6_model_BEST.pth


In [21]:
# Load best saved model for evaluation
BEST_MODEL_PATH = "checkpoints/paper6_model_BEST.pth"

print("\nLoading best trained model from:", BEST_MODEL_PATH)

best_model = D3Model().to(DEVICE)

state = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
best_model.load_state_dict(state["model_state_dict"])

best_model.eval()
print("✔ Best model loaded successfully")


Loading best trained model from: checkpoints/paper6_model_BEST.pth


  state = torch.load(BEST_MODEL_PATH, map_location=DEVICE)


✔ Best model loaded successfully


In [22]:
@torch.no_grad()
def evaluate_model(loader):

    model.eval()

    all_probs=[]
    all_labels=[]
    all_preds=[]

    correct=0
    total=0

    # progress bar so large datasets (e.g., DFDC) are visible
    for imgs,labels in tqdm(loader, desc="Evaluating", leave=False):

        imgs=imgs.to(DEVICE)
        labels=labels.to(DEVICE)

        logits=model(imgs)

        probs=torch.softmax(logits,1)[:,1]
        preds=logits.argmax(1)

        correct+=(preds==labels).sum().item()
        total+=labels.size(0)

        all_probs.extend(probs.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

    acc=correct/total
    auc=roc_auc_score(all_labels,all_probs)
    ap=average_precision_score(all_labels,all_probs)

    precision=precision_score(all_labels,all_preds)
    recall=recall_score(all_labels,all_preds)
    f1=f1_score(all_labels,all_preds)

    fpr,tpr,_=roc_curve(all_labels,all_probs)
    fnr=1-tpr
    eer=fpr[np.nanargmin(np.abs(fnr-fpr))]

    return {
        "ACC":acc,
        "AUC":auc,
        "Precision":precision,
        "Recall":recall,
        "F1":f1,
        "AP":ap,
        "EER":eer
    }


In [None]:
print("\n===== FF++ Evaluation (TEST SET) =====")

# Use FF++ test split for evaluation
FFPP_REAL_PATH = r""
FFPP_FAKE_PATH = r""

ffpp_loader = DataLoader(
    ImageDataset(FFPP_REAL_PATH,FFPP_FAKE_PATH),
    batch_size=BATCH_SIZE,
    shuffle=False
)

print(evaluate_model(ffpp_loader))


===== FF++ Evaluation (TEST SET) =====


                                                               

{'ACC': 0.8276141982141539, 'AUC': 0.8480660469181622, 'Precision': 0.9071679290596711, 'Recall': 0.8806492109038737, 'F1': 0.8937118937118937, 'AP': 0.9626801526473261, 'EER': np.float64(0.23926636098374324)}




In [24]:
print("\n===== Cross Dataset =====")

celeb_loader = DataLoader(
    ImageDataset(CELEBDF_REAL_PATH,CELEBDF_FAKE_PATH),
    batch_size=BATCH_SIZE
)

dfdc_loader = DataLoader(
    ImageDataset(DFDC_REAL_PATH,DFDC_FAKE_PATH),
    batch_size=BATCH_SIZE
)

print("CelebDF:",evaluate_model(celeb_loader))
print("DFDC:",evaluate_model(dfdc_loader))



===== Cross Dataset =====


                                                               

CelebDF: {'ACC': 0.891591267535486, 'AUC': 0.6425871484335195, 'Precision': 0.9083044241723415, 'Recall': 0.9782739056229531, 'F1': 0.9419916496402239, 'AP': 0.9333208925628333, 'EER': np.float64(0.4033126293995859)}


                                                                  

DFDC: {'ACC': 0.7746156223029631, 'AUC': 0.5152509091501681, 'Precision': 0.7791700220110592, 'Recall': 0.9919895015993657, 'F1': 0.8727937506389478, 'AP': 0.8082076669679421, 'EER': np.float64(0.5023431083627229)}


In [None]:
print("\n===== JPEG Robustness (FF++ TEST) =====")

# Ensure FF++ test split here
FFPP_REAL_PATH = r""
FFPP_FAKE_PATH = r""

for q in [90,70,50,30]:

    jpeg_loader = DataLoader(
        ImageDataset(FFPP_REAL_PATH,FFPP_FAKE_PATH,jpeg_quality=q),
        batch_size=BATCH_SIZE
    )

    print(f"JPEG {q}:",evaluate_model(jpeg_loader))


===== JPEG Robustness (FF++ TEST) =====


                                                               

JPEG 90: {'ACC': 0.7386170762305365, 'AUC': 0.8503698262008978, 'Precision': 0.9423389909323413, 'Recall': 0.7268651362984218, 'F1': 0.820694542877392, 'AP': 0.9629989895279025, 'EER': np.float64(0.23759899958315964)}


                                                               

JPEG 70: {'ACC': 0.7096893218212678, 'AUC': 0.8005505081026707, 'Precision': 0.920139697322468, 'Recall': 0.7087517934002869, 'F1': 0.8007294093810151, 'AP': 0.949428154888267, 'EER': np.float64(0.28803668195081283)}


                                                               

JPEG 50: {'ACC': 0.570216220205151, 'AUC': 0.7726723847155348, 'Precision': 0.9478816408876933, 'Recall': 0.5055595408895266, 'F1': 0.6594152046783626, 'AP': 0.9418938587227492, 'EER': np.float64(0.31012922050854524)}


                                                               

JPEG 30: {'ACC': 0.4711829385285219, 'AUC': 0.7735406027619112, 'Precision': 0.9641360037261295, 'Recall': 0.3712338593974175, 'F1': 0.5360611161465751, 'AP': 0.9397746737615432, 'EER': np.float64(0.2988745310546061)}


