# Paper 8 – CLIP/ForgeLens Deepfake Detection Benchmark

This notebook implements a CLIP-based deepfake detector (ForgeLens-style) built on top of an OpenCLIP visual transformer.

The main stages are:

- Import libraries and configure paths, image size, batch size, and optimizer settings.
- Build an `ImageDataset` wrapper over real/fake frames (with optional JPEG compression).
- Define the CLIP feature extractor, WSGM, FAFormer, and ForgeLens classifier.
- Train the model on FF++ training frames and monitor the loss.
- Save the trained model and evaluate it on an FF++ test split using standard metrics.

Run the cells from top to bottom to train, save, and evaluate the model.

Paper link : https://arxiv.org/pdf/2408.13697 (2408.13697v2.pdf)

In [1]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import open_clip

from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    roc_curve,
    precision_score,
    recall_score,
    f1_score
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
DEVICE="cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE=224
BATCH_SIZE=8
EPOCHS=5
LR=1e-4

FFPP_REAL_PATH = r""
FFPP_FAKE_PATH = r""



In [3]:
class ImageDataset(Dataset):

    def __init__(self,real_path,fake_path,jpeg_quality=None):

        self.samples=[]

        for f in os.listdir(real_path):
            self.samples.append((os.path.join(real_path,f),0))

        for f in os.listdir(fake_path):
            self.samples.append((os.path.join(fake_path,f),1))

        self.jpeg_quality=jpeg_quality

        self.tf=T.Compose([
            T.Resize((IMG_SIZE,IMG_SIZE)),
            T.ToTensor(),
            T.Normalize(
                mean=[0.48145466,0.4578275,0.40821073],
                std=[0.26862954,0.26130258,0.27577711]
            )
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self,idx):

        path,label=self.samples[idx]
        img=Image.open(path).convert("RGB")

        if self.jpeg_quality:
            from io import BytesIO
            buf=BytesIO()
            img.save(buf,"JPEG",quality=self.jpeg_quality)
            img=Image.open(buf)

        img=self.tf(img)
        return img,label


## Dataset & Preprocessing

The `ImageDataset` class wraps folders of real and fake frames and applies:

- Optional JPEG compression (controlled by `jpeg_quality`).
- Resize to `IMG_SIZE × IMG_SIZE`.
- Normalization to the range [-1, 1] using mean and std of 0.5.

It is reused for FF++ training and later for test-time evaluation.


In [4]:
clip_model,_,_ = open_clip.create_model_and_transforms(
    "ViT-B-16",
    pretrained="openai"
)

clip_model = clip_model.to(DEVICE)

# Freeze CLIP
for p in clip_model.parameters():
    p.requires_grad=False

clip_model.eval()




CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine

In [5]:
class CLIPFeatureExtractor(nn.Module):

    def __init__(self,clip_model):
        super().__init__()
        self.clip=clip_model

    def forward(self,x):

        visual=self.clip.visual

        x = visual.conv1(x)
        x = x.reshape(x.shape[0],x.shape[1],-1).permute(0,2,1)

        cls = visual.class_embedding.to(x.dtype)
        cls = cls.unsqueeze(0).unsqueeze(0).expand(x.size(0),-1,-1)

        x = torch.cat([cls,x],dim=1)

        x = x + visual.positional_embedding
        x = visual.ln_pre(x)

        x = x.permute(1,0,2)

        for blk in visual.transformer.resblocks:
            x = blk(x)

        x = x.permute(1,0,2)

        x = visual.ln_post(x)

        return x   # [B, tokens, 512]


In [6]:
class WSGM(nn.Module):

    def __init__(self,dim):
        super().__init__()

        self.guidance = nn.Sequential(
            nn.Linear(dim,dim),
            nn.GELU(),
            nn.Linear(dim,dim)
        )

        self.norm = nn.LayerNorm(dim)

    def forward(self,tokens):

        g = self.guidance(tokens)
        tokens = tokens + g

        return self.norm(tokens)


In [7]:
class FAFormer(nn.Module):

    def __init__(self,dim,heads=8,depth=2):
        super().__init__()

        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=dim,
                nhead=heads,
                batch_first=True
            )
            for _ in range(depth)
        ])

    def forward(self,x):

        for layer in self.layers:
            x = layer(x)

        return x


In [8]:
class ForgeLens(nn.Module):

    def __init__(self):
        super().__init__()

        self.extractor = CLIPFeatureExtractor(clip_model)

        dim = clip_model.visual.transformer.width

        print("ForgeLens token dim =", dim)

        self.wsgm = WSGM(dim)
        self.faformer = FAFormer(dim, heads=8, depth=2)

        self.cls_head = nn.Linear(dim, 2)

    def forward(self, x):

        tokens = self.extractor(x)

        tokens = self.wsgm(tokens)
        tokens = self.faformer(tokens)

        cls_token = tokens[:,0] + tokens[:,1:].mean(dim=1)

        logits = self.cls_head(cls_token)
        return logits


## Model Architecture – ForgeLens + CLIP

The overall model is composed of:

