# STF-Mamba V8.0 — Kaggle Training Notebook

**Semantic Temporal Forensics via Hydra-Mamba and DINOv2**

Target: CVPR/ICCV 2026

**Settings:** GPU T4 x2, Internet ON

**Required datasets:**
- `stf-mamba-v8-cache` (pre-built face crops + SBI cache from preprocessing notebook)
- `celeb-df-v2` (for cross-dataset evaluation)
- `shape-predictor81` (dlib landmarks for Celeb-DF eval)

## Section 1: Setup + GPU Check + Install Dependencies

In [None]:
# # ============================================================================
# # Section 1: Setup + GPU Check + Install Dependencies
# # ============================================================================
# import os, sys, subprocess, time

# # GPU check
# print("=" * 60)
# print("  STF-Mamba V8.0 — Kaggle Training")
# print("=" * 60)
# !nvidia-smi --query-gpu=name,memory.total --format=csv,noheader

# import torch
# print(f"\nPyTorch: {torch.__version__}")
# print(f"CUDA: {torch.cuda.is_available()}, GPUs: {torch.cuda.device_count()}")
# for i in range(torch.cuda.device_count()):
#     print(f"  GPU {i}: {torch.cuda.get_device_name(i)} ({torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB)")

# # Install mamba_ssm — try import first, else compile
# print("\n--- Mamba SSM ---")
# MAMBA_OK = False
# try:
#     from mamba_ssm import Mamba
#     print(f"✓ mamba_ssm already installed")
#     MAMBA_OK = True
# except ImportError:
#     print("Compiling mamba_ssm from source (~20-30 min)...")
#     print("(CPU will be at 400% — this is normal)")
#     os.environ["MAX_JOBS"] = "2"
#     !pip install causal-conv1d --no-build-isolation -q 2>&1 | tail -3
#     !pip install mamba-ssm --no-build-isolation -q 2>&1 | tail -3
#     try:
#         from mamba_ssm import Mamba
#         print(f"✓ mamba_ssm compiled successfully")
#         MAMBA_OK = True
#     except:
#         print("⚠ mamba_ssm failed — using Conv1d fallback")

# !pip install dlib imutils albumentations einops -q
# print("\n✓ Section 1 complete")

In [None]:
# 1.1 GPU Check
import os, sys, torch
print("=" * 60)
print("  STF-Mamba V8.0 — GPU Check")
print("=" * 60)
!nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
print(f"PyTorch: {torch.__version__} | CUDA: {torch.cuda.is_available()}")

In [None]:
# 1.2 Fast CUDA Kernel Build
print("Installing causal-conv1d (Fast kernels)...")
os.environ["MAX_JOBS"] = "2"
!pip install causal-conv1d --no-build-isolation -q
print("✓ causal-conv1d installed")

In [None]:
# 1.3 Heavy Mamba Compilation
print("Compiling mamba-ssm from source")
!pip install mamba-ssm --no-build-isolation -q
try:
    from mamba_ssm import Mamba
    print("✓ mamba_ssm compiled successfully")
except ImportError:
    print("⚠ Error: mamba_ssm failed to compile.")

In [None]:
# 1.4 CV Dependencies
print("Installing vision utilities...")
!pip install dlib imutils albumentations einops -q
print("\n✓ Section 1 complete — Environment is Ready")

## Section 2: Clone Repo + Path Configuration

In [None]:
# ============================================================================
# Section 2: Clone Repo + Path Configuration
# ============================================================================
REPO_URL = "https://github.com/AbdelRahman-Madboly/STF-Mamba_V8.0.git"
REPO_DIR = "/kaggle/working/STF-Mamba_V8.0"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL} {REPO_DIR}
else:
    print(f"Repo already cloned at {REPO_DIR}")
    !cd {REPO_DIR} && git pull

sys.path.insert(0, REPO_DIR)
os.chdir(REPO_DIR)

from stf_mamba import STFMambaV8, STFMambaLoss, is_mamba_available
from data import SBIVideoDataset, get_train_transforms, get_val_transforms, load_all_splits
from training import Trainer, build_optimizer, build_scheduler

print(f"\nRepo: {REPO_DIR}")
print(f"Mamba SSM: {'native' if is_mamba_available() else 'Conv1d fallback'}")
print("✓ Section 2 complete — all imports OK")

## Section 3: Dataset Paths

**Key change:** Cache is loaded from pre-built dataset, NOT built on the fly.

