# Data Preparation

In [None]:
import os
import cv2
import torch
import random
import numpy as np
from tqdm import tqdm
from facenet_pytorch import MTCNN

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
RAW_ROOT = "/kaggle/input/celeb-df-v2"

In [None]:
OUT_ROOT = "/kaggle/working/celebdf_processed"
TRAIN_DIR = os.path.join(OUT_ROOT, "train")
VAL_DIR   = os.path.join(OUT_ROOT, "val")

os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(VAL_DIR, exist_ok=True)

mtcnn = MTCNN(
    image_size=224,
    margin=20,
    select_largest=True,
    post_process=False,   # IMPORTANT: we handle normalization ourselves
    device=device
)

def sample_frame_indices(total_frames, num_frames=24):
    if total_frames <= num_frames:
        return list(range(total_frames))
    return np.linspace(0, total_frames - 1, num_frames).astype(int)

def extract_rgb_dct(video_path, num_frames=24):
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_ids = sample_frame_indices(total_frames, num_frames)

    rgb_faces = []
    dct_maps = []

    for fid in frame_ids:
        cap.set(cv2.CAP_PROP_POS_FRAMES, fid)
        ret, frame = cap.read()
        if not ret:
            continue

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        face = mtcnn(frame_rgb)
        if face is None:
            continue

        # ---- RGB FACE (uint8) ----
        face = face.clamp(0, 255).byte()
        rgb_faces.append(face)

        # ---- DCT (112x112, fp16) ----
        gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)
        gray = cv2.resize(gray, (112, 112))
        dct = cv2.dct(np.float32(gray))
        dct = np.log(np.abs(dct) + 1e-6)
        dct_maps.append(
            torch.tensor(dct, dtype=torch.float16).unsqueeze(0)
        )

    cap.release()

    if len(rgb_faces) < 5:
        return None

    return torch.stack(rgb_faces), torch.stack(dct_maps)

In [None]:
def process_and_save(video_path, label, out_dir):
    result = extract_rgb_dct(video_path)
    if result is None:
        return False

    rgb, dct = result
    name = os.path.splitext(os.path.basename(video_path))[0]

    torch.save(
        {
            "rgb": rgb,   # uint8 (T,3,224,224)
            "dct": dct,   # float16 (T,1,112,112)
            "label": label
        },
        os.path.join(out_dir, name + ".pt")
    )
    return True

In [None]:
real_videos = []
fake_videos = []

for folder in ["Celeb-real", "YouTube-real"]:
    path = os.path.join(RAW_ROOT, folder)
    for f in os.listdir(path):
        real_videos.append(os.path.join(path, f))

for f in os.listdir(os.path.join(RAW_ROOT, "Celeb-synthesis")):
    fake_videos.append(os.path.join(RAW_ROOT, "Celeb-synthesis", f))

random.shuffle(real_videos)
random.shuffle(fake_videos)

In [None]:
TRAIN_REAL = 400
TRAIN_FAKE = 400
VAL_REAL   = 100
VAL_FAKE   = 100

train_real = real_videos[:TRAIN_REAL]
train_fake = fake_videos[:TRAIN_FAKE]

val_real = real_videos[TRAIN_REAL:TRAIN_REAL + VAL_REAL]
val_fake = fake_videos[TRAIN_FAKE:TRAIN_FAKE + VAL_FAKE]

In [None]:
print("Processing TRAIN REAL")
for v in tqdm(train_real):
    process_and_save(v, 0, TRAIN_DIR)

print("Processing TRAIN FAKE")
for v in tqdm(train_fake):
    process_and_save(v, 1, TRAIN_DIR)

print("Processing VAL REAL")
for v in tqdm(val_real):
    process_and_save(v, 0, VAL_DIR)

print("Processing VAL FAKE")
for v in tqdm(val_fake):
    process_and_save(v, 1, VAL_DIR)

In [31]:
print("Train videos:", len(os.listdir(TRAIN_DIR)))
print("Val videos:", len(os.listdir(VAL_DIR)))

sample = torch.load(os.path.join(TRAIN_DIR, os.listdir(TRAIN_DIR)[0]))
print("RGB shape:", sample["rgb"].shape, sample["rgb"].dtype)
print("DCT shape:", sample["dct"].shape, sample["dct"].dtype)

