# Step 4: Temporal Consistency Verification Module
## Ablation B â€” Does Temporal Modeling Help?

**Step 3 result (honest baseline):** EfficientNet-B0, frame-level only â†’ Celeb-DF AUC = 0.6135

**This notebook adds the temporal module on top:**
- Same EfficientNet-B0 backbone (frozen after initial training)
- Input: 8 frames per video â†’ sequence of per-frame embeddings â†’ temporal model
- Temporal model: lightweight GRU (fits T4 VRAM, fast to train)
- Video-level prediction from temporal sequence

**Why GRU before Mamba?**  
GRU is fast, well-understood, and a proven temporal baseline. If GRU improves over  
frame-level by +3%, Mamba will do better. If GRU shows no improvement, we know  
temporal modeling needs rethinking before investing in Mamba.

**Ablation question:** Does modeling temporal consistency across frames add signal  
beyond what a single frame already provides?

**Expected:** +5-10% Celeb-DF AUC (0.6135 â†’ 0.65-0.72)

## Section 1 â€” Setup

In [None]:
import os, json, random, time, warnings, sys
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, roc_curve
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
warnings.filterwarnings('ignore')

SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device : {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU    : {torch.cuda.get_device_name(0)}")
    print(f"VRAM   : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

OUTPUT_DIR = Path('/kaggle/working/step4')
CKPT_DIR   = OUTPUT_DIR / 'checkpoints'
PLOTS_DIR  = OUTPUT_DIR / 'plots'
for d in [OUTPUT_DIR, CKPT_DIR, PLOTS_DIR]:
    d.mkdir(parents=True, exist_ok=True)
print(f"Outputs â†’ {OUTPUT_DIR}")


In [None]:
CFG = {
    # Data
    'img_size':        224,
    'n_frames':        8,       # â†‘ was 4 â€” more frames = better temporal signal
    'n_train_real':    600,
    'n_train_fake':    600,     # 150 per method Ã— 4 methods
    'n_val_each':      50,

    # Model
    'embed_dim':       1280,    # EfficientNet-B0 output dim
    'temporal_hidden': 512,     # GRU hidden size
    'temporal_layers': 2,       # GRU layers
    'dropout':         0.3,

    # Training â€” two phases
    # Phase 1: train temporal head only (backbone frozen) â€” fast convergence
    # Phase 2: fine-tune everything together â€” refine
    'phase1_epochs':   10,
    'phase2_epochs':   10,
    'lr_head':         1e-3,    # high LR for fresh temporal head
    'lr_backbone':     1e-5,    # low LR for pretrained backbone
    'weight_decay':    1e-4,
    'label_smoothing': 0.0,
    'batch_size':      16,      # lower batch â€” videos not frames now
}

TRAIN_METHODS = ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures']

print("Config:")
for k, v in CFG.items():
    print(f"  {k:22s}: {v}")
print(f"Training methods: {TRAIN_METHODS}")


## Section 2 â€” Dataset Paths & ID-Based Splits

In [None]:
KAGGLE_INPUT = Path('/kaggle/input')

def locate_ff_root(base):
    known = base / 'datasets' / 'xdxd003' / 'ff-c23' / 'FaceForensics++_C23'
    if known.exists(): return known
    for d in sorted(base.rglob('*')):
        if d.is_dir():
            if sum(1 for m in ['Deepfakes','Face2Face','FaceSwap'] if (d/m).exists()) >= 2:
                return d
    return None

def locate_celeb_root(base):
    known = base / 'datasets' / 'reubensuju' / 'celeb-df-v2'
    if known.exists(): return known
    for d in sorted(base.rglob('*')):
        if d.is_dir() and (d/'Celeb-real').exists(): return d
    return None

FF_ROOT    = locate_ff_root(KAGGLE_INPUT)
CELEB_ROOT = locate_celeb_root(KAGGLE_INPUT)
print(f"FF++    : {FF_ROOT}")
print(f"Celeb-DF: {CELEB_ROOT}")

FF_REAL = sorted(FF_ROOT.rglob('original*/*.mp4')) if FF_ROOT else []
if not FF_REAL and FF_ROOT:
    FF_REAL = sorted(p for p in FF_ROOT.rglob('*.mp4') if 'original' in str(p).lower())

FF_FAKE_BY_METHOD = {}
for method in TRAIN_METHODS:
    paths = sorted((FF_ROOT/method).glob('*.mp4')) if FF_ROOT and (FF_ROOT/method).exists() else []
    FF_FAKE_BY_METHOD[method] = paths
    print(f"  FF++/{method:20s}: {len(paths)} videos")
print(f"  FF++/{'real':20s}: {len(FF_REAL)} videos")

CDF_REAL, CDF_FAKE = [], []
if CELEB_ROOT:
    CDF_REAL = (sorted((CELEB_ROOT/'Celeb-real').glob('*.mp4')) +
                sorted((CELEB_ROOT/'YouTube-real').glob('*.mp4')))
    CDF_FAKE = sorted((CELEB_ROOT/'Celeb-synthesis').glob('*.mp4'))
    print(f"  Celeb-DF real: {len(CDF_REAL)} | fake: {len(CDF_FAKE)}")


In [None]:
# â”€â”€ ID-based split (no leakage) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
def get_video_id(path):
    return Path(path).stem.split('_')[0]

rng = random.Random(SEED)
all_ids = sorted(set(get_video_id(p) for p in FF_REAL))
rng.shuffle(all_ids)
n_train_ids = int(len(all_ids) * 0.75)
train_ids   = set(all_ids[:n_train_ids])
val_ids     = set(all_ids[n_train_ids:])
print(f"Video IDs â€” train: {len(train_ids)}, val: {len(val_ids)} (no overlap)")

n_per_method = CFG['n_train_fake'] // len(TRAIN_METHODS)

# Training set
train_real = rng.sample([p for p in FF_REAL if get_video_id(p) in train_ids],
                         min(CFG['n_train_real'], len([p for p in FF_REAL if get_video_id(p) in train_ids])))
TRAIN_DATA = [(p, 0) for p in train_real]
for method in TRAIN_METHODS:
    pool   = [p for p in FF_FAKE_BY_METHOD[method] if get_video_id(p) in train_ids]
    picked = rng.sample(pool, min(n_per_method, len(pool)))
    TRAIN_DATA += [(p, 1) for p in picked]
rng.shuffle(TRAIN_DATA)

# Validation set
val_real = rng.sample([p for p in FF_REAL if get_video_id(p) in val_ids],
                       min(CFG['n_val_each'], len([p for p in FF_REAL if get_video_id(p) in val_ids])))
VAL_DATA = [(p, 0) for p in val_real]
for method in TRAIN_METHODS:
    pool   = [p for p in FF_FAKE_BY_METHOD[method] if get_video_id(p) in val_ids]
    picked = rng.sample(pool, min(CFG['n_val_each']//len(TRAIN_METHODS), len(pool)))
    VAL_DATA += [(p, 1) for p in picked]
rng.shuffle(VAL_DATA)

# Celeb-DF
n_cdf    = min(200, len(CDF_REAL), len(CDF_FAKE))
CDF_TEST = ([(p,0) for p in rng.sample(CDF_REAL, n_cdf)] +
            [(p,1) for p in rng.sample(CDF_FAKE,  n_cdf)])

print(f"Train: {sum(1 for _,l in TRAIN_DATA if l==0)} real + "
      f"{sum(1 for _,l in TRAIN_DATA if l==1)} fake = {len(TRAIN_DATA)}")
print(f"Val  : {sum(1 for _,l in VAL_DATA   if l==0)} real + "
      f"{sum(1 for _,l in VAL_DATA   if l==1)} fake = {len(VAL_DATA)}")
print(f"CDF  : {n_cdf} real + {n_cdf} fake = {len(CDF_TEST)}")


## Section 3 â€” Video Dataset

**Key difference from Steps 2-3:** each sample is now a full video clip (8 frames),  
not an individual frame. The model receives a sequence and predicts once per video.  
This is what enables temporal modeling.

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

frame_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

aug_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


def load_video_clip(video_path, n_frames, img_size):
    """Load n evenly-spaced frames. Returns (n_frames, H, W, 3) uint8 or None."""
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened(): return None
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total < 1:
        cap.release(); return None
    positions = np.linspace(0, total-1, n_frames, dtype=int)
    frames = []
    for pos in positions:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(pos))
        ret, frame = cap.read()
        if not ret: continue
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        h, w = frame.shape[:2]
        frame = frame[int(h*0.05):int(h*0.95), int(w*0.10):int(w*0.90)]
        frame = cv2.resize(frame, (img_size, img_size))
        frames.append(frame)
    cap.release()
    if not frames: return None
    while len(frames) < n_frames: frames.append(frames[-1])
    return frames[:n_frames]


class VideoDataset(Dataset):
    """
    Each item is a full video clip: (n_frames, 3, H, W) tensor.
    Pre-extracts all clips at construction â€” no video I/O in DataLoader.
    """
    def __init__(self, video_label_pairs, n_frames, img_size, augment=False):
        self.augment   = augment
        self.transform = aug_tf if augment else frame_tf
        self.clips     = []   # list of (frames_list, label)
        failed = 0
        for path, label in tqdm(video_label_pairs, ncols=80, desc='Loading clips'):
            frames = load_video_clip(str(path), n_frames, img_size)
            if frames is None:
                failed += 1
                continue
            self.clips.append((frames, label))
        print(f"  {len(self.clips)} clips ready ({failed} failed)")

    def __len__(self): return len(self.clips)

    def __getitem__(self, idx):
        frames, label = self.clips[idx]
        # Apply same random transform consistently to all frames in clip
        tensors = torch.stack([self.transform(f) for f in frames])  # (T, 3, H, W)
        return tensors, torch.tensor(label, dtype=torch.long)


print("Pre-extracting video clips (~8 min for 1200 videos Ã— 8 frames)...")
t0 = time.time()
train_ds = VideoDataset(TRAIN_DATA, CFG['n_frames'], CFG['img_size'], augment=True)
val_ds   = VideoDataset(VAL_DATA,   CFG['n_frames'], CFG['img_size'], augment=False)
cdf_ds   = VideoDataset(CDF_TEST,   CFG['n_frames'], CFG['img_size'], augment=False)
print(f"Done in {time.time()-t0:.1f}s")

train_loader = DataLoader(train_ds, batch_size=CFG['batch_size'],
                          shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=CFG['batch_size'],
                          shuffle=False, num_workers=0, pin_memory=False)
cdf_loader   = DataLoader(cdf_ds,   batch_size=CFG['batch_size'],
                          shuffle=False, num_workers=0, pin_memory=False)

print(f"Train clips: {len(train_ds)} | Val: {len(val_ds)} | CDF: {len(cdf_ds)}")
x, y = next(iter(train_loader))
print(f"Batch: x={x.shape} (B, T, C, H, W), labels={y.unique().tolist()}")


## Section 4 â€” Model: EfficientNet-B0 + Temporal GRU

**Architecture:**
1. EfficientNet-B0 extracts per-frame embeddings independently â†’ (B, T, 1280)
2. GRU processes the temporal sequence â†’ captures frame-to-frame consistency
3. Final hidden state â†’ classification head

**Two-phase training:**
- Phase 1 (epochs 1-10): backbone FROZEN, only train GRU + head
- Phase 2 (epochs 11-20): unfreeze backbone with very low LR

In [None]:
class TemporalDeepfakeDetector(nn.Module):
    """
    EfficientNet-B0 spatial backbone + bidirectional GRU temporal model.
    
    The GRU processes the sequence of per-frame embeddings and detects
    inconsistencies in face identity across time â€” the core temporal signal.
    """
    def __init__(self):
        super().__init__()

        # â”€â”€ Spatial backbone â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        effnet = models.efficientnet_b0(
            weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        # Remove the classifier â€” keep feature extractor only
        self.backbone = effnet.features   # outputs (B, 1280, 7, 7)
        self.pool     = nn.AdaptiveAvgPool2d(1)   # â†’ (B, 1280)

        # â”€â”€ Temporal module â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        # Project to smaller dim before GRU to save memory
        self.proj = nn.Sequential(
            nn.Linear(CFG['embed_dim'], CFG['temporal_hidden']),
            nn.LayerNorm(CFG['temporal_hidden']),
            nn.GELU(),
        )

        # Bidirectional GRU â€” forward pass sees future, backward sees past
        # This lets the model ask: "is frame 8 consistent with frame 1?"
        self.gru = nn.GRU(
            input_size  = CFG['temporal_hidden'],
            hidden_size = CFG['temporal_hidden'],
            num_layers  = CFG['temporal_layers'],
            batch_first = True,
            bidirectional = True,
            dropout = CFG['dropout'] if CFG['temporal_layers'] > 1 else 0.0,
        )

        # â”€â”€ Classification head â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        gru_out_dim = CFG['temporal_hidden'] * 2   # bidirectional
        self.head = nn.Sequential(
            nn.Dropout(CFG['dropout']),
            nn.Linear(gru_out_dim, 256),
            nn.GELU(),
            nn.Dropout(CFG['dropout'] * 0.5),
            nn.Linear(256, 2),
        )

        # Orthogonal init for GRU (prevents dead-branch bug from V7.3)
        for name, param in self.gru.named_parameters():
            if 'weight_hh' in name:
                nn.init.orthogonal_(param)
            elif 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def extract_frame_embeddings(self, x_video):
        """
        x_video: (B, T, C, H, W)
        returns: (B, T, embed_dim)
        """
        B, T, C, H, W = x_video.shape
        # Flatten batch and time dims â†’ process all frames at once
        x_flat = x_video.view(B * T, C, H, W)
        feats  = self.backbone(x_flat)    # (B*T, 1280, 7, 7)
        feats  = self.pool(feats)         # (B*T, 1280, 1, 1)
        feats  = feats.view(B*T, -1)      # (B*T, 1280)
        feats  = feats.view(B, T, -1)     # (B, T, 1280)
        return feats

    def forward(self, x_video):
        """
        x_video: (B, T, C, H, W)
        returns: logits (B, 2)
        """
        # Step 1: per-frame embeddings
        embeds = self.extract_frame_embeddings(x_video)  # (B, T, 1280)

        # Step 2: project to temporal hidden dim
        proj   = self.proj(embeds)   # (B, T, 512)

        # Step 3: temporal GRU â€” models consistency across frames
        out, _ = self.gru(proj)      # (B, T, 1024) bidirectional

        # Step 4: use mean pooling over time (captures overall temporal pattern)
        # + last hidden state (captures final state after seeing all frames)
        temporal_feat = out.mean(dim=1)  # (B, 1024)

        # Step 5: classify
        return self.head(temporal_feat)  # (B, 2)

    def freeze_backbone(self):
        for p in self.backbone.parameters():
            p.requires_grad = False
        print("Backbone FROZEN")

    def unfreeze_backbone(self):
        for p in self.backbone.parameters():
            p.requires_grad = True
        print("Backbone UNFROZEN")

    def get_param_groups(self, phase):
        if phase == 1:
            # Phase 1: only temporal module + head
            return [{'params': list(self.proj.parameters()) +
                               list(self.gru.parameters()) +
                               list(self.head.parameters()),
                     'lr': CFG['lr_head']}]
        else:
            # Phase 2: backbone (low LR) + temporal (lower LR for stability)
            return [
                {'params': self.backbone.parameters(), 'lr': CFG['lr_backbone']},
                {'params': list(self.proj.parameters()) +
                           list(self.gru.parameters()) +
                           list(self.head.parameters()),
                 'lr': CFG['lr_head'] / 10},
            ]


model = TemporalDeepfakeDetector().to(DEVICE)
total = sum(p.numel() for p in model.parameters())
backbone_p = sum(p.numel() for p in model.backbone.parameters())
temporal_p = sum(p.numel() for p in list(model.proj.parameters()) +
                              list(model.gru.parameters()) +
                              list(model.head.parameters()))
print(f"Total params   : {total/1e6:.2f}M")
print(f"Backbone params: {backbone_p/1e6:.2f}M")
print(f"Temporal params: {temporal_p/1e6:.2f}M")

with torch.no_grad():
    test_in  = torch.randn(2, CFG['n_frames'], 3, 224, 224).to(DEVICE)
    test_out = model(test_in)
    print(f"Forward: (2, {CFG['n_frames']}, 3, 224, 224) â†’ {test_out.shape} âœ“")


## Section 5 â€” Two-Phase Training

In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=CFG['label_smoothing'])


def train_epoch(model, loader, optimizer):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        logits = model(x)
        loss   = criterion(logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
        correct    += (logits.detach().argmax(1) == y).sum().item()
        total      += y.size(0)
    return total_loss / len(loader), correct / total


def evaluate(model, loader):
    model.eval()
    all_labels, all_probs = [], []
    total_loss, n = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y   = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            total_loss += criterion(logits, y).item()
            probs  = F.softmax(logits, dim=1)[:, 1]
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            n += 1
    labels = np.array(all_labels)
    probs  = np.array(all_probs)
    auc    = roc_auc_score(labels, probs) if len(np.unique(labels)) > 1 else 0.5
    acc    = ((probs > 0.5).astype(int) == labels).mean()
    return {'auc': auc, 'acc': acc, 'loss': total_loss/max(n,1),
            'labels': labels, 'probs': probs}


def run_phase(phase, epochs, model, loader_tr, loader_val):
    optimizer = torch.optim.AdamW(
        model.get_param_groups(phase), weight_decay=CFG['weight_decay'])

    def lr_lambda(ep):
        warmup = 2
        if ep < warmup: return (ep+1)/warmup
        progress = (ep-warmup) / max(1, epochs-warmup)
        return 0.5*(1+np.cos(np.pi*progress))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    history = {'train_loss':[], 'train_acc':[], 'val_auc':[], 'val_loss':[]}
    best_auc, best_epoch = 0.0, 0

    print(f"\n{'='*68}")
    print(f"PHASE {phase} â€” {'Temporal head only (backbone frozen)' if phase==1 else 'Full fine-tune'}")
    print(f"{'='*68}")
    print(f"{'Ep':>3} {'TrLoss':>8} {'TrAcc':>7} {'VaLoss':>8} "
          f"{'VaAUC':>7} {'VaAcc':>7} {'t':>5}")
    print(f"{'-'*68}")

    for epoch in range(epochs):
        t0 = time.time()
        tr_loss, tr_acc = train_epoch(model, loader_tr, optimizer)
        val_m           = evaluate(model, loader_val)
        scheduler.step()

        history['train_loss'].append(tr_loss)
        history['train_acc'].append(tr_acc)
        history['val_auc'].append(val_m['auc'])
        history['val_loss'].append(val_m['loss'])

        flag = ' âœ“' if val_m['auc'] > best_auc else ''
        print(f"{epoch+1:>3} {tr_loss:>8.4f} {tr_acc:>7.3f} {val_m['loss']:>8.4f} "
              f"{val_m['auc']:>7.4f} {val_m['acc']:>7.3f} "
              f"{time.time()-t0:>4.0f}s{flag}")
        sys.stdout.flush()

        if val_m['auc'] > best_auc:
            best_auc   = val_m['auc']
            best_epoch = epoch + 1
            torch.save({'epoch': epoch, 'model_state': model.state_dict(),
                        'val_auc': best_auc, 'phase': phase},
                       CKPT_DIR / f'best_phase{phase}.pth')

    print(f"Phase {phase} best: AUC={best_auc:.4f} at epoch {best_epoch}")
    return history, best_auc

print("âœ… Training functions ready")


In [None]:
start_time = time.time()

# â”€â”€ Phase 1: Train temporal head only â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
model.freeze_backbone()
hist1, best_auc_p1 = run_phase(1, CFG['phase1_epochs'], model, train_loader, val_loader)

# â”€â”€ Phase 2: Fine-tune everything â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
model.unfreeze_backbone()
hist2, best_auc_p2 = run_phase(2, CFG['phase2_epochs'], model, train_loader, val_loader)

total_time = time.time() - start_time
print(f"\nTotal training time: {total_time/60:.1f} min")
print(f"Best phase 1 AUC: {best_auc_p1:.4f}")
print(f"Best phase 2 AUC: {best_auc_p2:.4f}")

# Load the best overall model
best_phase = 2 if best_auc_p2 >= best_auc_p1 else 1
ckpt = torch.load(CKPT_DIR / f'best_phase{best_phase}.pth',
                  map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt['model_state'])
print(f"Loaded best model from phase {best_phase}, epoch {ckpt['epoch']+1}")


## Section 6 â€” Evaluation & Comparison

In [None]:
ff_m  = evaluate(model, val_loader)
cdf_m = evaluate(model, cdf_loader)

# Baselines from previous steps
STEP3 = {'ff_auc': 0.6850, 'cdf_auc': 0.6135}

print("\n" + "="*62)
print("ABLATION RESULTS â€” Temporal vs Frame-Level")
print("="*62)
print(f"{'Metric':<35} {'Step 3 (frame)':>14} {'Step 4 (temporal)':>14}")
print("-"*62)
print(f"{'FF++ Val AUC':<35} {STEP3['ff_auc']:>14.4f} {ff_m['auc']:>14.4f}")
print(f"{'Celeb-DF AUC (cross-dataset)':<35} {STEP3['cdf_auc']:>14.4f} {cdf_m['auc']:>14.4f}")
delta = cdf_m['auc'] - STEP3['cdf_auc']
print(f"{'Improvement from temporal module':<35} {'':>14} {delta:>+14.4f}")
print("="*62)

if delta >= 0.05:
    verdict = "ðŸŸ¢ TEMPORAL HELPS â€” +5%+ improvement. Mamba will do better."
elif delta >= 0.02:
    verdict = "ðŸŸ¡ MODEST IMPROVEMENT â€” Temporal adds signal. Continue to Step 5 (B4)."
elif delta >= 0.0:
    verdict = "ðŸŸ¡ MARGINAL â€” Temporal module barely helps with B0. B4 backbone needed first."
else:
    verdict = "ðŸ”´ NO IMPROVEMENT â€” Investigate: temporal head may need more epochs or data."
print(f"\n{verdict}")


In [None]:
# â”€â”€ Plots â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('Step 4: Temporal GRU â€” Training Curves & Results', fontsize=14, fontweight='bold')

# Combined loss curves (phase 1 + phase 2)
all_tr_loss = hist1['train_loss'] + hist2['train_loss']
all_va_loss = hist1['val_loss']   + hist2['val_loss']
all_va_auc  = hist1['val_auc']    + hist2['val_auc']
x = range(1, len(all_tr_loss)+1)
split = CFG['phase1_epochs']

axes[0].plot(x, all_tr_loss, color='#3498db', linewidth=2, label='Train loss')
axes[0].plot(x, all_va_loss, color='#e74c3c', linewidth=2, label='Val loss')
axes[0].axvline(split+0.5, color='gray', linestyle='--', alpha=0.7, label='Phase boundary')
axes[0].set_title('Loss (P1: frozen | P2: fine-tune)')
axes[0].set_xlabel('Epoch'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

axes[1].plot(x, all_va_auc, color='#2ecc71', linewidth=2.5, label='Val AUC')
axes[1].axvline(split+0.5, color='gray', linestyle='--', alpha=0.7)
axes[1].axhline(max(all_va_auc), color='#2ecc71', linestyle='--', alpha=0.5,
                label=f'Best={max(all_va_auc):.4f}')
axes[1].axhline(STEP3['ff_auc'], color='gray', linestyle=':', alpha=0.6,
                label=f'Step3={STEP3["ff_auc"]:.4f}')
axes[1].axhline(cdf_m['auc'], color='#e74c3c', linestyle='--', alpha=0.7,
                label=f'CDF={cdf_m["auc"]:.4f}')
axes[1].axhline(STEP3['cdf_auc'], color='#e74c3c', linestyle=':', alpha=0.5,
                label=f'Step3 CDF={STEP3["cdf_auc"]:.4f}')
axes[1].set_title('Val AUC'); axes[1].set_xlabel('Epoch')
axes[1].set_ylim(0.40, 1.0); axes[1].legend(fontsize=8); axes[1].grid(True, alpha=0.3)

# ROC curves
for color, m, label in [
    ('#3498db', ff_m,  f"FF++ Val (AUC={ff_m['auc']:.4f})"),
    ('#e74c3c', cdf_m, f"Celeb-DF (AUC={cdf_m['auc']:.4f})"),
]:
    fpr, tpr, _ = roc_curve(m['labels'], m['probs'])
    axes[2].plot(fpr, tpr, color=color, linewidth=2, label=label)
axes[2].plot([0,1],[0,1],'k--', alpha=0.4, label='Random')
axes[2].set_title('ROC Curves'); axes[2].set_xlabel('FPR'); axes[2].set_ylabel('TPR')
axes[2].legend(fontsize=9); axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(PLOTS_DIR / 'step4_results.png', dpi=150, bbox_inches='tight')
plt.show()
print("âœ… step4_results.png")


## Section 7 â€” Save Results

In [None]:
results = {
    'model':           'EfficientNet-B0 + Bidirectional GRU (2 layers)',
    'n_frames':        CFG['n_frames'],
    'train_methods':   TRAIN_METHODS,
    'phase1_best_auc': round(best_auc_p1, 4),
    'phase2_best_auc': round(best_auc_p2, 4),
    'ff_val':          {'auc': round(ff_m['auc'],  4), 'acc': round(ff_m['acc'],  4)},
    'celeb_df':        {'auc': round(cdf_m['auc'], 4), 'acc': round(cdf_m['acc'], 4)},
    'step3_cdf_auc':   STEP3['cdf_auc'],
    'temporal_improvement': round(cdf_m['auc'] - STEP3['cdf_auc'], 4),
    'training_minutes': round(total_time/60, 1),
}

with open(OUTPUT_DIR / 'step4_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("="*60)
print("STEP 4 COMPLETE")
print("="*60)
print(f"  Frame-level baseline (Step 3): CDF AUC = {STEP3['cdf_auc']:.4f}")
print(f"  + Temporal GRU    (Step 4): CDF AUC = {cdf_m['auc']:.4f}")
print(f"  Temporal contribution       : {results['temporal_improvement']:+.4f}")
print()
print("Next: Step 5 â€” upgrade backbone to EfficientNet-B4")
print(f"âœ… Results â†’ {OUTPUT_DIR / 'step4_results.json'}")
