# 07 – Train Ensemble Model (Spatial + Temporal)

This notebook trains a **video-level ensemble classifier** that fuses:

• Spatial stream (EfficientNet-B3 frame-level predictions / embeddings)  
• Temporal stream (LSTM + Attention video-level predictions)

The ensemble produces the **final deepfake decision**.

✔ Spatial model: already trained  
✔ Temporal model: already trained  
✔ Frequency stream: skipped (future work)

Outputs:
- checkpoints/ensemble/ensemble_best_valAUC.pth
- Final video-level predictions

In [None]:
from pathlib import Path
import json
import time
import random
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score

In [None]:
# ---------------- USER CONFIG ----------------
ROOT = Path.cwd().parent
DATA_DIR = ROOT / "data"
CHECKPOINT_DIR = ROOT / "checkpoints" / "ensemble"

SPATIAL_PRED_DIR = ROOT / "predictions" / "spatial"     # video-level spatial preds
TEMPORAL_PRED_DIR = ROOT / "predictions" / "temporal"   # video-level temporal preds

LABELS_JSON = DATA_DIR / "labels.json"

NUM_EPOCHS = 15
BATCH_SIZE = 64
LR = 1e-4
WEIGHT_DECAY = 1e-4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
# ---------------------------------------------

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print("Device:", DEVICE)
print("Checkpoint dir:", CHECKPOINT_DIR)



In [None]:
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

with open(LABELS_JSON, "r") as f:
    labels_map = json.load(f)

def get_label(stem):
    if stem in labels_map:
        return int(labels_map[stem])
    for k, v in labels_map.items():
        if stem in k:
            return int(v)
    raise KeyError(f"Label not found for {stem}")

In [None]:
class EnsembleDataset(Dataset):
    """
    Loads:
    - spatial prediction (sigmoid score)
    - temporal prediction (sigmoid score)
    """

    def __init__(self, split):
        self.spatial_dir = SPATIAL_PRED_DIR / split
        self.temporal_dir = TEMPORAL_PRED_DIR / split

        self.items = sorted(
            set(p.stem for p in self.spatial_dir.glob("*.npy"))
            & set(p.stem for p in self.temporal_dir.glob("*.npy"))
        )

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        stem = self.items[idx]

        spatial_score = np.load(self.spatial_dir / f"{stem}.npy").item()
        temporal_score = np.load(self.temporal_dir / f"{stem}.npy").item()

        x = torch.tensor([spatial_score, temporal_score], dtype=torch.float32)
        y = torch.tensor(get_label(stem), dtype=torch.float32)

        return x, y, stem

In [None]:
train_ds = EnsembleDataset("train")
val_ds   = EnsembleDataset("val")
test_ds  = EnsembleDataset("test")

print("Train videos:", len(train_ds))
print("Val videos:", len(val_ds))
print("Test videos:", len(test_ds))

x, y, s = train_ds[0]
print("Sample features:", x, "Label:", y, "Stem:", s)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True
)

val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False
)

test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False
)

In [None]:
class EnsembleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(1)
    

model = EnsembleModel().to(DEVICE)
print(model)

In [None]:
optimizer = optim.AdamW(
    model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
)

criterion = nn.BCEWithLogitsLoss()

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=0.5, patience=2
)

In [None]:
best_val_auc = 0.0

for epoch in range(NUM_EPOCHS):
    t0 = time.time()
    model.train()

    all_preds, all_labels = [], []
    running_loss = 0.0

    for x, y, _ in tqdm(train_loader, desc=f"Epoch {epoch}"):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        all_preds.append(torch.sigmoid(logits).detach().cpu())
        all_labels.append(y.cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    train_auc = roc_auc_score(all_labels, all_preds)
    train_loss = running_loss / len(train_ds)

    model.eval()
    val_preds, val_labels = [], []
    val_loss = 0.0

    with torch.no_grad():
        for x, y, _ in val_loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            logits = model(x)
            loss = criterion(logits, y)

            val_loss += loss.item() * x.size(0)
            val_preds.append(torch.sigmoid(logits).cpu())
            val_labels.append(y.cpu())

    val_preds = torch.cat(val_preds).numpy()
    val_labels = torch.cat(val_labels).numpy()

    val_auc = roc_auc_score(val_labels, val_preds)
    val_loss /= len(val_ds)

    scheduler.step(val_auc)

    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(
            {"model_state": model.state_dict(), "val_auc": val_auc},
            CHECKPOINT_DIR / "ensemble_best_valAUC.pth"
        )

    print(
        f"Epoch {epoch} | "
        f"train_loss={train_loss:.4f} train_auc={train_auc:.4f} | "
        f"val_loss={val_loss:.4f} val_auc={val_auc:.4f} | "
        f"time={time.time()-t0:.1f}s"
    )

In [None]:
ck = torch.load(CHECKPOINT_DIR / "ensemble_best_valAUC.pth", map_location=DEVICE)
model.load_state_dict(ck["model_state"])
model.eval()

test_preds, test_labels = [], []

with torch.no_grad():
    for x, y, _ in test_loader:
        x = x.to(DEVICE)
        logits = model(x)
        test_preds.append(torch.sigmoid(logits).cpu())
        test_labels.append(y)

test_preds = torch.cat(test_preds).numpy()
test_labels = torch.cat(test_labels).numpy()

print("Ensemble Test AUC:", roc_auc_score(test_labels, test_preds))