In [None]:
# Installs all of the libraries present in the 'offline-pytorch-2-1-2'
!pip install \
   --requirement /kaggle/input/offline-pytorch-2-1-2/requirements.txt \
   --no-index \
   --find-links file:///kaggle/input/offline-pytorch-2-1-2/wheels  \
--q

In [None]:
import torch
import platform
import subprocess

def check_suitability():
    print("="*60)
    print("  STF-Mamba V8.0 — Compatibility Test Report")
    print("="*60)
    
    # 1. PyTorch Version Check
    torch_version = torch.__version__.split('+')[0]
    print(f"1. PyTorch Version: {torch.__version__}")
    
    # Ideal versions are 2.4.0 or 2.1.0 for pre-compiled wheels
    if torch_version in ["2.4.0", "2.1.0"]:
        print("   ✓ EXCELLENT: This version has pre-compiled Mamba kernels available.")
    elif torch_version > "2.4.0":
        print("   ⚠ WARNING: Version too new. You will likely have to compile from source (30-60 min).")
    else:
        print("   ✗ ERROR: Version too old. Mamba 2.0+ requires PyTorch 2.1.0 minimum.")

    # 2. CUDA & GPU Check
    cuda_available = torch.cuda.is_available()
    print(f"\n2. CUDA Available: {cuda_available}")
    if cuda_available:
        gpu_name = torch.cuda.get_device_name(0)
        print(f"   ✓ GPU Device: {gpu_name}")
        # Mamba requires Ampere (A100/3060) or newer for best speed, 
        # but T4 (Kaggle) is supported via specific kernels.
    else:
        print("   ✗ ERROR: No GPU detected. Mamba cannot run on CPU.")

    # 3. Python Version
    python_version = platform.python_version()
    print(f"\n3. Python Version: {python_version}")
    if python_version.startswith("3.10"):
        print("   ✓ MATCH: Python 3.10 is the standard for Mamba wheels.")
    else:
        print(f"   ⚠ NOTE: Wheels are usually built for 3.10. You are on {python_version}.")

    # 4. Summary Verdict
    print("\n" + "="*60)
    if torch_version == "2.4.0" and cuda_available:
        print("  VERDICT: PERFECT SUITABILITY. Use the direct .whl links.")
    elif cuda_available:
        print("  VERDICT: SEMI-SUITABLE. You can run it, but expect long install times.")
    else:
        print("  VERDICT: NOT SUITABLE. Switch to a GPU-enabled notebook.")
    print("="*60)

check_suitability()

In [None]:
# ============================================================================
# Section 1: High-Speed Setup (PyTorch 2.1.2 + cu118)
# ============================================================================
import torch, os, sys
print(f"Verified Active Version: {torch.__version__}")

# Direct links for Python 3.10 + CUDA 11.8 + Torch 2.1
# These wheels bypass the 33-minute CPU compilation process.
print("Installing native Mamba kernels...")
!pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.1.3/causal_conv1d-1.1.3+cu118torch2.1cxx11abiFalse-cp310-cp310-linux_x86_64.whl -q
!pip install https://github.com/state-spaces/mamba/releases/download/v1.1.1/mamba_ssm-1.1.1+cu118torch2.1cxx11abiFalse-cp310-cp310-linux_x86_64.whl -q

# Install required vision and utility libraries
print("Installing vision utilities...")
!pip install dlib imutils albumentations einops -q

try:
    from mamba_ssm import Mamba
    from causal_conv1d import causal_conv1d_fn
    print("\n" + "="*40)
    print("✓ SUCCESS: Mamba Kernels Ready")
    print("="*40)
except Exception as e:
    print(f"\n⚠ Error: {e}")
    print("If failure persists, please ensure your GPU is set to T4.")

print("\n✓ Section 1 complete")

In [None]:
# ============================================================================
# Section 2: Clone Repo + Path Configuration
# ============================================================================
REPO_URL = "https://github.com/AbdelRahman-Madboly/STF-Mamba_V8.0.git"
REPO_DIR = "/kaggle/working/STF-Mamba_V8.0"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL} {REPO_DIR}
else:
    print(f"Repo already cloned at {REPO_DIR}")
    !cd {REPO_DIR} && git pull

