# STF-Mamba V8.0 — Kaggle Training Notebook

**Semantic Temporal Forensics via Hydra-Mamba and DINOv2**

Target: CVPR/ICCV 2026 | Platform: Kaggle T4 x2 | 25 epochs

---

| Section | Contents |
|---------|----------|
| 1 | Setup + GPU check + install dependencies |
| 2 | Clone repo + path configuration |
| 3 | Dataset paths (FF++ and Celeb-DF) |
| 4 | Preprocessing + SBI cache build |
| 5 | Model init + param count + forward pass verify |
| 6 | Training loop (25 epochs, batch 8) |
| 7 | Evaluation — FF++ val + Celeb-DF AUC |
| 8 | Variance visualization |

## Section 1: Setup + GPU Check + Install Dependencies

In [None]:
# ============================================================================
# Section 1: Setup + GPU Check + Install Dependencies
# ============================================================================
import subprocess, sys, os

# GPU check
print("=" * 60)
print("  STF-Mamba V8.0 — Kaggle Environment Check")
print("=" * 60)
!nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"  GPU {i}: {torch.cuda.get_device_name(i)} ({torch.cuda.get_device_properties(i).total_mem / 1e9:.1f} GB)")

# Install mamba_ssm (may fail — Conv1d fallback is built in)
print("\n--- Installing mamba_ssm ---")
try:
    import mamba_ssm
    print(f"mamba_ssm already installed: {mamba_ssm.__version__}")
except ImportError:
    try:
        !pip install mamba_ssm -q
        import mamba_ssm
        print(f"mamba_ssm installed: {mamba_ssm.__version__}")
    except Exception as e:
        print(f"mamba_ssm install failed (expected on some Kaggle kernels): {e}")
        print("Using Conv1d fallback — this is fine for training.")

# Install other dependencies
!pip install albumentations einops imutils -q

print("\n✓ Section 1 complete")

## Section 2: Clone Repo + Path Configuration

In [None]:
# ============================================================================
# Section 2: Clone Repo + Path Configuration
# ============================================================================
REPO_URL = "https://github.com/AbdelRahman-Madboly/STF-Mamba_V8.0.git"
REPO_DIR = "/kaggle/working/STF-Mamba_V8.0"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL} {REPO_DIR}
else:
    print(f"Repo already cloned at {REPO_DIR}")
    !cd {REPO_DIR} && git pull

# Add repo to Python path
sys.path.insert(0, REPO_DIR)
os.chdir(REPO_DIR)

# Verify imports
from stf_mamba import STFMambaV8, STFMambaLoss, is_mamba_available
from data import SBIVideoDataset, get_train_transforms, get_val_transforms, load_all_splits
from training import Trainer, build_optimizer, build_scheduler

print(f"\nRepo: {REPO_DIR}")
print(f"Mamba SSM: {'native' if is_mamba_available() else 'Conv1d fallback'}")
print("✓ Section 2 complete — all imports OK")

## Section 3: Dataset Paths

**Required Kaggle datasets (attach before running):**
- FaceForensics++ (original sequences, raw videos)
- Celeb-DF v2 (for cross-dataset evaluation)

Update the paths below to match your attached dataset names.

In [None]:
# ============================================================================
# Section 3: Dataset Paths
# ============================================================================
# UPDATE THESE to match your attached Kaggle dataset names:
FF_ROOT = "/kaggle/input/faceforensics"          # FF++ root
CELEB_DF_ROOT = "/kaggle/input/celeb-df-v2"      # Celeb-DF root

# Auto-detect FF++ video directory
FF_VIDEO_CANDIDATES = [
    os.path.join(FF_ROOT, "original_sequences/youtube/raw/videos"),
    os.path.join(FF_ROOT, "original_sequences/youtube/c23/videos"),
    os.path.join(FF_ROOT, "youtube/raw/videos"),
    os.path.join(FF_ROOT, "videos"),
    FF_ROOT,
]

FF_VIDEO_DIR = None
for candidate in FF_VIDEO_CANDIDATES:
    if os.path.isdir(candidate):
        mp4s = [f for f in os.listdir(candidate) if f.endswith('.mp4')]
        if mp4s:
            FF_VIDEO_DIR = candidate
            print(f"FF++ videos found: {candidate} ({len(mp4s)} videos)")
            break