- `CLIPFeatureExtractor` to obtain token embeddings from the OpenCLIP vision transformer.
- `WSGM` (guidance module) that refines token representations with a residual MLP and LayerNorm.
- `FAFormer`, a small Transformer encoder stack operating over tokens.
- A classification head on the CLS token to predict real vs fake.

Only the ForgeLens components are trained; the CLIP backbone is kept frozen.


In [9]:
model = ForgeLens().to(DEVICE)

optimizer = torch.optim.AdamW(
    filter(lambda p:p.requires_grad,model.parameters()),
    lr=LR
)

criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    ImageDataset(FFPP_REAL_PATH,FFPP_FAKE_PATH),
    batch_size=BATCH_SIZE,
    shuffle=True
)


ForgeLens token dim = 768


## Training Loop

We now train ForgeLens on FF++ training frames using cross-entropy loss and the AdamW optimizer.

The loop iterates over epochs, accumulating the average loss per epoch for monitoring.

In [10]:
for epoch in range(EPOCHS):

    model.train()
    total_loss=0

    for imgs,labels in tqdm(train_loader):

        imgs=imgs.to(DEVICE)
        labels=labels.to(DEVICE)

        optimizer.zero_grad()

        logits=model(imgs)
        loss=criterion(logits,labels)

        loss.backward()
        optimizer.step()

        total_loss+=loss.item()

    print("Epoch",epoch+1,"Loss:",total_loss/len(train_loader))


100%|██████████| 4782/4782 [14:53<00:00,  5.35it/s]


Epoch 1 Loss: 0.5760109235413001


100%|██████████| 4782/4782 [14:25<00:00,  5.52it/s]


Epoch 2 Loss: 0.5405354090447844


100%|██████████| 4782/4782 [14:22<00:00,  5.55it/s]


Epoch 3 Loss: 0.5315555519226117


100%|██████████| 4782/4782 [15:32<00:00,  5.13it/s]


Epoch 4 Loss: 0.5338648674360682


100%|██████████| 4782/4782 [15:35<00:00,  5.11it/s] 

Epoch 5 Loss: 0.5385450400709669





In [11]:
# Save trained model for Paper 8
import os

BEST_MODEL_PATH = os.path.join("checkpoints", "paper8_model_BEST.pth")
os.makedirs("checkpoints", exist_ok=True)

torch.save(model.state_dict(), BEST_MODEL_PATH)
print("Saved trained model to:", BEST_MODEL_PATH)

Saved trained model to: checkpoints\paper8_model_BEST.pth


In [12]:
# Reload best model for testing

print("\nLoading best trained model from:", BEST_MODEL_PATH)

model = ForgeLens().to(DEVICE)
state_dict = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()

print("✔ Best model loaded successfully")


Loading best trained model from: checkpoints\paper8_model_BEST.pth
ForgeLens token dim = 768


  state_dict = torch.load(BEST_MODEL_PATH, map_location=DEVICE)


✔ Best model loaded successfully


## Evaluation & Metrics

We define a reusable `evaluate` helper that runs the model on a data loader and computes:

- Accuracy
- ROC AUC
- Precision, Recall
- F1 score

This mirrors the evaluation used in Paper 1 so that results are comparable across models.


In [13]:
# Evaluation utilities (Paper 8)

import numpy as np
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score,
    roc_curve,
    average_precision_score,
)
import torch.nn.functional as F


@torch.no_grad()
def evaluate(loader, model):
    model.eval()

    all_probs = []
    all_preds = []
    all_labels = []

    for imgs, labels in tqdm(loader, desc="Evaluating", leave=False):
        imgs = imgs.to(DEVICE)

        logits = model(imgs)
        probs = F.softmax(logits, dim=1)[:, 1]
        preds = (probs >= 0.5).long().cpu()

        all_probs.append(probs.cpu())
        all_preds.append(preds)
        all_labels.append(labels)

    probs = torch.cat(all_probs).numpy()
    preds = torch.cat(all_preds).numpy()
    labels = torch.cat(all_labels).numpy()

    return {
        "acc": accuracy_score(labels, preds),
        "auc": roc_auc_score(labels, probs),
        "precision": precision_score(labels, preds, zero_division=0),
        "recall": recall_score(labels, preds, zero_division=0),
        "f1": f1_score(labels, preds, zero_division=0),
    }

In [None]:
# FF++ TEST SET EVALUATION (Paper 8)

print("\n===== FF++ TEST (Paper 8) =====")

FFPP_TEST_REAL_PATH = r""
FFPP_TEST_FAKE_PATH = r""

ffpp_test_dataset = ImageDataset(FFPP_TEST_REAL_PATH, FFPP_TEST_FAKE_PATH)
ffpp_test_loader = DataLoader(
    ffpp_test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    )

ffpp_metrics = evaluate(ffpp_test_loader, model)
print("FF++ Test Metrics:", ffpp_metrics)


===== FF++ TEST (Paper 8) =====


                                                               

FF++ Test Metrics: {'acc': 0.7929304110397757, 'auc': 0.5861088925143965, 'precision': 0.8285826771653543, 'recall': 0.9435975609756098, 'f1': 0.8823578735535804}


## Celeb-DF Cross-Dataset Test

