In [17]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T

from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    roc_curve,
    precision_score,
    recall_score,
    f1_score
)


In [18]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 5
LR = 3e-4

FFPP_REAL_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\train\\real"
FFPP_FAKE_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\train\\fake"

In [19]:
class ImageDataset(Dataset):

    def __init__(self,real_path,fake_path,jpeg_quality=None):

        self.samples=[]

        for f in os.listdir(real_path):
            self.samples.append((os.path.join(real_path,f),0))

        for f in os.listdir(fake_path):
            self.samples.append((os.path.join(fake_path,f),1))

        self.jpeg_quality=jpeg_quality

        self.tf=T.Compose([
            T.Resize((IMG_SIZE,IMG_SIZE)),
            T.ToTensor(),
            T.Normalize([0.5]*3,[0.5]*3)
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self,idx):

        path,label=self.samples[idx]
        img=Image.open(path).convert("RGB")

        if self.jpeg_quality:
            from io import BytesIO
            buf=BytesIO()
            img.save(buf,"JPEG",quality=self.jpeg_quality)
            img=Image.open(buf)

        img=self.tf(img)
        return img,label


In [20]:
class SeparableConv2d(nn.Module):

    def __init__(self,in_c,out_c,kernel=3,stride=1,padding=1):
        super().__init__()

        self.depthwise = nn.Conv2d(
            in_c,in_c,kernel,stride,padding,
            groups=in_c,bias=False
        )

        self.pointwise = nn.Conv2d(in_c,out_c,1,bias=False)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self,x):
        x=self.depthwise(x)
        x=self.pointwise(x)
        return self.bn(x)


In [21]:
class Block(nn.Module):

    def __init__(self,in_c,out_c,reps,stride=1):
        super().__init__()

        self.skip = nn.Conv2d(in_c,out_c,1,stride=stride,bias=False)
        self.skipbn = nn.BatchNorm2d(out_c)

        layers=[]
        filters=in_c

        for _ in range(reps):
            # ✅ REMOVE inplace=True
            layers.append(nn.ReLU())
            layers.append(SeparableConv2d(filters,out_c))
            filters=out_c

        if stride!=1:
            layers.append(nn.MaxPool2d(3,stride,1))

        self.rep = nn.Sequential(*layers)

    def forward(self,x):

        skip=self.skipbn(self.skip(x))
        x=self.rep(x)

        # residual add
        return x+skip

In [22]:
class XceptionNet(nn.Module):

    def __init__(self):
        super().__init__()

        # Entry Flow
        self.conv1 = nn.Conv2d(3,32,3,2,0,bias=False)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.block1 = Block(64,128,2,2)
        self.block2 = Block(128,256,2,2)
        self.block3 = Block(256,728,2,2)

        # Middle Flow (8 blocks)
        self.middle = nn.Sequential(
            *[Block(728,728,3) for _ in range(8)]
        )

        # Exit Flow
        self.block12 = Block(728,1024,2,2)

        self.conv3 = SeparableConv2d(1024,1536)
        self.conv4 = SeparableConv2d(1536,2048)

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048,2)

    def forward(self,x):

        x=torch.relu(self.bn1(self.conv1(x)))
        x=torch.relu(self.bn2(self.conv2(x)))


        x=self.block1(x)
        x=self.block2(x)
        x=self.block3(x)

        x=self.middle(x)

        x=self.block12(x)

        x=F.relu(self.conv3(x))
        x=F.relu(self.conv4(x))

        x=self.pool(x).flatten(1)
        logits=self.fc(x)

        return logits


In [23]:
model = XceptionNet().to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(),lr=LR)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    ImageDataset(FFPP_REAL_PATH,FFPP_FAKE_PATH),
    batch_size=BATCH_SIZE,
    shuffle=True
)


In [24]:
for epoch in range(EPOCHS):

    model.train()
    total_loss=0

    for imgs,labels in tqdm(train_loader):

        imgs=imgs.to(DEVICE)
        labels=labels.to(DEVICE)

        optimizer.zero_grad()

        logits=model(imgs)
        loss=criterion(logits,labels)

        loss.backward()
        optimizer.step()

        total_loss+=loss.item()

    print("Epoch",epoch+1,"Loss:",total_loss/len(train_loader))


100%|██████████| 2391/2391 [14:08<00:00,  2.82it/s]


Epoch 1 Loss: 0.5138833650369516