if FF_VIDEO_DIR is None:
    print("WARNING: FF++ video directory not found!")
    print("Available at FF_ROOT:")
    if os.path.exists(FF_ROOT):
        for root, dirs, files in os.walk(FF_ROOT):
            depth = root.replace(FF_ROOT, '').count(os.sep)
            if depth < 4:
                indent = ' ' * 2 * depth
                print(f"{indent}{os.path.basename(root)}/")
    else:
        print(f"  {FF_ROOT} does not exist — attach the FF++ dataset!")

# Working directories
SPLITS_DIR = os.path.join(REPO_DIR, "splits")
CACHE_DIR = "/kaggle/working/cache"
CHECKPOINT_DIR = "/kaggle/working/checkpoints"
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# dlib predictor path (download if needed)
PREDICTOR_PATH = "/kaggle/working/shape_predictor_81_face_landmarks.dat"
if not os.path.exists(PREDICTOR_PATH):
    print("\nDownloading dlib 81-point landmark predictor...")
    # Check if bundled with a dataset
    dlib_candidates = [
        "/kaggle/input/dlib-shape-predictor/shape_predictor_81_face_landmarks.dat",
        "/kaggle/input/shape-predictor/shape_predictor_81_face_landmarks.dat",
    ]
    found = False
    for c in dlib_candidates:
        if os.path.exists(c):
            !cp {c} {PREDICTOR_PATH}
            found = True
            print(f"  Copied from: {c}")
            break
    if not found:
        print("  NOTE: 81-point predictor not found.")
        print("  Attach 'dlib-shape-predictor' dataset, or download manually.")
        print("  Source: https://github.com/codeniko/shape_predictor_81_face_landmarks")

# Verify splits
splits = load_all_splits(SPLITS_DIR)
print(f"\nSplits: train={len(splits['train'])}, val={len(splits['val'])}, test={len(splits['test'])}")

print("\n✓ Section 3 complete")

## Section 4: Preprocessing + SBI Cache Build

First run: ~15 min to extract face crops from all videos and build SBI cache.  
Subsequent runs: loads instantly from NPZ cache.

In [None]:
# ============================================================================
# Section 4: Preprocessing + SBI Cache Build
# ============================================================================
import time
from data.preprocessing import FacePreprocessor
from data.splits import get_video_ids

# --- 4a: Pre-extract face crops for all videos in train + val splits ---
print("=" * 60)
print("  Phase 4a: Face Crop Extraction")
print("=" * 60)

NUM_FRAMES = 32
IMG_SIZE = 224

preprocessor = FacePreprocessor(
    video_dir=FF_VIDEO_DIR,
    cache_dir=os.path.join(CACHE_DIR, "crops"),
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE,
    predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
)

# Get all video IDs needed for train + val
train_ids = get_video_ids(splits['train'])
val_ids = get_video_ids(splits['val'])
all_ids = sorted(set(train_ids + val_ids))
print(f"Videos to preprocess: {len(all_ids)} (train: {len(train_ids)}, val: {len(val_ids)})")

t0 = time.time()
preprocessor.preprocess_all(all_ids, show_progress=True)
print(f"Face crops done in {time.time()-t0:.0f}s")

# --- 4b: Build SBI fake cache ---
print(f"\n{'=' * 60}")
print("  Phase 4b: SBI Fake Generation")
print("=" * 60)

train_tf = get_train_transforms(IMG_SIZE)
val_tf = get_val_transforms(IMG_SIZE)

train_ds = SBIVideoDataset(
    split_path=os.path.join(SPLITS_DIR, "Dataset_Split_train.json"),
    video_dir=FF_VIDEO_DIR,
    cache_dir=CACHE_DIR,
    phase="train",
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE,
    transform=train_tf,
    sbi_seed=42,
    predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
)

val_ds = SBIVideoDataset(
    split_path=os.path.join(SPLITS_DIR, "Dataset_Split_val.json"),
    video_dir=FF_VIDEO_DIR,
    cache_dir=CACHE_DIR,
    phase="val",
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE,
    transform=val_tf,
    sbi_seed=42,
    predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
)

t0 = time.time()
train_ds.build_cache(show_progress=True)
val_ds.build_cache(show_progress=True)
print(f"SBI cache done in {time.time()-t0:.0f}s")

# Quick sanity check
sample = train_ds[0]
print(f"\nSanity check:")
print(f"  frames: {sample['frames'].shape}")
print(f"  label:  {sample['label']}")
print(f"  id:     {sample['video_id']}")
print(f"\nDataset sizes: train={len(train_ds)}, val={len(val_ds)}")