Finally, we test cross-dataset generalization on Celeb-DF frames,
again reporting average detection metrics over the full evaluation split.


In [None]:
print("\n===== CELEB-DF CROSS-DATASET (Paper 8) =====")

CELEB_REAL_PATH = r""
CELEB_FAKE_PATH = r""

celeb_dataset = ImageDataset(CELEB_REAL_PATH, CELEB_FAKE_PATH)
celeb_loader = DataLoader(
    celeb_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    )

metrics_runs = []
NUM_RUNS = 1

for run in range(NUM_RUNS):
    print(f"Run {run+1}/{NUM_RUNS}")
    metrics = evaluate(celeb_loader, model)
    metrics_runs.append(metrics)

avg = {k: np.mean([m[k] for m in metrics_runs]) for k in metrics_runs[0]}
print("\nAVG:", avg)


===== CELEB-DF CROSS-DATASET (Paper 8) =====
Run 1/1


                                                               


AVG: {'acc': np.float64(0.8960322071885116), 'auc': np.float64(0.5602905254982902), 'precision': np.float64(0.9004260295714643), 'recall': np.float64(0.9944185617417778), 'f1': np.float64(0.9450910764779378)}




## DFDC Cross-Dataset Test

We evaluate how well the model trained on FF++ generalizes to the DFDC dataset,
using real and fake frames from a held-out DFDC split.


In [None]:
print("\n===== DFDC CROSS-DATASET (Paper 8) =====")

DFDC_REAL_PATH = r""
DFDC_FAKE_PATH = r""

dfdc_dataset = ImageDataset(DFDC_REAL_PATH, DFDC_FAKE_PATH)
dfdc_loader = DataLoader(
    dfdc_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    )

metrics_runs = []
NUM_RUNS = 1

for run in range(NUM_RUNS):
    print(f"Run {run+1}/{NUM_RUNS}")
    metrics = evaluate(dfdc_loader, model)
    metrics_runs.append(metrics)

avg = {k: np.mean([m[k] for m in metrics_runs]) for k in metrics_runs[0]}
print("\nAVG:", avg)


===== DFDC CROSS-DATASET (Paper 8) =====
Run 1/1


                                                                  


AVG: {'acc': np.float64(0.7776416310613407), 'auc': np.float64(0.5343964246664671), 'precision': np.float64(0.7793927475980292), 'recall': np.float64(0.9968969571041911), 'f1': np.float64(0.8748283080318853)}


## JPEG Compression Test

We test how robust the model is to different JPEG compression qualities on the FF++ test set.

For each quality level, we recompute the metrics using `ImageDataset` with `jpeg_quality` set,
and report the averaged results.


In [None]:
print("\n===== JPEG COMPRESSION TEST (Paper 8) =====")

jpeg_qualities = [100, 90, 75, 50, 30]
NUM_RUNS = 1

FFPP_TEST_REAL_PATH = r""
FFPP_TEST_FAKE_PATH = r""

for q in jpeg_qualities:
    print(f"\n--- JPEG Quality {q} ---")

    metrics_runs = []

    for run in range(NUM_RUNS):
        jpeg_dataset = ImageDataset(FFPP_TEST_REAL_PATH, FFPP_TEST_FAKE_PATH, jpeg_quality=q)
        jpeg_loader = DataLoader(
            jpeg_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=0,
        )

        metrics = evaluate(jpeg_loader, model)
        metrics_runs.append(metrics)

    avg = {k: np.mean([m[k] for m in metrics_runs]) for k in metrics_runs[0]}
    print("AVG:", avg)


===== JPEG COMPRESSION TEST (Paper 8) =====

--- JPEG Quality 100 ---


                                                               

AVG: {'acc': np.float64(0.7931517969153568), 'auc': np.float64(0.5930055220880531), 'precision': np.float64(0.8281581636663784), 'recall': np.float64(0.9446736011477762), 'f1': np.float64(0.8825870229966908)}

--- JPEG Quality 90 ---


                                                               

AVG: {'acc': np.float64(0.7862888347723416), 'auc': np.float64(0.5875768605462701), 'precision': np.float64(0.8290816326530612), 'recall': np.float64(0.9325681492109039), 'f1': np.float64(0.8777852802160703)}

--- JPEG Quality 75 ---


                                                               

AVG: {'acc': np.float64(0.7989816249723267), 'auc': np.float64(0.589965282491569), 'precision': np.float64(0.8281931464174455), 'recall': np.float64(0.9535509325681492), 'f1': np.float64(0.8864621540513504)}

--- JPEG Quality 50 ---


                                                               

AVG: {'acc': np.float64(0.7990554202641872), 'auc': np.float64(0.5972607548697658), 'precision': np.float64(0.8281043207473725), 'recall': np.float64(0.9538199426111909), 'f1': np.float64(0.8865274826019919)}

--- JPEG Quality 30 ---


                                                               

AVG: {'acc': np.float64(0.8017120507711608), 'auc': np.float64(0.6027977380879049), 'precision': np.float64(0.8291469010031884), 'recall': np.float64(0.9560616929698709), 'f1': np.float64(0.8880929573945275)}