sys.path.insert(0, REPO_DIR)
os.chdir(REPO_DIR)

from stf_mamba import STFMambaV8, STFMambaLoss, is_mamba_available
from data import SBIVideoDataset, get_train_transforms, get_val_transforms, load_all_splits
from training import Trainer, build_optimizer, build_scheduler

print(f"\nRepo: {REPO_DIR}")
print(f"Mamba SSM: {'native' if is_mamba_available() else 'Conv1d fallback'}")
print("✓ Section 2 complete — all imports OK")

In [None]:
# ============================================================================
# Section 3: Dataset Paths — Load Pre-Built Cache
# ============================================================================
from pathlib import Path

KAGGLE_INPUT = Path("/kaggle/input")

# --- Find pre-built cache dataset ---
CACHE_DATASET = None
for p in sorted(KAGGLE_INPUT.rglob("crops")):
    if p.is_dir() and any(f.name.endswith("_crops.npz") for f in p.iterdir()):
        CACHE_DATASET = str(p.parent)  # parent of crops/
        break

if CACHE_DATASET is None:
    # Fallback: search by dataset name
    for name in ["stf-mamba-v8-cache", "stf-cache", "stf_cache"]:
        for p in KAGGLE_INPUT.rglob(name):
            if p.is_dir():
                CACHE_DATASET = str(p)
                break

if CACHE_DATASET is None:
    raise RuntimeError(
        "Pre-built cache not found!\n"
        "Run the preprocessing notebook first and attach 'stf-mamba-v8-cache' dataset."
    )

# Verify cache contents
n_crops = len([f for f in os.listdir(os.path.join(CACHE_DATASET, "crops")) if f.endswith(".npz")])
sbi_dir = os.path.join(CACHE_DATASET, "sbi_seed42")
n_sbi = len([f for f in os.listdir(sbi_dir) if f.endswith(".npz")]) if os.path.isdir(sbi_dir) else 0
print(f"✓ Cache found: {CACHE_DATASET}")
print(f"  crops/     : {n_crops} files")
print(f"  sbi_seed42/: {n_sbi} files")

# Cache is read-only on Kaggle input, so we symlink or copy to working dir
CACHE_DIR = "/kaggle/working/cache"
if not os.path.exists(CACHE_DIR):
    os.symlink(CACHE_DATASET, CACHE_DIR)
    print(f"  Symlinked to: {CACHE_DIR}")

# FF++ video dir (needed only if SBI cache needs regeneration — shouldn't happen)
FF_VIDEO_DIR = "/kaggle/input/datasets/xdxd003/ff-c23/FaceForensics++_C23/original"

# Celeb-DF
CELEB_DF_ROOT = None
for p in KAGGLE_INPUT.rglob("*"):
    if p.is_dir() and "celeb" in p.name.lower():
        children = [c.name for c in p.iterdir() if c.is_dir()]
        if any("synth" in c.lower() for c in children):
            CELEB_DF_ROOT = str(p)
            break
print(f"  Celeb-DF: {CELEB_DF_ROOT or 'not found (cross-dataset eval skipped)'}")

# dlib predictor (for Celeb-DF eval only)
PREDICTOR_PATH = "/kaggle/working/shape_predictor_81_face_landmarks.dat"
if not os.path.exists(PREDICTOR_PATH):
    for p in KAGGLE_INPUT.rglob("shape_predictor_81_face_landmarks.dat"):
        os.system(f"cp '{p}' '{PREDICTOR_PATH}'")
        break

# Splits + checkpoint dirs
SPLITS_DIR = os.path.join(REPO_DIR, "splits")
CHECKPOINT_DIR = "/kaggle/working/checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

splits = load_all_splits(SPLITS_DIR)
print(f"  Splits: train={len(splits['train'])}, val={len(splits['val'])}, test={len(splits['test'])}")
print("\n✓ Section 3 complete")

In [None]:
# ============================================================================
# Section 4: Load Cached Datasets (instant — no face extraction!)
# ============================================================================
NUM_FRAMES = 32
IMG_SIZE = 224

train_tf = get_train_transforms(IMG_SIZE)
val_tf = get_val_transforms(IMG_SIZE)