100%|██████████| 2391/2391 [13:47<00:00,  2.89it/s]


Epoch 2 Loss: 0.4411662678148197


100%|██████████| 2391/2391 [13:51<00:00,  2.88it/s]


Epoch 3 Loss: 0.37120652378590124


100%|██████████| 2391/2391 [13:56<00:00,  2.86it/s]


Epoch 4 Loss: 0.28049582284253943


100%|██████████| 2391/2391 [14:15<00:00,  2.79it/s]

Epoch 5 Loss: 0.20361332557916278





In [25]:
# Save final model weights
os.makedirs("checkpoints", exist_ok=True)
BEST_MODEL_PATH = os.path.join("checkpoints", "paper9_model_BEST.pth")
torch.save(model.state_dict(), BEST_MODEL_PATH)
print("Saved model to:", BEST_MODEL_PATH)



Saved model to: checkpoints\paper9_model_BEST.pth


In [26]:
# Load the best saved model
BEST_MODEL_PATH = os.path.join("checkpoints", "paper9_model_BEST.pth")

print("Loading best trained model from:", BEST_MODEL_PATH)

# Create fresh model instance
model = XceptionNet().to(DEVICE)

# Load weights
state_dict = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)

model.eval()

print("✔ Best model loaded successfully")

Loading best trained model from: checkpoints\paper9_model_BEST.pth


  state_dict = torch.load(BEST_MODEL_PATH, map_location=DEVICE)


✔ Best model loaded successfully


In [27]:
# Test on FF++ test split
TEST_REAL_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\test\\real"
TEST_FAKE_PATH = r"C:\\Users\\vk200\\OneDrive\\Desktop\\Benchmarking\\FFPP_CViT\\test\\fake"

test_loader = DataLoader(
    ImageDataset(TEST_REAL_PATH, TEST_FAKE_PATH),
    batch_size=BATCH_SIZE,
    shuffle=False
)

model.eval()
all_probs = []
all_preds = []
all_labels = []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Testing", leave=False):
        imgs = imgs.to(DEVICE)
        logits = model(imgs)
        probs = F.softmax(logits, dim=1)[:, 1]
        preds = (probs >= 0.5).long().cpu()

        all_probs.append(probs.cpu())
        all_preds.append(preds)
        all_labels.append(labels)

probs = torch.cat(all_probs).numpy()
preds = torch.cat(all_preds).numpy()
labels = torch.cat(all_labels).numpy()

print("Test ACC:", (preds == labels).mean())
print("Test AUC:", roc_auc_score(labels, probs))
print("Test Precision:", precision_score(labels, preds, zero_division=0))
print("Test Recall:", recall_score(labels, preds, zero_division=0))
print("Test F1:", f1_score(labels, preds, zero_division=0))

                                                          

Test ACC: 0.7965463803409343
Test AUC: 0.843145278729839
Test Precision: 0.9201281153037734
Test Recall: 0.8243364418938307
Test F1: 0.8696022324173486




In [28]:
# Evaluation function for comprehensive metrics
@torch.no_grad()
def evaluate(loader, model):
    model.eval()

    all_probs = []
    all_preds = []
    all_labels = []

    for imgs, labels in tqdm(loader, desc="Evaluating", leave=False):
        imgs = imgs.to(DEVICE)

        logits = model(imgs)
        probs = F.softmax(logits, dim=1)[:,1]   # fake prob

        preds = (probs >= 0.5).long().cpu()

        all_probs.append(probs.cpu())
        all_preds.append(preds)
        all_labels.append(labels)

    probs = torch.cat(all_probs).numpy()
    preds = torch.cat(all_preds).numpy()
    labels = torch.cat(all_labels).numpy()

    return {
        "acc": (preds == labels).mean(),
        "auc": roc_auc_score(labels, probs),
        "precision": precision_score(labels, preds, zero_division=0),
        "recall": recall_score(labels, preds, zero_division=0),
        "f1": f1_score(labels, preds, zero_division=0),
    }


In [29]:
# JPEG Compression Testing
import io

class JPEGCompression:
    def __init__(self, quality):
        self.quality = quality

    def __call__(self, img_tensor):
        # UNNORMALIZE
        img = img_tensor.clone()
        img = img * 0.5 + 0.5     # [-1,1] -> [0,1]
        img = img.clamp(0,1)

        img = img.permute(1,2,0).cpu().numpy()
        img = (img * 255).astype(np.uint8)

        pil_img = Image.fromarray(img)
        buffer = io.BytesIO()
        pil_img.save(buffer, format="JPEG", quality=self.quality)
        buffer.seek(0)

        comp = Image.open(buffer).convert("RGB")
        comp = np.array(comp) / 255.0
        comp = torch.tensor(comp).permute(2,0,1).float()

        # RENORMALIZE
        comp = (comp - 0.5) / 0.5

        return comp