Train videos: 800
Val videos: 200
RGB shape: torch.Size([24, 3, 224, 224]) torch.uint8
DCT shape: torch.Size([24, 1, 112, 112]) torch.float16


# Training Part

In [10]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [11]:
DATA_ROOT = "/kaggle/working/celebdf_processed"
TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR   = os.path.join(DATA_ROOT, "val")

In [12]:
def sample_frames(x, num_frames=24):
    T = x.shape[0]
    if T <= num_frames:
        return x
    idx = torch.linspace(0, T - 1, num_frames).long()
    return x[idx]

def make_windows(x, window=5, stride=4):
    return torch.stack([
        x[i:i+window]
        for i in range(0, x.shape[0] - window + 1, stride)
    ])

In [13]:
class CelebDFVideoDataset(Dataset):
    def __init__(self, root):
        self.files = [
            os.path.join(root, f)
            for f in os.listdir(root)
            if f.endswith(".pt")
        ]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = torch.load(self.files[idx], map_location="cpu")

        rgb = data["rgb"].float() / 255.0     # (T,3,224,224)
        dct = data["dct"].float()             # (T,1,112,112)
        label = torch.tensor(data["label"], dtype=torch.float32)

        rgb = sample_frames(rgb)
        dct = sample_frames(dct)

        rgb_w = make_windows(rgb)
        dct_w = make_windows(dct)

        return rgb_w, dct_w, label

In [14]:
class PixelCNN(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = torchvision.models.resnet18(
            weights="IMAGENET1K_V1"
        )
        backbone.fc = nn.Linear(512, 256)
        self.net = backbone

    def forward(self, x):
        B, W, C, H, W_ = x.shape
        x = x.view(B * W, C, H, W_)
        x = self.net(x)
        return x.view(B, W, -1)

class DCTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Linear(64, 256)

    def forward(self, x):
        B, W, C, H, W_ = x.shape 
        x = x.view(B * W, C, H, W_)
        x = self.conv(x).squeeze(-1).squeeze(-1) # (B*W, 64, 1, 1) --> (B*W,64)
        x = self.fc(x)
        return x.view(B, W, -1)

In [32]:
class TemporalEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(256, 256, batch_first=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out[:, -1]

In [17]:
class VideoDeepfakeModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.pixel = PixelCNN()
        self.dct = DCTCNN()

        self.pixel_lstm = TemporalEncoder()
        self.dct_lstm = TemporalEncoder()

        self.video_lstm = nn.LSTM(512, 256, batch_first=True)
        self.classifier = nn.Linear(256, 1)

    def forward(self, rgb, dct):
        B, N, W, C, H, W_ = rgb.shape

        rgb = rgb.view(B * N, W, C, H, W_)
        dct = dct.view(B * N, W, 1, dct.shape[-2], dct.shape[-1])

        p = self.pixel_lstm(self.pixel(rgb))
        f = self.dct_lstm(self.dct(dct))

        fused = torch.cat([p, f], dim=-1)
        fused = fused.view(B, N, -1)

        v, _ = self.video_lstm(fused)
        return self.classifier(v[:, -1]).squeeze(1)

In [18]:
model = VideoDeepfakeModel().to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=8
)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 44.7M/44.7M [00:00<00:00, 185MB/s]


In [19]:
for p in model.pixel.parameters():
    p.requires_grad = False

In [20]:
def train_epoch(loader):
    model.train()
    correct, total, loss_sum = 0, 0, 0

    for rgb, dct, y in tqdm(loader):
        rgb = rgb.to(device)
        dct = dct.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(rgb, dct)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        preds = (torch.sigmoid(logits) > 0.5).float()
        correct += (preds == y).sum().item()
        total += y.size(0)
        loss_sum += loss.item()

    return correct / total, loss_sum / len(loader)


def eval_epoch(loader):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for rgb, dct, y in loader:
            rgb = rgb.to(device)
            dct = dct.to(device)
            y = y.to(device)

            logits = model(rgb, dct)
            preds = (torch.sigmoid(logits) > 0.5).float()
            correct += (preds == y).sum().item()
            total += y.size(0)

    return correct / total


In [21]:
train_ds = CelebDFVideoDataset(TRAIN_DIR)
val_ds   = CelebDFVideoDataset(VAL_DIR)

train_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [22]:
EPOCHS = 8

for epoch in range(EPOCHS):

    if epoch == 4:
        print("ðŸ”“ Unfreezing Pixel CNN")
        for p in model.pixel.parameters():
            p.requires_grad = True

    train_acc, train_loss = train_epoch(train_loader)
    val_acc = eval_epoch(val_loader)
    scheduler.step()

    print(
        f"Epoch {epoch+1}/{EPOCHS} | "
        f"Train Acc: {train_acc:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:17<00:00, 11.25it/s]


Epoch 1/8 | Train Acc: 0.611 | Val Acc: 0.715


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:16<00:00, 11.80it/s]


Epoch 2/8 | Train Acc: 0.705 | Val Acc: 0.760


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:16<00:00, 11.77it/s]