train_ds = SBIVideoDataset(
    split_path=os.path.join(SPLITS_DIR, "Dataset_Split_train.json"),
    video_dir=FF_VIDEO_DIR,
    cache_dir=CACHE_DIR,
    phase="train",
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE,
    transform=train_tf,
    sbi_seed=42,
    predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
)

val_ds = SBIVideoDataset(
    split_path=os.path.join(SPLITS_DIR, "Dataset_Split_val.json"),
    video_dir=FF_VIDEO_DIR,
    cache_dir=CACHE_DIR,
    phase="val",
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE,
    transform=val_tf,
    sbi_seed=42,
    predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
)

# Quick sanity check
sample = train_ds[0]
print(f"Sanity check:")
print(f"  frames: {sample['frames'].shape}")
print(f"  label:  {sample['label']}")
print(f"  id:     {sample['video_id']}")
print(f"\nDataset sizes: train={len(train_ds)}, val={len(val_ds)}")
print("\n✓ Section 4 complete — loaded from cache (no preprocessing needed)")

In [None]:
# ============================================================================
# Section 5: Model Init + Param Count + Forward Pass Verify
# ============================================================================
import torch
import torch.nn as nn
from stf_mamba import STFMambaV8

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Build model — Force d_conv=4 for compatibility with Mamba CUDA kernels
print("\nLoading DINOv2-ViT-B/14 + Hydra-Mamba...")
# d_conv MUST be 2, 3, or 4 to avoid 'causal_conv1d' RuntimeError
model = STFMambaV8(pretrained_backbone=True, d_conv=4)

# Param count
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nParameters:")
print(f"  Total:     {total / 1e6:.1f}M")
print(f"  Trainable: {trainable / 1e6:.1f}M")
print(f"  Frozen:    { (total - trainable) / 1e6:.1f}M")

# Forward pass verify
print(f"\nForward pass test...")
model.eval()
model = model.to(device)

# Batch format: [Batch, Frames, Channels, Height, Width]
x_test = torch.randn(1, 32, 3, 224, 224).to(device)
with torch.no_grad():
    out = model(x_test)

print(f"  Logits:   {out['logits'].shape}  ← expected (1, 2)")
print(f"  Variance: {out['variance'].shape} ← expected (1, 1)")

# Verification asserts
assert out['logits'].shape == (1, 2), f"Logits shape mismatch: {out['logits'].shape}"
assert out['variance'].shape == (1, 1), f"Variance shape mismatch: {out['variance'].shape}"

# Cleanup to preserve T4 memory for training
del x_test, out
torch.cuda.empty_cache()
model = model.cpu()
print("\n✓ Section 5 complete — Forward pass successful with d_conv=4")

In [None]:
# ============================================================================
# Section 6: Training Loop — all fixes applied:
#   1. Clear module cache → loads patched consistency_head.py
#   2. lambda_var=1.0, lr_backbone=1e-5
#   3. Warmup sets LR from base values (not compound-multiplies)
#   4. num_workers=2 + fork (faster loading)
#   5. var_gap tracked in history
# ============================================================================
import sys, gc, time, random, os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

# ── Clear module cache so patched consistency_head.py is loaded ──────────────
for key in list(sys.modules.keys()):
    if "stf_mamba" in key:
        del sys.modules[key]

from stf_mamba import STFMambaV8, STFMambaLoss
print("✓ Loaded patched stf_mamba")

# ── Reproducibility ───────────────────────────────────────────────────────────
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# ── Config ────────────────────────────────────────────────────────────────────
config = {
    "epochs":        25,
    "batch_size":    4,
    "lr_backbone":   1e-5,   # was 5e-6
    "lr_temporal":   1e-4,
    "lr_head":       1e-4,
    "weight_decay":  1e-4,
    "warmup_epochs": 3,
    "grad_clip":     1.0,
    "lambda_var":    1.0,    # was 0.1 — variance was dead
}

# ── Clear GPU ─────────────────────────────────────────────────────────────────
try: del model, trainer
except: pass
torch.cuda.empty_cache(); gc.collect()