In [None]:
# ============================================================================
# Section 3: Dataset Paths — Load Pre-Built Cache
# ============================================================================
from pathlib import Path

KAGGLE_INPUT = Path("/kaggle/input")

# --- Find pre-built cache dataset ---
CACHE_DATASET = None
for p in sorted(KAGGLE_INPUT.rglob("crops")):
    if p.is_dir() and any(f.name.endswith("_crops.npz") for f in p.iterdir()):
        CACHE_DATASET = str(p.parent)  # parent of crops/
        break

if CACHE_DATASET is None:
    # Fallback: search by dataset name
    for name in ["stf-mamba-v8-cache", "stf-cache", "stf_cache"]:
        for p in KAGGLE_INPUT.rglob(name):
            if p.is_dir():
                CACHE_DATASET = str(p)
                break

if CACHE_DATASET is None:
    raise RuntimeError(
        "Pre-built cache not found!\n"
        "Run the preprocessing notebook first and attach 'stf-mamba-v8-cache' dataset."
    )

# Verify cache contents
n_crops = len([f for f in os.listdir(os.path.join(CACHE_DATASET, "crops")) if f.endswith(".npz")])
sbi_dir = os.path.join(CACHE_DATASET, "sbi_seed42")
n_sbi = len([f for f in os.listdir(sbi_dir) if f.endswith(".npz")]) if os.path.isdir(sbi_dir) else 0
print(f"✓ Cache found: {CACHE_DATASET}")
print(f"  crops/     : {n_crops} files")
print(f"  sbi_seed42/: {n_sbi} files")

# Cache is read-only on Kaggle input, so we symlink or copy to working dir
CACHE_DIR = "/kaggle/working/cache"
if not os.path.exists(CACHE_DIR):
    os.symlink(CACHE_DATASET, CACHE_DIR)
    print(f"  Symlinked to: {CACHE_DIR}")

# FF++ video dir (needed only if SBI cache needs regeneration — shouldn't happen)
FF_VIDEO_DIR = "/kaggle/input/datasets/xdxd003/ff-c23/FaceForensics++_C23/original"

# Celeb-DF
CELEB_DF_ROOT = None
for p in KAGGLE_INPUT.rglob("*"):
    if p.is_dir() and "celeb" in p.name.lower():
        children = [c.name for c in p.iterdir() if c.is_dir()]
        if any("synth" in c.lower() for c in children):
            CELEB_DF_ROOT = str(p)
            break
print(f"  Celeb-DF: {CELEB_DF_ROOT or 'not found (cross-dataset eval skipped)'}")

# dlib predictor (for Celeb-DF eval only)
PREDICTOR_PATH = "/kaggle/working/shape_predictor_81_face_landmarks.dat"
if not os.path.exists(PREDICTOR_PATH):
    for p in KAGGLE_INPUT.rglob("shape_predictor_81_face_landmarks.dat"):
        os.system(f"cp '{p}' '{PREDICTOR_PATH}'")
        break

# Splits + checkpoint dirs
SPLITS_DIR = os.path.join(REPO_DIR, "splits")
CHECKPOINT_DIR = "/kaggle/working/checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

splits = load_all_splits(SPLITS_DIR)
print(f"  Splits: train={len(splits['train'])}, val={len(splits['val'])}, test={len(splits['test'])}")
print("\n✓ Section 3 complete")

## Section 4: Load Cached Datasets

**No preprocessing needed** — everything loads from pre-built NPZ files.

In [None]:
# ============================================================================
# Section 4: Load Cached Datasets (instant — no face extraction!)
# ============================================================================
NUM_FRAMES = 32
IMG_SIZE = 224

train_tf = get_train_transforms(IMG_SIZE)
val_tf = get_val_transforms(IMG_SIZE)

train_ds = SBIVideoDataset(
    split_path=os.path.join(SPLITS_DIR, "Dataset_Split_train.json"),
    video_dir=FF_VIDEO_DIR,
    cache_dir=CACHE_DIR,
    phase="train",
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE,
    transform=train_tf,
    sbi_seed=42,
    predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
)

val_ds = SBIVideoDataset(
    split_path=os.path.join(SPLITS_DIR, "Dataset_Split_val.json"),
    video_dir=FF_VIDEO_DIR,
    cache_dir=CACHE_DIR,
    phase="val",
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE,
    transform=val_tf,
    sbi_seed=42,
    predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
)