Epoch 3/8 | Train Acc: 0.714 | Val Acc: 0.685


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:17<00:00, 11.76it/s]


Epoch 4/8 | Train Acc: 0.744 | Val Acc: 0.780
ðŸ”“ Unfreezing Pixel CNN


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:37<00:00,  5.37it/s]


Epoch 5/8 | Train Acc: 0.660 | Val Acc: 0.830


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:37<00:00,  5.38it/s]


Epoch 6/8 | Train Acc: 0.812 | Val Acc: 0.730


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:37<00:00,  5.38it/s]


Epoch 7/8 | Train Acc: 0.899 | Val Acc: 0.880


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [00:37<00:00,  5.39it/s]


Epoch 8/8 | Train Acc: 0.963 | Val Acc: 0.950


In [27]:
torch.save(model.state_dict(),'/kaggle/working/final_model.pt')

# Testing Part

In [37]:
TEST_DIR = os.path.join(OUT_ROOT, "test")
os.makedirs(TEST_DIR, exist_ok=True)

In [40]:
TEST_REAL = 100
TEST_FAKE = 100

test_real = real_videos[
    TRAIN_REAL + VAL_REAL :
    TRAIN_REAL + VAL_REAL + TEST_REAL
]

test_fake = fake_videos[
    TRAIN_FAKE + VAL_FAKE :
    TRAIN_FAKE + VAL_FAKE + TEST_FAKE
]


In [41]:
print("Processing TEST REAL")
for v in tqdm(test_real):
    process_and_save(v, 0, TEST_DIR)

print("Processing TEST FAKE")
for v in tqdm(test_fake):
    process_and_save(v, 1, TEST_DIR)

Processing TEST REAL


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100/100 [03:36<00:00,  2.16s/it]


Processing TEST FAKE


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100/100 [03:41<00:00,  2.21s/it]


In [42]:
print("Test videos:", len(os.listdir(TEST_DIR)))

Test videos: 199


In [43]:
TEST_DIR = os.path.join(DATA_ROOT, "test")

test_ds = CelebDFVideoDataset(TEST_DIR)

test_loader = DataLoader(
    test_ds,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [44]:
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score
)

def evaluate_test(loader):
    model.eval()

    y_true = []
    y_pred = []
    y_prob = []

    with torch.no_grad():
        for rgb, dct, y in tqdm(loader):
            rgb = rgb.to(device)
            dct = dct.to(device)

            logits = model(rgb, dct)
            probs = torch.sigmoid(logits)

            y_true.extend(y.numpy())
            y_prob.extend(probs.cpu().numpy())
            y_pred.extend((probs > 0.5).cpu().numpy())

    return (
        np.array(y_true),
        np.array(y_pred),
        np.array(y_prob)
    )


In [45]:
y_true, y_pred, y_prob = evaluate_test(test_loader)

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 50/50 [00:03<00:00, 12.56it/s]


In [48]:
acc  = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred)
rec  = recall_score(y_true, y_pred)
f1   = f1_score(y_true, y_pred)
auc  = roc_auc_score(y_true, y_prob)

print(f"Test Accuracy : {acc:.4f}")
print(f"Precision     : {prec:.4f}")
print(f"Recall        : {rec:.4f}")
print(f"F1-score      : {f1:.4f}")
print(f"AUC           : {auc:.4f}")

confuson_Matrix = confusion_matrix(y_true, y_pred)
print("Confuson Matrix:")
print(consfusion_Matrix)

Test Accuracy : 0.9246
Precision     : 0.8899
Recall        : 0.9700
F1-score      : 0.9282
AUC           : 0.9779
Confuson Matrix:
[[87 12]
 [ 3 97]]