# ── DataLoaders ───────────────────────────────────────────────────────────────
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=config["batch_size"], shuffle=True,
    num_workers=2, multiprocessing_context="fork",
    pin_memory=True, drop_last=True,
    collate_fn=SBIVideoDataset.collate_fn,
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=config["batch_size"], shuffle=False,
    num_workers=2, multiprocessing_context="fork",
    pin_memory=True, drop_last=False,
    collate_fn=SBIVideoDataset.collate_fn,
)

# ── Model ─────────────────────────────────────────────────────────────────────
device = torch.device("cuda:0")
model  = STFMambaV8(pretrained_backbone=True, d_conv=4).to(device)

# ── Criterion ─────────────────────────────────────────────────────────────────
criterion = STFMambaLoss(lambda_var=config["lambda_var"])

# ── Optimizer: differential LR ────────────────────────────────────────────────
backbone_params, temporal_params, head_params = [], [], []
for name, p in model.named_parameters():
    if not p.requires_grad: continue
    if "backbone" in name or "dinov2" in name: backbone_params.append(p)
    elif "mamba" in name or "temporal" in name: temporal_params.append(p)
    else: head_params.append(p)

optimizer = AdamW([
    {"params": backbone_params, "lr": config["lr_backbone"]},
    {"params": temporal_params, "lr": config["lr_temporal"]},
    {"params": head_params,     "lr": config["lr_head"]},
], weight_decay=config["weight_decay"])

scheduler = CosineAnnealingLR(optimizer, T_max=config["epochs"], eta_min=1e-7)

# ── History ───────────────────────────────────────────────────────────────────
history = {"train_loss": [], "train_acc": [], "val_loss": [],
           "val_auc": [], "val_acc": [], "var_gap": []}
best_auc   = 0.0
best_epoch = 0

print("=" * 65)
print(f"  STF-Mamba V8.0 — Training ({config['epochs']} epochs, batch={config['batch_size']})")
print(f"  Train: {len(train_ds)} | Val: {len(val_ds)} | Device: {device}")
print(f"  lambda_var={config['lambda_var']} | lr_backbone={config['lr_backbone']}")
print("=" * 65)

