In [None]:
# ============================================================
# 08 - External Evaluation on Celeb-DF
# ============================================================

from pathlib import Path
import json
import cv2
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import timm
from torchvision import transforms
from sklearn.metrics import roc_auc_score

In [None]:
ROOT = Path.cwd().parent

CELEBDF_VIDEOS = ROOT / "videos" / "CelebDF"
FRAMES_ROOT   = ROOT / "preprocessed" / "frames" / "celebdf"
EMB_ROOT      = ROOT / "embeddings" / "celebdf"

LABELS_JSON   = ROOT / "data" / "labels.json"

SPATIAL_CKPT  = ROOT / "checkpoints" / "spatial" / "spatial_best_valAUC.pth"
TEMPORAL_CKPT = ROOT / "checkpoints" / "temporal" / "temporal_best_valAUC.pth"
ENSEMBLE_CKPT = ROOT / "ensemble_outputs" / "ensemble_calibrated.joblib"

NUM_FRAMES = 8
IMG_SIZE   = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

FRAMES_ROOT.mkdir(parents=True, exist_ok=True)
EMB_ROOT.mkdir(parents=True, exist_ok=True)

print("Device:", DEVICE)

In [None]:
with open(LABELS_JSON, "r") as f:
    labels_map = json.load(f)

def get_label(stem):
    for k, v in labels_map.items():
        if stem in k:
            return int(v)
    raise KeyError(stem)

celeb_videos = sorted(CELEBDF_VIDEOS.glob("*.mp4"))
print("CelebDF videos:", len(celeb_videos))

In [None]:
def extract_frames(video_path, out_dir, num_frames=8):
    cap = cv2.VideoCapture(str(video_path))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    idxs = np.linspace(0, total - 1, num_frames, dtype=int)

    frames = []
    for i in range(total):
        ret, frame = cap.read()
        if not ret:
            break
        if i in idxs:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)

    cap.release()

    out_dir.mkdir(parents=True, exist_ok=True)
    for i, f in enumerate(frames):
        cv2.imwrite(str(out_dir / f"frame_{i:02d}.jpg"),
                    cv2.cvtColor(f, cv2.COLOR_RGB2BGR))

for vp in tqdm(celeb_videos):
    stem = vp.stem
    out_dir = FRAMES_ROOT / stem
    if out_dir.exists():
        continue
    extract_frames(vp, out_dir)

print("Frame extraction done.")

In [None]:
class SpatialModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model("efficientnet_b3", pretrained=False, num_classes=0)
        self.head = nn.Sequential(
            nn.Linear(self.backbone.num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        return self.head(self.backbone(x)).squeeze(1)

spatial = SpatialModel().to(DEVICE)
ck = torch.load(SPATIAL_CKPT, map_location=DEVICE)
spatial.load_state_dict(ck["model_state"])
spatial.eval()

In [None]:
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406),
                         std=(0.229,0.224,0.225))
])

spatial_feats = {}

with torch.no_grad():
    for vdir in tqdm(sorted(FRAMES_ROOT.iterdir())):
        probs = []
        for img_p in sorted(vdir.glob("*.jpg")):
            img = tfm(cv2.cvtColor(cv2.imread(str(img_p)), cv2.COLOR_BGR2RGB))
            img = img.unsqueeze(0).to(DEVICE)
            logit = spatial(img)
            prob = torch.sigmoid(logit).item()
            probs.append(prob)

        spatial_feats[vdir.name] = {
            "mean": np.mean(probs),
            "max": np.max(probs),
            "median": np.median(probs),
            "std": np.std(probs),
        }

In [None]:
class TemporalModel(nn.Module):
    def __init__(self, feat_dim=1536):
        super().__init__()
        self.lstm = nn.LSTM(feat_dim, 512, 2, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(1024, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out.mean(dim=1)).squeeze(1)

temporal = TemporalModel().to(DEVICE)
ck = torch.load(TEMPORAL_CKPT, map_location=DEVICE)
temporal.load_state_dict(ck["model_state"])
temporal.eval()

In [None]:
from joblib import load

ensemble = load(ENSEMBLE_CKPT)

y_true, y_pred = [], []

for vp in celeb_videos:
    stem = vp.stem
    label = get_label(stem)

    f = spatial_feats[stem]
    X = np.array([[f["mean"], f["max"], f["median"], f["std"]]])

    prob = ensemble.predict_proba(X)[0,1]

    y_true.append(label)
    y_pred.append(prob)

print("CelebDF AUC:", roc_auc_score(y_true, y_pred))