print("\n===== JPEG COMPRESSION TEST (Paper9) =====")

jpeg_qualities = [100, 90, 75, 50, 30]

for q in jpeg_qualities:
    print(f"\n--- JPEG Quality {q} ---")

    class JPEGWrapper(torch.utils.data.Dataset):
        def __init__(self, base_dataset, quality):
            self.base = base_dataset
            self.comp = JPEGCompression(quality)

        def __len__(self):
            return len(self.base)

        def __getitem__(self, idx):
            img, label = self.base[idx]
            img = self.comp(img)
            return img, label

    jpeg_dataset = JPEGWrapper(
        ImageDataset(TEST_REAL_PATH, TEST_FAKE_PATH), q
    )
    jpeg_loader = DataLoader(
        jpeg_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
    )

    metrics = evaluate(jpeg_loader, model)
    print(f"ACC: {metrics['acc']:.4f} | AUC: {metrics['auc']:.4f} | "
          f"Precision: {metrics['precision']:.4f} | Recall: {metrics['recall']:.4f} | "
          f"F1: {metrics['f1']:.4f}")



===== JPEG COMPRESSION TEST (Paper9) =====

--- JPEG Quality 100 ---


                                                             

ACC: 0.7896 | AUC: 0.8426 | Precision: 0.9225 | Recall: 0.8126 | F1: 0.8641

--- JPEG Quality 90 ---


                                                             

ACC: 0.7614 | AUC: 0.8382 | Precision: 0.9295 | Recall: 0.7684 | F1: 0.8413

--- JPEG Quality 75 ---


                                                             

ACC: 0.7622 | AUC: 0.8298 | Precision: 0.9280 | Recall: 0.7708 | F1: 0.8421

--- JPEG Quality 50 ---


                                                             

ACC: 0.7100 | AUC: 0.8117 | Precision: 0.9327 | Recall: 0.6980 | F1: 0.7984

--- JPEG Quality 30 ---


                                                             

ACC: 0.6343 | AUC: 0.7855 | Precision: 0.9361 | Recall: 0.5964 | F1: 0.7286




In [30]:
# DFDC Cross-Dataset Test
DFDC_REAL_PATH = r"C:\Users\vk200\OneDrive\Desktop\Benchmarking\DFDC\train\real"
DFDC_FAKE_PATH = r"C:\Users\vk200\OneDrive\Desktop\Benchmarking\DFDC\train\fake"

print("\n===== DFDC CROSS-DATASET (Paper9) =====")

dfdc_dataset = ImageDataset(DFDC_REAL_PATH, DFDC_FAKE_PATH)
dfdc_loader = DataLoader(
    dfdc_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

metrics = evaluate(dfdc_loader, model)
print(f"ACC: {metrics['acc']:.4f}")
print(f"AUC: {metrics['auc']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1: {metrics['f1']:.4f}")



===== DFDC CROSS-DATASET (Paper9) =====


                                                                 

ACC: 0.7600
AUC: 0.4198
Precision: 0.7770
Recall: 0.9707
F1: 0.8631




In [31]:
# Celeb-DF Cross-Dataset Test
CELEB_REAL_PATH = r"C:\Users\vk200\OneDrive\Desktop\Benchmarking\CelebDF_images\train\real"
CELEB_FAKE_PATH = r"C:\Users\vk200\OneDrive\Desktop\Benchmarking\CelebDF_images\train\fake"

print("\n===== CELEB-DF CROSS-DATASET (Paper9) =====")

celeb_dataset = ImageDataset(CELEB_REAL_PATH, CELEB_FAKE_PATH)
celeb_loader = DataLoader(
    celeb_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

metrics = evaluate(celeb_loader, model)
print(f"ACC: {metrics['acc']:.4f}")
print(f"AUC: {metrics['auc']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1: {metrics['f1']:.4f}")



===== CELEB-DF CROSS-DATASET (Paper9) =====


                                                               

ACC: 0.8239
AUC: 0.7013
Precision: 0.9256
Recall: 0.8747
F1: 0.8994