for epoch in range(1, config["epochs"] + 1):
    t0 = time.time()

    # ── Warmup: set LR from base values each epoch (not compound-multiply) ────
    if epoch <= config["warmup_epochs"]:
        wf = epoch / config["warmup_epochs"]
        optimizer.param_groups[0]["lr"] = config["lr_backbone"] * wf
        optimizer.param_groups[1]["lr"] = config["lr_temporal"] * wf
        optimizer.param_groups[2]["lr"] = config["lr_head"]     * wf

    # ── TRAIN ─────────────────────────────────────────────────────────────────
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    pbar = tqdm(train_loader, desc=f"Ep {epoch:02d}/{config['epochs']} [Train]",
                leave=True, ncols=120)
    for batch in pbar:
        frames = batch["frames"].to(device, non_blocking=True)
        labels = batch["label"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        out       = model(frames)
        loss_dict = criterion(out["logits"], labels, out["variance"])
        loss      = loss_dict["total"]
        loss.backward()

        if config["grad_clip"] > 0:
            nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
        optimizer.step()

        bs = labels.size(0)
        running_loss += loss.item() * bs
        preds         = out["logits"].argmax(dim=1)
        correct      += (preds == labels).sum().item()
        total        += bs

        pbar.set_postfix({
            "loss": f"{loss_dict['total'].item():.4f}",
            "ce":   f"{loss_dict['ce'].item():.4f}",
            "var":  f"{loss_dict['var'].item():.4f}",
            "acc":  f"{correct/total:.3f}",
            "lr":   f"{optimizer.param_groups[1]['lr']:.1e}",
        })
    train_loss = running_loss / total
    train_acc  = correct / total
    pbar.close()

    # ── VALIDATE ──────────────────────────────────────────────────────────────
    model.eval()
    val_loss_sum = val_correct = val_total = 0
    val_var_real, val_var_fake = [], []
    all_probs, all_labels = [], []

    with torch.no_grad():
        vbar = tqdm(val_loader, desc=f"Ep {epoch:02d}/{config['epochs']} [ Val ]",
                    leave=True, ncols=120)
        for batch in vbar:
            frames = batch["frames"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)

            out       = model(frames)
            loss_dict = criterion(out["logits"], labels, out["variance"])
            loss      = loss_dict["total"]

            bs = labels.size(0)
            val_loss_sum += loss.item() * bs
            probs  = torch.softmax(out["logits"], dim=1)[:, 1]
            preds  = out["logits"].argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total   += bs
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            for v, lbl in zip(out["variance"].cpu().numpy().flatten(), labels.cpu().numpy()):
                (val_var_real if lbl == 0 else val_var_fake).append(float(v))
            vbar.set_postfix({"loss": f"{loss.item():.4f}"})
        vbar.close()

    val_loss = val_loss_sum / val_total
    val_acc  = val_correct / val_total
    try:    val_auc = roc_auc_score(all_labels, all_probs)
    except: val_auc = 0.5

    _vr     = float(np.mean(val_var_real)) if val_var_real else 0.0
    _vf     = float(np.mean(val_var_fake)) if val_var_fake else 0.0
    var_gap = _vf - _vr

    scheduler.step()

    # ── Checkpoint ────────────────────────────────────────────────────────────
    if val_auc > best_auc:
        best_auc   = val_auc
        best_epoch = epoch
        torch.save({
            "epoch":                epoch,
            "model_state_dict":     model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_auc":              val_auc,
        }, os.path.join(CHECKPOINT_DIR, "best.pth"))

    # ── History ───────────────────────────────────────────────────────────────
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_auc"].append(val_auc)
    history["val_acc"].append(val_acc)
    history["var_gap"].append(var_gap)

    elapsed = time.time() - t0
    print(f"  Ep {epoch:02d} | TrLoss {train_loss:.4f} | TrAcc {train_acc:.3f} "
          f"| VaLoss {val_loss:.4f} | VaAUC {val_auc:.4f} | VaAcc {val_acc:.3f} "
          f"| VarGap {var_gap:+.4f} | {elapsed:.0f}s {'★ best' if epoch == best_epoch else ''}")

print(f"\n✓ Done! Best Val AUC: {best_auc:.4f} at epoch {best_epoch}")
print(f"  Checkpoint: {CHECKPOINT_DIR}/best.pth")

In [None]:
# ============================================================================
# Section 6b + 7 + 8: Curves, Evaluation, Variance Visualization
# ============================================================================
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
from stf_mamba import STFMambaV8

# ── Training Curves ───────────────────────────────────────────────────────────
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
ep = range(1, len(history["train_loss"]) + 1)

axes[0, 0].plot(ep, history["train_loss"], "b-", label="Train")
axes[0, 0].plot(ep, history["val_loss"],   "r-", label="Val")
axes[0, 0].set_title("Loss"); axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(ep, history["train_acc"], "b-", label="Train")
axes[0, 1].plot(ep, history["val_acc"],   "r-", label="Val")
axes[0, 1].set_title("Accuracy"); axes[0, 1].legend(); axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(ep, history["val_auc"], "g-", marker="o", markersize=3)
axes[1, 0].axhline(y=0.75, color="orange", linestyle="--", alpha=0.7, label="Stage 5 (0.75)")
axes[1, 0].axhline(y=0.90, color="red",    linestyle="--", alpha=0.7, label="Paper (0.90)")
axes[1, 0].set_title("Val AUC"); axes[1, 0].set_ylim([0.4, 1.0])
axes[1, 0].legend(); axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(ep, history["var_gap"], "m-", marker="o", markersize=3)
axes[1, 1].axhline(y=0, color="gray", linestyle="--", alpha=0.5)
axes[1, 1].set_title("Variance Gap (fake − real)"); axes[1, 1].grid(True, alpha=0.3)

fig.suptitle("STF-Mamba V8.0 — Kaggle Training (25 epochs)", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, "training_curves.png"), dpi=150, bbox_inches="tight")
plt.show()
print("✓ Training curves saved")

# ── Section 7a: FF++ Val AUC (reload best checkpoint) ────────────────────────
best_path = os.path.join(CHECKPOINT_DIR, "best.pth")
print(f"\nLoading best checkpoint: {best_path}")
ckpt = torch.load(best_path, map_location=device)

# d_conv=4 must match what was used during training
eval_model = STFMambaV8(pretrained_backbone=True, d_conv=4).to(device)
eval_model.load_state_dict(ckpt["model_state_dict"])
eval_model.eval()