print("\n✓ Section 4 complete")

## Section 5: Model Init + Param Count + Forward Pass Verify

In [None]:
# ============================================================================
# Section 5: Model Init + Param Count + Forward Pass Verify
# ============================================================================
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Build model
print("\nLoading DINOv2-ViT-B/14...")
model = STFMambaV8(pretrained_backbone=True)

# Param count
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"\nParameter Count:")
print(f"  Total:     {total_params / 1e6:.1f}M")
print(f"  Trainable: {trainable_params / 1e6:.1f}M")
print(f"  Frozen:    {frozen_params / 1e6:.1f}M")

# Forward pass verify
print(f"\nForward pass test...")
model.eval()
model = model.to(device)
x_test = torch.randn(1, NUM_FRAMES, 3, IMG_SIZE, IMG_SIZE).to(device)

with torch.no_grad():
    out = model(x_test)

print(f"  Input:    {x_test.shape}")
print(f"  Logits:   {out['logits'].shape}  ← expected (1, 2)")
print(f"  Variance: {out['variance'].shape} ← expected (1, 1)")
assert out['logits'].shape == (1, 2), "Logits shape mismatch!"
assert out['variance'].shape == (1, 1), "Variance shape mismatch!"

# Clean up test tensor
del x_test, out
torch.cuda.empty_cache()

# Move model back to CPU before Trainer wraps it
model = model.cpu()

print("\n✓ Section 5 complete — forward pass verified")

## Section 6: Training (25 Epochs)

Config: `configs/v8_kaggle.yaml`
- Batch size: 8 (T4 VRAM constraint)
- Frames: 32 per clip
- LR: backbone=5e-6, temporal+head=1e-4
- Loss: CE + 0.1 * variance auxiliary
- DataParallel on T4 x2

In [None]:
# ============================================================================
# Section 6: Training Loop
# ============================================================================
import numpy as np
import random

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Kaggle config
config = {
    "epochs": 25,
    "batch_size": 8,
    "lr_backbone": 5e-6,
    "lr_temporal": 1e-4,
    "lr_head": 1e-4,
    "weight_decay": 1e-4,
    "warmup_epochs": 3,
    "grad_clip": 1.0,
    "lambda_var": 0.1,
}

# DataLoaders — num_workers=0 on Kaggle (CRITICAL: prevents deadlock)
train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=0,
    collate_fn=SBIVideoDataset.collate_fn,
    pin_memory=True,
    drop_last=True,
)

val_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=0,
    collate_fn=SBIVideoDataset.collate_fn,
    pin_memory=True,
    drop_last=False,
)

print(f"Train: {len(train_ds)} clips → {len(train_loader)} batches (bs={config['batch_size']})")
print(f"Val:   {len(val_ds)} clips → {len(val_loader)} batches")

# Loss (label_smoothing=0.0 ALWAYS — Bug #1)
criterion = STFMambaLoss(lambda_var=config["lambda_var"])

# Trainer
trainer = Trainer(
    model=model,
    criterion=criterion,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    save_dir=CHECKPOINT_DIR,
    device=device,
)

# Train!
print(f"\n{'='*60}")
print(f"  Starting training: {config['epochs']} epochs on {device}")
print(f"{'='*60}")

history = trainer.train(num_epochs=config["epochs"])

print(f"\nBest val AUC: {trainer.best_val_auc:.4f}")
print("\n✓ Section 6 complete")

In [None]:
# --- 6b: Training Curves ---
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs = range(1, len(history['train_loss']) + 1)

# Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train')
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val')
axes[0, 0].set_title('Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train')
axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val')
axes[0, 1].set_title('Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Val AUC
axes[1, 0].plot(epochs, history['val_auc'], 'g-', marker='o', markersize=3)
axes[1, 0].axhline(y=0.75, color='orange', linestyle='--', alpha=0.7, label='Stage 5 target (0.75)')
axes[1, 0].axhline(y=0.90, color='red', linestyle='--', alpha=0.7, label='Paper target (0.90)')
axes[1, 0].set_title('Val AUC (FF++)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylim([0.4, 1.0])
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Variance Gap
axes[1, 1].plot(epochs, history['var_gap'], 'm-', marker='o', markersize=3)
axes[1, 1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[1, 1].set_title('Variance Gap (fake - real)')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].legend(['Var gap'])
axes[1, 1].grid(True, alpha=0.3)

fig.suptitle('STF-Mamba V8.0 — Kaggle Training (25 epochs)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, 'training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()
print("Training curves saved.")

## Section 7: Evaluation — FF++ Val + Celeb-DF AUC

Load best checkpoint, evaluate on:
1. FF++ validation split (SBI)
2. Celeb-DF v2 (cross-dataset — the main paper number)

In [None]:
# ============================================================================
# Section 7: Evaluation
# ============================================================================
from sklearn.metrics import roc_auc_score, accuracy_score, roc_curve
import numpy as np

# Load best checkpoint
best_path = os.path.join(CHECKPOINT_DIR, "best.pth")
print(f"Loading best checkpoint: {best_path}")
ckpt = torch.load(best_path, map_location=device)
print(f"  Saved at epoch {ckpt['epoch']}, val AUC: {ckpt['val_metrics']['auc']:.4f}")

eval_model = STFMambaV8(pretrained_backbone=True).to(device)
eval_model.load_state_dict(ckpt['model_state_dict'])
eval_model.eval()

# ---- 7a: FF++ Val AUC (SBI) ----
print("\n--- 7a: FF++ Validation (SBI) ---")
all_probs, all_labels = [], []
all_variances = []

with torch.no_grad():
    for batch in val_loader:
        frames = batch['frames'].to(device)
        labels = batch['label']
        out = eval_model(frames)
        probs = torch.softmax(out['logits'], dim=1)[:, 1].cpu().numpy()
        var = out['variance'].cpu().numpy().flatten()
        all_probs.extend(probs)
        all_labels.extend(labels.numpy())
        all_variances.extend(var)

ff_val_auc = roc_auc_score(all_labels, all_probs)
ff_val_acc = accuracy_score(all_labels, [1 if p > 0.5 else 0 for p in all_probs])

all_labels_np = np.array(all_labels)
all_var_np = np.array(all_variances)
var_real = all_var_np[all_labels_np == 0].mean()
var_fake = all_var_np[all_labels_np == 1].mean()

print(f"  FF++ Val AUC:  {ff_val_auc:.4f}")
print(f"  FF++ Val Acc:  {ff_val_acc:.4f}")
print(f"  Variance real: {var_real:.6f}")
print(f"  Variance fake: {var_fake:.6f}")
print(f"  Variance gap:  {var_fake - var_real:+.6f}")

In [None]:
# ---- 7b: Celeb-DF Cross-Dataset Evaluation ----
print("\n--- 7b: Celeb-DF v2 Cross-Dataset ---")

# Celeb-DF evaluation: load real + fake videos, run model
import cv2
from data.preprocessing import FacePreprocessor
from data.augmentation import apply_transform_to_clip

# Find Celeb-DF videos
cdf_real_dirs = [
    os.path.join(CELEB_DF_ROOT, "Celeb-real"),
    os.path.join(CELEB_DF_ROOT, "YouTube-real"),
    os.path.join(CELEB_DF_ROOT, "celeb_real"),
]
cdf_fake_dirs = [
    os.path.join(CELEB_DF_ROOT, "Celeb-synthesis"),
    os.path.join(CELEB_DF_ROOT, "celeb_synthesis"),
]

def find_videos(dir_list, extensions=('.mp4', '.avi')):
    """Find all video files in candidate directories."""
    videos = []
    for d in dir_list:
        if os.path.isdir(d):
            for root, _, files in os.walk(d):
                for f in files:
                    if f.lower().endswith(extensions):
                        videos.append(os.path.join(root, f))
    return sorted(videos)

cdf_real_videos = find_videos(cdf_real_dirs)
cdf_fake_videos = find_videos(cdf_fake_dirs)
print(f"  Celeb-DF real:  {len(cdf_real_videos)} videos")
print(f"  Celeb-DF fake:  {len(cdf_fake_videos)} videos")

if len(cdf_real_videos) > 0 and len(cdf_fake_videos) > 0:
    # Process Celeb-DF videos
    from tqdm import tqdm
    
    cdf_preprocessor = FacePreprocessor(
        video_dir=CELEB_DF_ROOT,
        cache_dir=os.path.join(CACHE_DIR, "cdf_crops"),
        num_frames=NUM_FRAMES,
        img_size=IMG_SIZE,
        predictor_path=PREDICTOR_PATH if os.path.exists(PREDICTOR_PATH) else None,
    )

    def evaluate_video_list(video_paths, label, model, transform, preprocessor):
        """Evaluate a list of videos and return probs."""
        probs_list = []
        for vpath in tqdm(video_paths, desc=f"Eval label={label}"):
            vid_id = os.path.splitext(os.path.basename(vpath))[0]
            try:
                crops, _ = preprocessor.get_video(vid_id)
                frames_t = apply_transform_to_clip(crops[:NUM_FRAMES], transform)
                frames_t = frames_t.unsqueeze(0).to(device)  # (1, T, 3, H, W)
                with torch.no_grad():
                    out = model(frames_t)
                    prob = torch.softmax(out['logits'], dim=1)[0, 1].item()
                probs_list.append(prob)
            except Exception as e:
                continue  # Skip failed videos
        return probs_list

    cdf_real_probs = evaluate_video_list(cdf_real_videos, 0, eval_model, val_tf, cdf_preprocessor)
    cdf_fake_probs = evaluate_video_list(cdf_fake_videos, 1, eval_model, val_tf, cdf_preprocessor)

    cdf_labels = [0] * len(cdf_real_probs) + [1] * len(cdf_fake_probs)
    cdf_probs = cdf_real_probs + cdf_fake_probs

    if len(set(cdf_labels)) > 1:
        cdf_auc = roc_auc_score(cdf_labels, cdf_probs)
        cdf_acc = accuracy_score(cdf_labels, [1 if p > 0.5 else 0 for p in cdf_probs])
        print(f"\n  Celeb-DF AUC:  {cdf_auc:.4f}")
        print(f"  Celeb-DF Acc:  {cdf_acc:.4f}")
        print(f"  Videos evaluated: {len(cdf_real_probs)} real + {len(cdf_fake_probs)} fake")
    else:
        cdf_auc = None
        print("  Could not compute AUC — need both classes")
else:
    cdf_auc = None
    print("  Celeb-DF videos not found — skip cross-dataset eval")
    print(f"  Looked in: {cdf_real_dirs}")

In [None]:
# ---- 7c: Results Summary Table ----
print("\n" + "=" * 60)
print("  RESULTS SUMMARY — STF-Mamba V8.0 (Kaggle 25 epochs)")
print("=" * 60)

print(f"\n{'Metric':<30} {'Value':>10}")
print("-" * 42)
print(f"{'FF++ Val AUC (SBI)':<30} {ff_val_auc:>10.4f}")
print(f"{'FF++ Val Acc':<30} {ff_val_acc:>10.4f}")
if cdf_auc is not None:
    print(f"{'Celeb-DF AUC':<30} {cdf_auc:>10.4f}")
print(f"{'Variance gap (fake-real)':<30} {var_fake - var_real:>+10.6f}")
print(f"{'Best epoch':<30} {ckpt['epoch']:>10d}")

print(f"\n--- Comparison to Baselines ---")
print(f"{'Model':<35} {'FF++ Val':>10} {'CDF':>10}")
print("-" * 57)
print(f"{'B0 frame-level (Step 3)':<35} {'0.6850':>10} {'0.6135':>10}")
print(f"{'B0 + GRU temporal (Step 4)':<35} {'0.5954':>10} {'0.5524':>10}")
print(f"{'SBI reference (EffNet-B4)':<35} {'—':>10} {'0.9382':>10}")
cdf_str = f"{cdf_auc:.4f}" if cdf_auc else "—"
print(f"{'V8.0 (this run)':<35} {ff_val_auc:>10.4f} {cdf_str:>10}")

# Stage 5 exit criteria check
print(f"\n--- Stage 5 Exit Criteria ---")
checks = []
if cdf_auc and cdf_auc > 0.75:
    checks.append(f"  ✓ CDF AUC {cdf_auc:.4f} > 0.75")
elif cdf_auc:
    checks.append(f"  ✗ CDF AUC {cdf_auc:.4f} < 0.75 — needs improvement")
else:
    checks.append(f"  ? CDF AUC not measured — attach Celeb-DF dataset")

if var_fake - var_real > 0:
    checks.append(f"  ✓ Variance gap positive ({var_fake - var_real:+.6f})")
else:
    checks.append(f"  ✗ Variance gap not positive — check Section 5.2 troubleshooting")

# Check overfitting
last5_val = history['val_loss'][-5:]
if len(last5_val) >= 5 and last5_val[-1] <= last5_val[0] * 1.2:
    checks.append(f"  ✓ No overfitting (last 5 val loss stable)")
else:
    checks.append(f"  ? Check val loss trend in last 5 epochs")

for c in checks:
    print(c)

print("\n✓ Section 7 complete")

## Section 8: Variance Visualization

Show per-frame identity variance for real vs fake clips.
**Key hypothesis:** Fake clips should show higher temporal variance in identity embeddings.

In [None]:
# ============================================================================
# Section 8: Variance Visualization
# ============================================================================
import matplotlib.pyplot as plt
import numpy as np

# Collect per-frame similarities for several real and fake clips
n_samples = min(10, len(val_ds) // 2)

real_similarities = []
fake_similarities = []
real_variances = []
fake_variances = []

eval_model.eval()
with torch.no_grad():
    for i in range(0, min(n_samples * 2, len(val_ds)), 2):
        # Real clip (even index)
        real_sample = val_ds[i]
        real_frames = real_sample['frames'].unsqueeze(0).to(device)
        real_out = eval_model(real_frames)
        real_sims = real_out['similarities'][0].cpu().numpy()
        real_similarities.append(real_sims)
        real_variances.append(real_out['variance'][0].item())

        # Fake clip (odd index)
        if i + 1 < len(val_ds):
            fake_sample = val_ds[i + 1]
            fake_frames = fake_sample['frames'].unsqueeze(0).to(device)
            fake_out = eval_model(fake_frames)
            fake_sims = fake_out['similarities'][0].cpu().numpy()
            fake_similarities.append(fake_sims)
            fake_variances.append(fake_out['variance'][0].item())

# --- Plot 1: Per-frame similarity traces ---
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Real clips
for i, sims in enumerate(real_similarities[:5]):
    axes[0].plot(range(len(sims)), sims, alpha=0.6, label=f'Real {i}')
axes[0].set_title(f'Real Clips — Per-Frame Cosine Similarity\nMean σ² = {np.mean(real_variances):.6f}')
axes[0].set_xlabel('Frame')
axes[0].set_ylabel('cos(h_t, mean(H))')
axes[0].set_ylim([0.5, 1.05])
axes[0].legend(fontsize=8)
axes[0].grid(True, alpha=0.3)

# Fake clips
for i, sims in enumerate(fake_similarities[:5]):
    axes[1].plot(range(len(sims)), sims, alpha=0.6, label=f'Fake {i}')
axes[1].set_title(f'Fake Clips — Per-Frame Cosine Similarity\nMean σ² = {np.mean(fake_variances):.6f}')
axes[1].set_xlabel('Frame')
axes[1].set_ylabel('cos(h_t, mean(H))')
axes[1].set_ylim([0.5, 1.05])
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.3)

fig.suptitle('STF-Mamba V8.0 — Identity Consistency Signal', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, 'similarity_traces.png'), dpi=150, bbox_inches='tight')
plt.show()

# --- Plot 2: Variance distribution ---
fig, ax = plt.subplots(figsize=(10, 5))
bins = np.linspace(
    min(min(real_variances, default=0), min(fake_variances, default=0)),
    max(max(real_variances, default=1), max(fake_variances, default=1)),
    30
)
ax.hist(real_variances, bins=bins, alpha=0.6, label=f'Real (n={len(real_variances)})', color='green')
ax.hist(fake_variances, bins=bins, alpha=0.6, label=f'Fake (n={len(fake_variances)})', color='red')
ax.set_title('Temporal Variance Distribution: Real vs Fake')
ax.set_xlabel('Identity Variance (σ²)')
ax.set_ylabel('Count')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, 'variance_distribution.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nVariance Statistics:")
print(f"  Real mean σ²: {np.mean(real_variances):.6f} ± {np.std(real_variances):.6f}")
print(f"  Fake mean σ²: {np.mean(fake_variances):.6f} ± {np.std(fake_variances):.6f}")
print(f"  Gap:           {np.mean(fake_variances) - np.mean(real_variances):+.6f}")

print("\n✓ Section 8 complete")
print("\n" + "=" * 60)
print("  STF-Mamba V8.0 — Kaggle Notebook Complete")
print("=" * 60)
print(f"  Best FF++ Val AUC: {ff_val_auc:.4f}")
if cdf_auc:
    print(f"  Celeb-DF AUC:      {cdf_auc:.4f}")
print(f"  Checkpoint:        {best_path}")
print(f"  Next: Stage 6 — RunPod A100 full 50-epoch training")