In [62]:
import os
import random
import numpy as np
from PIL import Image
from glob import glob

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
from sklearn.metrics import accuracy_score, roc_auc_score

# ----------------- Seed -----------------
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ----------------- Transforms -----------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# ----------------- Dataset paths -----------------
train_path = '/kaggle/input/celebdf-v2image-dataset/Celeb_V2/Train'
val_path   = '/kaggle/input/celebdf-v2image-dataset/Celeb_V2/Val'
test_path  = '/kaggle/input/celebdf-v2image-dataset/Celeb_V2/Test'

# ----------------- Datasets -----------------
train_dataset = datasets.ImageFolder(root=train_path, transform=transform)
val_dataset   = datasets.ImageFolder(root=val_path, transform=transform)
test_dataset  = datasets.ImageFolder(root=test_path, transform=transform)

# ----------------- Dataloaders -----------------
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_siz e=16, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")


Train: 80824, Val: 10104, Test: 10103


In [63]:
import timm
import torch.fft

class FrequencyHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(112*112, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
        )

    def forward(self, x):
        gray = torch.mean(x, dim=1)  # b x H x W
        fft = torch.fft.fft2(gray)
        fft_shift = torch.fft.fftshift(fft)
        mag = torch.log1p(torch.abs(fft_shift))
        b, H, W = mag.shape
        mag_crop = mag[:, :H//2, :W//2]
        mag_flat = mag_crop.reshape(b, -1)
        out = self.fc(mag_flat)
        return out

class HybridDetector(nn.Module):
    def __init__(self, backbone_name='tf_efficientnet_b3_ns', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0, global_pool='')
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.spatial_head = nn.Sequential(
            nn.Linear(1536, 128),
            nn.ReLU(),
        )
        self.freq_head = FrequencyHead()
        self.fc = nn.Sequential(
            nn.Linear(1536 + 128 + 128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        last = self.backbone(x)
        pooled = self.pool(last).view(last.size(0), -1)
        spat = self.spatial_head(pooled)
        freq = self.freq_head(x)
        concat = torch.cat([pooled, spat, freq], dim=1)
        out = self.fc(concat).squeeze(1)
        return out


In [64]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridDetector(pretrained=True).to(device)


In [65]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.6)


In [66]:
from sklearn.metrics import roc_curve

def compute_eer(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.absolute((fnr - fpr)))]
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    return eer


In [None]:
num_epochs = 5
best_val_auc = 0.0

for epoch in range(num_epochs):
    model.train()
    train_losses = []

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.float().to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()

    # Validation
    model.eval()
    val_losses = []
    all_labels = []
    all_preds  = []

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.float().to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_losses.append(loss.item())
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())

    val_acc = accuracy_score(all_labels, np.round(all_preds))
    val_auc = roc_auc_score(all_labels, all_preds)
    val_eer = compute_eer(np.array(all_labels), np.array(all_preds))

    print(f"Epoch {epoch+1}/{num_epochs} | "
          f"TrainLoss: {np.mean(train_losses):.4f} | "
          f"ValLoss: {np.mean(val_losses):.4f} | "
          f"ValAcc: {val_acc:.4f} | "
          f"ValAUC: {val_auc:.4f} | "
          f"EER: {val_eer:.4f}")

    # Save best model
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), '/kaggle/working/best_hybrid_model.pth')
        print("Saved Best Model!")


Epoch 1/5 | TrainLoss: 0.0758 | ValLoss: 0.0230 | ValAcc: 0.9923 | ValAUC: 0.9996 | EER: 0.0077
Saved Best Model!
Epoch 2/5 | TrainLoss: 0.0255 | ValLoss: 0.0310 | ValAcc: 0.9906 | ValAUC: 0.9998 | EER: 0.0071
Saved Best Model!
Epoch 3/5 | TrainLoss: 0.0171 | ValLoss: 0.0205 | ValAcc: 0.9917 | ValAUC: 0.9999 | EER: 0.0036
Saved Best Model!
Epoch 4/5 | TrainLoss: 0.0064 | ValLoss: 0.0127 | ValAcc: 0.9949 | ValAUC: 0.9999 | EER: 0.0047


In [None]:
# Load best model
model.load_state_dict(torch.load('/kaggle/working/best_hybrid_model.pth'))
model.eval()

all_labels = []
all_preds = []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.float().to(device)
        outputs = model(imgs)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(outputs.cpu().numpy())

test_acc = accuracy_score(all_labels, np.round(all_preds))
test_auc = roc_auc_score(all_labels, all_preds)
test_eer = compute_eer(np.array(all_labels), np.array(all_preds))

print(f"Test Acc: {test_acc:.4f} | Test AUC: {test_auc:.4f} | Test EER: {test_eer:.4f}")