all_probs, all_labels, all_variances = [], [], []
with torch.no_grad():
    for batch in val_loader:
        frames = batch["frames"].to(device)
        labels = batch["label"]
        out = eval_model(frames)
        probs = torch.softmax(out["logits"], dim=1)[:, 1].cpu().numpy()
        var   = out["variance"].cpu().numpy().flatten()
        all_probs.extend(probs)
        all_labels.extend(labels.numpy())
        all_variances.extend(var)

ff_val_auc = roc_auc_score(all_labels, all_probs)
ff_val_acc = accuracy_score(all_labels, [1 if p > 0.5 else 0 for p in all_probs])
all_labels_np = np.array(all_labels)
all_var_np    = np.array(all_variances)
var_real = all_var_np[all_labels_np == 0].mean()
var_fake = all_var_np[all_labels_np == 1].mean()

print(f"  FF++ Val AUC : {ff_val_auc:.4f}")
print(f"  FF++ Val Acc : {ff_val_acc:.4f}")
print(f"  Variance gap : {var_fake - var_real:+.6f}")

# ── Section 7b: Celeb-DF Cross-Dataset ───────────────────────────────────────
import cv2
from tqdm import tqdm
from data.preprocessing import FacePreprocessor
from data.augmentation import apply_transform_to_clip

cdf_auc = None
if CELEB_DF_ROOT and os.path.isdir(CELEB_DF_ROOT):
    print("\n--- Celeb-DF v2 Cross-Dataset ---")

    cdf_real_dirs = [os.path.join(CELEB_DF_ROOT, d) for d in ["Celeb-real", "YouTube-real", "celeb_real"]]
    cdf_fake_dirs = [os.path.join(CELEB_DF_ROOT, d) for d in ["Celeb-synthesis", "celeb_synthesis"]]

    def find_videos(dir_list):
        videos = []
        for d in dir_list:
            if os.path.isdir(d):
                for f in sorted(os.listdir(d)):
                    if f.lower().endswith((".mp4", ".avi")):
                        videos.append(os.path.join(d, f))
        return videos

    cdf_real = find_videos(cdf_real_dirs)
    cdf_fake = find_videos(cdf_fake_dirs)
    print(f"  Real: {len(cdf_real)}, Fake: {len(cdf_fake)}")

    if cdf_real and cdf_fake:
        cdf_preprocessor = FacePreprocessor(
            video_dir=CELEB_DF_ROOT,
            cache_dir="/kaggle/working/cdf_cache",
            num_frames=NUM_FRAMES, img_size=IMG_SIZE,
            predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
        )

        def eval_videos(video_paths, label):
            probs = []
            for vpath in tqdm(video_paths[:200], desc=f"label={label}"):
                vid_id = os.path.splitext(os.path.basename(vpath))[0]
                try:
                    crops, _ = cdf_preprocessor.get_video(vid_id)
                    frames_t = apply_transform_to_clip(crops[:NUM_FRAMES], val_tf)
                    frames_t = frames_t.unsqueeze(0).to(device)
                    with torch.no_grad():
                        out = eval_model(frames_t)
                        prob = torch.softmax(out["logits"], dim=1)[0, 1].item()
                    probs.append(prob)
                except Exception:
                    continue
            return probs

        real_probs = eval_videos(cdf_real, 0)
        fake_probs = eval_videos(cdf_fake, 1)

        if real_probs and fake_probs:
            cdf_labels = [0] * len(real_probs) + [1] * len(fake_probs)
            cdf_probs  = real_probs + fake_probs
            cdf_auc    = roc_auc_score(cdf_labels, cdf_probs)
            print(f"  Celeb-DF AUC: {cdf_auc:.4f}")
else:
    print("\nCeleb-DF not found — skipping cross-dataset eval")

# ── Section 7c: Results Summary ───────────────────────────────────────────────
print("\n" + "=" * 60)
print("  RESULTS SUMMARY — STF-Mamba V8.0 (Kaggle 25 epochs)")
print("=" * 60)
print(f"\n{'Metric':<30} {'Value':>10}")
print("-" * 42)
print(f"{'FF++ Val AUC (SBI)':<30} {ff_val_auc:>10.4f}")
print(f"{'FF++ Val Acc':<30} {ff_val_acc:>10.4f}")
if cdf_auc:
    print(f"{'Celeb-DF AUC':<30} {cdf_auc:>10.4f}")