# Quick sanity check
sample = train_ds[0]
print(f"Sanity check:")
print(f"  frames: {sample['frames'].shape}")
print(f"  label:  {sample['label']}")
print(f"  id:     {sample['video_id']}")
print(f"\nDataset sizes: train={len(train_ds)}, val={len(val_ds)}")
print("\n✓ Section 4 complete — loaded from cache (no preprocessing needed)")

## Section 5: Model Init + Param Count + Forward Pass Verify

In [None]:
# ============================================================================
# Section 5: Model Init + Param Count + Forward Pass Verify
# ============================================================================
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Build model — use default d_conv from config
print("\nLoading DINOv2-ViT-B/14...")
model = STFMambaV8(pretrained_backbone=True)

# Param count
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nParameters:")
print(f"  Total:     {total / 1e6:.1f}M")
print(f"  Trainable: {trainable / 1e6:.1f}M")
print(f"  Frozen:    {(total - trainable) / 1e6:.1f}M")

# Forward pass verify
print(f"\nForward pass test...")
model.eval()
model = model.to(device)
x_test = torch.randn(1, 32, 3, 224, 224).to(device)
with torch.no_grad():
    out = model(x_test)
print(f"  Logits:   {out['logits'].shape}  ← expected (1, 2)")
print(f"  Variance: {out['variance'].shape} ← expected (1, 1)")
assert out['logits'].shape == (1, 2)
assert out['variance'].shape == (1, 1)

del x_test, out
torch.cuda.empty_cache()
model = model.cpu()
print("\n✓ Section 5 complete")

## Section 6: Training (25 Epochs)

Config: `v8_kaggle.yaml` — batch=8, 32 frames, differential LR, cosine schedule

In [None]:
# ============================================================================
# Section 6: Training Loop
# ============================================================================
import numpy as np
import random

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Config
config = {
    "epochs": 25,
    "batch_size": 8,
    "lr_backbone": 5e-6,
    "lr_temporal": 1e-4,
    "lr_head": 1e-4,
    "weight_decay": 1e-4,
    "warmup_epochs": 3,
    "grad_clip": 1.0,
    "lambda_var": 0.1,
    "label_smoothing": 0.0,  # CRITICAL: Bug #1 — never > 0 for K=2
}

train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=config["batch_size"], shuffle=True,
    num_workers=0, collate_fn=SBIVideoDataset.collate_fn,
    pin_memory=True, drop_last=True,
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=config["batch_size"], shuffle=False,
    num_workers=0, collate_fn=SBIVideoDataset.collate_fn,
    pin_memory=True,
)

criterion = STFMambaLoss(lambda_var=config["lambda_var"])
model = STFMambaV8(pretrained_backbone=True)

trainer = Trainer(
    model=model, criterion=criterion,
    train_loader=train_loader, val_loader=val_loader,
    config=config, save_dir=CHECKPOINT_DIR, device=device,
)

print("\n" + "=" * 60)
print(f"  Starting STF-Mamba V8.0 Training: {config['epochs']} Epochs")
print("=" * 60)

history = trainer.train(num_epochs=config["epochs"])
print(f"\nTraining Complete! Best Val AUC: {trainer.best_val_auc:.4f}")
print("✓ Section 6 complete")

In [None]:
# --- 6b: Training Curves ---
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs = range(1, len(history['train_loss']) + 1)

axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train')
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val')
axes[0, 0].set_title('Loss'); axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train')
axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val')
axes[0, 1].set_title('Accuracy'); axes[0, 1].legend(); axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(epochs, history['val_auc'], 'g-', marker='o', markersize=3)
axes[1, 0].axhline(y=0.75, color='orange', linestyle='--', alpha=0.7, label='Stage 5 (0.75)')
axes[1, 0].axhline(y=0.90, color='red', linestyle='--', alpha=0.7, label='Paper (0.90)')
axes[1, 0].set_title('Val AUC'); axes[1, 0].set_ylim([0.4, 1.0]); axes[1, 0].legend(); axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(epochs, history['var_gap'], 'm-', marker='o', markersize=3)
axes[1, 1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[1, 1].set_title('Variance Gap'); axes[1, 1].grid(True, alpha=0.3)

fig.suptitle('STF-Mamba V8.0 — Kaggle Training (25 epochs)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, 'training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

## Section 7: Evaluation — FF++ Val + Celeb-DF AUC

In [None]:
# ============================================================================
# Section 7a: FF++ Val AUC
# ============================================================================
from sklearn.metrics import roc_auc_score, accuracy_score

best_path = os.path.join(CHECKPOINT_DIR, "best.pth")
print(f"Loading best checkpoint: {best_path}")
ckpt = torch.load(best_path, map_location=device)

eval_model = STFMambaV8(pretrained_backbone=True).to(device)
eval_model.load_state_dict(ckpt['model_state_dict'])
eval_model.eval()

all_probs, all_labels, all_variances = [], [], []
with torch.no_grad():
    for batch in val_loader:
        frames = batch['frames'].to(device)
        labels = batch['label']
        out = eval_model(frames)
        probs = torch.softmax(out['logits'], dim=1)[:, 1].cpu().numpy()
        var = out['variance'].cpu().numpy().flatten()
        all_probs.extend(probs)
        all_labels.extend(labels.numpy())
        all_variances.extend(var)

ff_val_auc = roc_auc_score(all_labels, all_probs)
ff_val_acc = accuracy_score(all_labels, [1 if p > 0.5 else 0 for p in all_probs])
all_labels_np = np.array(all_labels)
all_var_np = np.array(all_variances)
var_real = all_var_np[all_labels_np == 0].mean()
var_fake = all_var_np[all_labels_np == 1].mean()

print(f"\nFF++ Val AUC: {ff_val_auc:.4f}")
print(f"FF++ Val Acc: {ff_val_acc:.4f}")
print(f"Variance gap: {var_fake - var_real:+.6f}")

In [None]:
# ============================================================================
# Section 7b: Celeb-DF Cross-Dataset Evaluation
# ============================================================================
import cv2
from data.preprocessing import FacePreprocessor
from data.augmentation import apply_transform_to_clip

cdf_auc = None

if CELEB_DF_ROOT and os.path.isdir(CELEB_DF_ROOT):
    print("--- Celeb-DF v2 Cross-Dataset ---")
    
    cdf_real_dirs = [os.path.join(CELEB_DF_ROOT, d) for d in ["Celeb-real", "YouTube-real", "celeb_real"]]
    cdf_fake_dirs = [os.path.join(CELEB_DF_ROOT, d) for d in ["Celeb-synthesis", "celeb_synthesis"]]

    def find_videos(dir_list):
        videos = []
        for d in dir_list:
            if os.path.isdir(d):
                for f in sorted(os.listdir(d)):
                    if f.lower().endswith(('.mp4', '.avi')):
                        videos.append(os.path.join(d, f))
        return videos

    cdf_real = find_videos(cdf_real_dirs)
    cdf_fake = find_videos(cdf_fake_dirs)
    print(f"  Real: {len(cdf_real)}, Fake: {len(cdf_fake)}")

    if cdf_real and cdf_fake:
        from tqdm import tqdm
        cdf_preprocessor = FacePreprocessor(
            video_dir=CELEB_DF_ROOT,
            cache_dir="/kaggle/working/cdf_cache",
            num_frames=NUM_FRAMES, img_size=IMG_SIZE,
            predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
        )

        def eval_videos(video_paths, label):
            probs = []
            for vpath in tqdm(video_paths[:200], desc=f"label={label}"):
                vid_id = os.path.splitext(os.path.basename(vpath))[0]
                try:
                    crops, _ = cdf_preprocessor.get_video(vid_id)
                    frames_t = apply_transform_to_clip(crops[:NUM_FRAMES], val_tf)
                    frames_t = frames_t.unsqueeze(0).to(device)
                    with torch.no_grad():
                        out = eval_model(frames_t)
                        prob = torch.softmax(out['logits'], dim=1)[0, 1].item()
                    probs.append(prob)
                except:
                    continue
            return probs

        real_probs = eval_videos(cdf_real, 0)
        fake_probs = eval_videos(cdf_fake, 1)
        
        if real_probs and fake_probs:
            cdf_labels = [0]*len(real_probs) + [1]*len(fake_probs)
            cdf_probs = real_probs + fake_probs
            cdf_auc = roc_auc_score(cdf_labels, cdf_probs)
            print(f"\n  Celeb-DF AUC: {cdf_auc:.4f}")
else:
    print("Celeb-DF not found — skipping")

In [None]:
# ============================================================================
# Section 7c: Results Summary
# ============================================================================
print("\n" + "=" * 60)
print("  RESULTS SUMMARY — STF-Mamba V8.0 (Kaggle 25 epochs)")
print("=" * 60)
print(f"\n{'Metric':<30} {'Value':>10}")
print("-" * 42)
print(f"{'FF++ Val AUC (SBI)':<30} {ff_val_auc:>10.4f}")
print(f"{'FF++ Val Acc':<30} {ff_val_acc:>10.4f}")
if cdf_auc: print(f"{'Celeb-DF AUC':<30} {cdf_auc:>10.4f}")
print(f"{'Variance gap (fake-real)':<30} {var_fake - var_real:>+10.6f}")
print(f"{'Best epoch':<30} {ckpt['epoch']:>10d}")

print(f"\n--- Comparison to Baselines ---")
print(f"{'Model':<35} {'FF++ Val':>10} {'CDF':>10}")
print("-" * 57)
print(f"{'B0 frame-level (Step 3)':<35} {'0.6850':>10} {'0.6135':>10}")
print(f"{'B0 + GRU temporal (Step 4)':<35} {'0.5954':>10} {'0.5524':>10}")
print(f"{'SBI reference (EffNet-B4)':<35} {'—':>10} {'0.9382':>10}")
cdf_str = f"{cdf_auc:.4f}" if cdf_auc else "—"
print(f"{'V8.0 (this run)':<35} {ff_val_auc:>10.4f} {cdf_str:>10}")

print(f"\n--- Stage 5 Exit Criteria ---")
if cdf_auc and cdf_auc > 0.75: print(f"  ✓ CDF AUC {cdf_auc:.4f} > 0.75")
elif cdf_auc: print(f"  ✗ CDF AUC {cdf_auc:.4f} < 0.75")
else: print(f"  ? CDF AUC not measured")
if var_fake - var_real > 0: print(f"  ✓ Variance gap positive ({var_fake-var_real:+.6f})")
else: print(f"  ✗ Variance gap not positive")
last5 = history['val_loss'][-5:]
if len(last5) >= 5 and last5[-1] <= last5[0] * 1.2: print(f"  ✓ No overfitting")
print("\n✓ Section 7 complete")

## Section 8: Variance Visualization

In [None]:
# ============================================================================
# Section 8: Variance Visualization
# ============================================================================
import matplotlib.pyplot as plt

n_samples = min(10, len(val_ds) // 2)
real_sims, fake_sims = [], []
real_vars, fake_vars = [], []

eval_model.eval()
with torch.no_grad():
    for i in range(min(n_samples * 2, len(val_ds))):
        sample = val_ds[i]
        frames = sample['frames'].unsqueeze(0).to(device)
        out = eval_model(frames)
        sims = out['similarities'][0].cpu().numpy()
        var = out['variance'][0].item()
        if sample['label'] == 0:
            real_sims.append(sims); real_vars.append(var)
        else:
            fake_sims.append(sims); fake_vars.append(var)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))
for i, s in enumerate(real_sims[:5]):
    axes[0].plot(s, alpha=0.6, label=f'Real {i}')
axes[0].set_title(f'Real — Mean σ²={np.mean(real_vars):.6f}')
axes[0].set_ylim([0.5, 1.05]); axes[0].legend(fontsize=8); axes[0].grid(True, alpha=0.3)

for i, s in enumerate(fake_sims[:5]):
    axes[1].plot(s, alpha=0.6, label=f'Fake {i}')
axes[1].set_title(f'Fake — Mean σ²={np.mean(fake_vars):.6f}')
axes[1].set_ylim([0.5, 1.05]); axes[1].legend(fontsize=8); axes[1].grid(True, alpha=0.3)

fig.suptitle('STF-Mamba V8.0 — Identity Consistency Signal', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, 'similarity_traces.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nVariance Statistics:")
print(f"  Real: {np.mean(real_vars):.6f} ± {np.std(real_vars):.6f}")
print(f"  Fake: {np.mean(fake_vars):.6f} ± {np.std(fake_vars):.6f}")
print(f"  Gap:  {np.mean(fake_vars) - np.mean(real_vars):+.6f}")

print("\n✓ Section 8 complete")
print("\n" + "=" * 60)
print("  STF-Mamba V8.0 — Training Complete")
print("=" * 60)
print(f"  Best FF++ Val AUC: {ff_val_auc:.4f}")
if cdf_auc: print(f"  Celeb-DF AUC: {cdf_auc:.4f}")
print(f"  Next: Stage 6 — RunPod A100 full 50-epoch training")