print(f"{'Variance gap (fake-real)':<30} {var_fake - var_real:>+10.6f}")
print(f"{'Best epoch':<30} {ckpt['epoch']:>10d}")

print(f"\n--- Comparison to Baselines ---")
print(f"{'Model':<35} {'FF++ Val':>10} {'CDF':>10}")
print("-" * 57)
print(f"{'B0 frame-level (Step 3)':<35} {'0.6850':>10} {'0.6135':>10}")
print(f"{'B0 + GRU temporal (Step 4)':<35} {'0.5954':>10} {'0.5524':>10}")
print(f"{'SBI reference (EffNet-B4)':<35} {'—':>10} {'0.9382':>10}")
cdf_str = f"{cdf_auc:.4f}" if cdf_auc else "—"
print(f"{'V8.0 (this run)':<35} {ff_val_auc:>10.4f} {cdf_str:>10}")

print(f"\n--- Stage 5 Exit Criteria ---")
if cdf_auc and cdf_auc > 0.75:
    print(f"  ✓ CDF AUC {cdf_auc:.4f} > 0.75")
elif cdf_auc:
    print(f"  ✗ CDF AUC {cdf_auc:.4f} < 0.75")
else:
    print(f"  ? CDF AUC not measured")

if var_fake - var_real > 0:
    print(f"  ✓ Variance gap positive ({var_fake - var_real:+.6f})")
else:
    print(f"  ✗ Variance gap not positive")

last5 = history["val_loss"][-5:]
if len(last5) >= 5 and last5[-1] <= last5[0] * 1.2:
    print(f"  ✓ No overfitting (val loss stable)")

print("\n✓ Section 7 complete")

# ── Section 8: Variance / Similarity Visualization ───────────────────────────
n_samples = min(10, len(val_ds) // 2)
real_sims, fake_sims = [], []
real_vars, fake_vars = [], []

eval_model.eval()
with torch.no_grad():
    for i in range(min(n_samples * 2, len(val_ds))):
        sample = val_ds[i]
        frames = sample["frames"].unsqueeze(0).to(device)
        out = eval_model(frames)
        sims = out["similarities"][0].cpu().numpy()
        var  = out["variance"][0].item()
        if sample["label"] == 0:
            real_sims.append(sims); real_vars.append(var)
        else:
            fake_sims.append(sims); fake_vars.append(var)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))
for i, s in enumerate(real_sims[:5]):
    axes[0].plot(s, alpha=0.6, label=f"Real {i}")
axes[0].set_title(f"Real — Mean σ²={np.mean(real_vars):.6f}")
axes[0].set_ylim([0.5, 1.05]); axes[0].legend(fontsize=8); axes[0].grid(True, alpha=0.3)

for i, s in enumerate(fake_sims[:5]):
    axes[1].plot(s, alpha=0.6, label=f"Fake {i}")
axes[1].set_title(f"Fake — Mean σ²={np.mean(fake_vars):.6f}")
axes[1].set_ylim([0.5, 1.05]); axes[1].legend(fontsize=8); axes[1].grid(True, alpha=0.3)

fig.suptitle("STF-Mamba V8.0 — Identity Consistency Signal", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, "similarity_traces.png"), dpi=150, bbox_inches="tight")
plt.show()

print(f"\nVariance Statistics:")
print(f"  Real: {np.mean(real_vars):.6f} ± {np.std(real_vars):.6f}")
print(f"  Fake: {np.mean(fake_vars):.6f} ± {np.std(fake_vars):.6f}")
print(f"  Gap:  {np.mean(fake_vars) - np.mean(real_vars):+.6f}")
print("\n✓ Section 8 complete")

print("\n" + "=" * 60)
print("  STF-Mamba V8.0 — Training Complete")
print("=" * 60)
print(f"  Best FF++ Val AUC : {ff_val_auc:.4f}")
if cdf_auc:
    print(f"  Celeb-DF AUC      : {cdf_auc:.4f}")
print(f"  Next: Stage 6 — RunPod A100 full 50-epoch training")
