# 06 - Train temporal model (LSTM + attention) on embeddings
Trains a video-level LSTM aggregator using per-frame embeddings produced by the spatial model.
Saves checkpoints: checkpoints/temporal/


In [1]:
from pathlib import Path
import json, time
import random
from pprint import pprint
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
#from torch.cuda.amp import autocast, GradScaler
# AMP disabled for temporal model (FP32 is more stable)
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score

# ------------- USER CONFIG -------------
ROOT = Path.cwd().parent
EMB_ROOT = ROOT / "embeddings"             # embeddings/<split>/<video_stem>.npy
LABELS_JSON = ROOT / "data" / "labels.json"
CHECKPOINT_DIR = ROOT / "checkpoints" / "temporal"
NUM_EPOCHS = 25
BATCH_SIZE = 16            # number of videos per batch
LR = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 0            # keep 0 in notebooks; increase on robust machines
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PRINT_FREQ = 20
LSTM_HIDDEN = 512
LSTM_LAYERS = 2
DROPOUT = 0.3
ATTENTION = True           # use attention pooling over LSTM outputs
BIDIRECTIONAL = True
# ---------------------------------------

BAD_EMB_LOG = CHECKPOINT_DIR / "bad_embeddings.txt"
if BAD_EMB_LOG.exists():
    BAD_EMB_LOG.unlink()

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print("Device:", DEVICE)
print("Emb root:", EMB_ROOT)
print("Checkpoint dir:", CHECKPOINT_DIR)


Device: cuda
Emb root: c:\Users\lkmah\OneDrive\Desktop\Lokesh\VS Code\DeepFake_Detection_SIC\embeddings
Checkpoint dir: c:\Users\lkmah\OneDrive\Desktop\Lokesh\VS Code\DeepFake_Detection_SIC\checkpoints\temporal


In [2]:
# reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

with open(LABELS_JSON, "r") as f:
    labels_map = json.load(f)

def get_label_from_stem(stem):
    if stem in labels_map:
        return int(labels_map[stem])
    for k,v in labels_map.items():
        if stem in k:
            return int(v)
    raise KeyError(f"Label for {stem} not found")


In [3]:
class VideoEmbeddingDataset(Dataset):
    def __init__(self, split):
        self.root = EMB_ROOT / split
        if not self.root.exists():
            raise RuntimeError(f"No embeddings for split: {split}")
        self.items = sorted([p for p in self.root.glob("*.npy")])
        # optional: filter if empty
        self.items = [p for p in self.items if np.load(p).shape[0] > 0]

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        p = self.items[idx]
        stem = p.stem

        try:
            arr = np.load(p)
        except Exception as e:
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"[LOAD ERROR] {p} | {e}\n")
            raise

        if arr.ndim != 2 or arr.shape[0] == 0:
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"[EMPTY] {p} | shape={arr.shape}\n")
            raise ValueError(f"Empty embedding: {p}")

        if np.isnan(arr).any() or np.isinf(arr).any():
            with open(BAD_EMB_LOG, "a") as f:
                f.write(
                    f"[NaN/Inf] {p} | "
                    f"shape={arr.shape} "
                    f"min={np.nanmin(arr)} "
                    f"max={np.nanmax(arr)} "
                    f"mean={np.nanmean(arr)}\n"
                )
            raise ValueError(f"NaN/Inf in embedding: {p}")

        emb = torch.from_numpy(arr.astype(np.float32))
        label = get_label_from_stem(stem)
        return emb, torch.tensor(label, dtype=torch.float32), stem

# quick sanity
# ds = VideoEmbeddingDataset("train")
# print("Train videos:", len(ds))


In [4]:
def collate_fn(batch):
    """
    batch: list of (emb [T,feat], label, stem)
    Pads sequences to longest T in batch (simple zero padding).
    Returns tensors: seqs [B, Tmax, feat], lengths [B], labels [B]
    """
    seqs, labels, stems = zip(*batch)
    lengths = [s.shape[0] for s in seqs]
    maxlen = max(lengths)
    feat_dim = seqs[0].shape[1]
    out = torch.zeros(len(seqs), maxlen, feat_dim, dtype=torch.float32)
    for i, s in enumerate(seqs):
        out[i, :s.shape[0], :] = s
    labels = torch.stack(labels)
    return out, torch.tensor(lengths, dtype=torch.long), labels, list(stems)


In [5]:
class AttentionPool(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.att = nn.Linear(hidden_dim, 1)

    def forward(self, h, lengths):
        B, T, _ = h.shape

        # never allow zero length
        lengths = torch.clamp(lengths, min=1)

        scores = self.att(h).squeeze(-1)  # [B, T]

        mask = torch.arange(T, device=h.device).unsqueeze(0) >= lengths.unsqueeze(1)
        scores = scores.masked_fill(mask, -1e9)

        weights = torch.softmax(scores, dim=1)

        # ABSOLUTE safety net
        weights = torch.nan_to_num(weights, nan=0.0, posinf=0.0, neginf=0.0)

        out = (h * weights.unsqueeze(-1)).sum(dim=1)
        return out, weights

class TemporalModel(nn.Module):
    def __init__(self, feat_dim):
        super().__init__()
        self.lstm = nn.LSTM(
            feat_dim, LSTM_HIDDEN, LSTM_LAYERS,
            batch_first=True,
            bidirectional=BIDIRECTIONAL,
            dropout=DROPOUT if LSTM_LAYERS > 1 else 0
        )
        out_dim = LSTM_HIDDEN * (2 if BIDIRECTIONAL else 1)
        self.attn = AttentionPool(out_dim)
        self.head = nn.Sequential(
            nn.Linear(out_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, x, lengths):
        # x: [B, T, feat], lengths: [B] (LongTensor)
        # Handle case all lengths equal -> no need to pack (but pack still works).
        if lengths.numel() == 0:
            raise ValueError("Empty lengths tensor in TemporalModel.forward")

        # sort by lengths (descending)
        lengths_sorted, perm_idx = lengths.sort(descending=True)
        x_sorted = x[perm_idx]

        # pack (pack expects CPU lengths)
        packed = rnn_utils.pack_padded_sequence(x_sorted, lengths_sorted.cpu(), batch_first=True, enforce_sorted=True)
        packed_out, _ = self.lstm(packed)
        out_unpacked, _ = rnn_utils.pad_packed_sequence(packed_out, batch_first=True)  # [B_sorted, Tmax_sorted, H_out]

        # unsort back to original order
        _, unperm_idx = perm_idx.sort()
        out = out_unpacked[unperm_idx]
        lengths = lengths[unperm_idx]

        pooled, att_weights = self.attn(out, lengths)
        logits = self.head(pooled).squeeze(1)

        # restore att_weights to original order too
        if att_weights is not None:
            att_weights = att_weights[unperm_idx]

        return logits, att_weights


In [6]:
def safe_auc(y_true, y_pred):
    if np.isnan(y_pred).any():
        print("⚠ NaNs detected in predictions — skipping AUC")
        return float("nan")
    try:
        return roc_auc_score(y_true, y_pred)
    except Exception as e:
        print("AUC error:", e)
        return float("nan")


In [7]:
def save_checkpoint(state, fname):
    # copy model weights to CPU to reduce CUDA memory pressure and make file portable
    cpu_state = state.copy()
    cpu_state["model_state"] = {k: v.cpu() for k, v in state["model_state"].items()}
    # optimizer state may contain tensors — move them to CPU as well (if present)
    if "optimizer_state" in state and state["optimizer_state"] is not None:
        opt_state = state["optimizer_state"]
        # shallow copy
        cpu_opt_state = {}
        cpu_opt_state['state'] = {}
        cpu_opt_state['param_groups'] = opt_state.get('param_groups', [])
        for k, v in opt_state.get('state', {}).items():
            cpu_opt_state['state'][k] = {sk: sv.cpu() if isinstance(sv, torch.Tensor) else sv
                                         for sk, sv in v.items()}
        cpu_state["optimizer_state"] = cpu_opt_state
    torch.save(cpu_state, fname)

In [8]:
# Build one dataset to read feat_dim
train_ds = VideoEmbeddingDataset("train")
val_ds = VideoEmbeddingDataset("val")
# sanity check 
if len(train_ds) == 0:
    raise RuntimeError("No train embeddings found. Run extract_embeddings first.")

sample_emb = np.load(train_ds.items[0])
FEAT_DIM = int(sample_emb.shape[1])
print("Feat dim:", FEAT_DIM, "Train videos:", len(train_ds))

pin_memory = torch.cuda.is_available()

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
                        collate_fn=collate_fn, pin_memory=pin_memory, persistent_workers=(NUM_WORKERS > 0))
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,
                        collate_fn=collate_fn, pin_memory=pin_memory, persistent_workers=(NUM_WORKERS > 0))

model = TemporalModel(feat_dim=FEAT_DIM).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)
criterion = nn.BCEWithLogitsLoss()
print(model)

Feat dim: 1536 Train videos: 4066
TemporalModel(
  (lstm): LSTM(1536, 512, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (attn): AttentionPool(
    (att): Linear(in_features=1024, out_features=1, bias=True)
  )
  (head): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)


In [9]:
# quick sanity test
# import torch
# feat_dim = FEAT_DIM  # from earlier
# model_test = TemporalModel(feat_dim=feat_dim).to(DEVICE)
# B, T = 4, 8
# dummy = torch.randn(B, T, feat_dim).to(DEVICE)
# lengths = torch.tensor([8,6,5,7], dtype=torch.long).to(DEVICE)
# with torch.no_grad():
#     logits, att = model_test(dummy, lengths)
# print("logits.shape:", logits.shape, "att shape:", None if att is None else att.shape)

In [10]:
best_val_auc = 0.0
start_epoch = 0
last_ckpt = CHECKPOINT_DIR / "temporal_last.pth"
if last_ckpt.exists():
    ck = torch.load(last_ckpt, map_location=DEVICE)
    model.load_state_dict(ck["model_state"])
    optimizer.load_state_dict(ck["optimizer_state"])
    start_epoch = ck.get("epoch", 0) + 1
    best_val_auc = ck.get("best_val_auc", 0.0)
    print("Resumed temporal from", start_epoch, "best", best_val_auc)


for epoch in range(start_epoch, NUM_EPOCHS):
    t0 = time.time()
    model.train()
    all_preds, all_labels = [], []
    running_loss = 0.0

    for seqs, lengths, labels, stems in tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch", disable=False):
        seqs = seqs.to(DEVICE)
        lengths = lengths.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        logits, _ = model(seqs, lengths)

        if torch.isnan(logits).any():
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"[NaN LOGITS] Epoch={epoch} | stems={stems}\n")
            print("NaN logits detected for batch stems:", stems)
            break
        
        loss = criterion(logits, labels)

        loss.backward()

        # (optional but STRONGLY recommended)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

        optimizer.step()

        running_loss += loss.item() * seqs.size(0)
        all_preds.append(torch.sigmoid(logits).detach().cpu())
        all_labels.append(labels.detach().cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    print("NaNs in preds:", np.isnan(all_preds).any())
    print("NaNs in labels:", np.isnan(all_labels).any())
    print("Pred min/max:", np.nanmin(all_preds), np.nanmax(all_preds))
    train_auc = safe_auc(all_labels, all_preds)
    #train_auc = roc_auc_score(all_labels, all_preds)
    train_loss = running_loss / len(train_ds)

    # validation
    model.eval()
    val_preds, val_labels = [], []
    val_loss = 0.0
    with torch.no_grad():
        for seqs, lengths, labels, stems in tqdm(val_loader):
            seqs = seqs.to(DEVICE); lengths = lengths.to(DEVICE); labels = labels.to(DEVICE)

            logits, _ = model(seqs, lengths)
            loss = criterion(logits, labels)
            val_loss += loss.item() * seqs.size(0)
            val_preds.append(torch.sigmoid(logits).cpu())
            val_labels.append(labels.cpu())

    val_preds = torch.cat(val_preds).numpy()
    val_labels = torch.cat(val_labels).numpy()
    val_auc = safe_auc(val_labels, val_preds)
    #val_auc = roc_auc_score(val_labels, val_preds)
    val_loss = val_loss / len(val_loader.dataset)

    scheduler.step(val_auc)

    ck = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_val_auc": best_val_auc,
        "val_auc": val_auc,
        "scheduler_state": scheduler.state_dict()  # optional
    }

    last_path = CHECKPOINT_DIR / "temporal_last.pth"
    save_checkpoint(ck, last_path)

    if val_auc > best_val_auc:
        best_val_auc = val_auc
        best_path = CHECKPOINT_DIR / "temporal_best_valAUC.pth"
        save_checkpoint(ck, best_path)
        print(f"Saved new best model at epoch {epoch} val_auc={val_auc:.4f}")

    # also save epoch checkpoint (optional)
    epoch_path = CHECKPOINT_DIR / f"temporal_epoch_{epoch}.pth"
    save_checkpoint(ck, epoch_path)

    print(f"Epoch {epoch} done. train_loss={train_loss:.4f} train_auc={train_auc:.4f} val_loss={val_loss:.4f} val_auc={val_auc:.4f} time={(time.time()-t0):.1f}s")


Epoch 0: 100%|██████████| 255/255 [00:03<00:00, 80.66batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.6622769e-06 0.999997


100%|██████████| 48/48 [00:00<00:00, 230.39it/s]


Saved new best model at epoch 0 val_auc=0.9944
Epoch 0 done. train_loss=0.0369 train_auc=0.9999 val_loss=0.1647 val_auc=0.9944 time=3.8s


Epoch 1: 100%|██████████| 255/255 [00:02<00:00, 88.70batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 6.219851e-07 0.9999994


100%|██████████| 48/48 [00:00<00:00, 227.80it/s]


Saved new best model at epoch 1 val_auc=0.9945
Epoch 1 done. train_loss=0.0029 train_auc=1.0000 val_loss=0.1918 val_auc=0.9945 time=3.5s


Epoch 2: 100%|██████████| 255/255 [00:02<00:00, 86.04batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 3.5436005e-07 0.99999976


100%|██████████| 48/48 [00:00<00:00, 226.48it/s]


Epoch 2 done. train_loss=0.0036 train_auc=1.0000 val_loss=0.1762 val_auc=0.9943 time=3.4s


Epoch 3: 100%|██████████| 255/255 [00:02<00:00, 88.28batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 9.946188e-08 0.9999999


100%|██████████| 48/48 [00:00<00:00, 238.50it/s]


Epoch 3 done. train_loss=0.0012 train_auc=1.0000 val_loss=0.2556 val_auc=0.9942 time=3.3s


Epoch 4: 100%|██████████| 255/255 [00:02<00:00, 87.50batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 7.178688e-08 0.9999999


100%|██████████| 48/48 [00:00<00:00, 223.92it/s]


Epoch 4 done. train_loss=0.0031 train_auc=1.0000 val_loss=0.1750 val_auc=0.9942 time=3.4s


Epoch 5: 100%|██████████| 255/255 [00:02<00:00, 86.16batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 4.0579252e-07 0.99999964


100%|██████████| 48/48 [00:00<00:00, 210.53it/s]


Epoch 5 done. train_loss=0.0017 train_auc=1.0000 val_loss=0.2094 val_auc=0.9943 time=3.4s


Epoch 6: 100%|██████████| 255/255 [00:03<00:00, 84.21batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.562776e-07 0.9999999


100%|██████████| 48/48 [00:00<00:00, 223.35it/s]


Epoch 6 done. train_loss=0.0010 train_auc=1.0000 val_loss=0.2388 val_auc=0.9944 time=3.5s


Epoch 7: 100%|██████████| 255/255 [00:02<00:00, 85.20batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.350048e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 225.35it/s]


Epoch 7 done. train_loss=0.0017 train_auc=1.0000 val_loss=0.2183 val_auc=0.9944 time=3.5s


Epoch 8: 100%|██████████| 255/255 [00:03<00:00, 83.91batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 5.2334407e-08 0.9999999


100%|██████████| 48/48 [00:00<00:00, 229.68it/s]


Epoch 8 done. train_loss=0.0011 train_auc=1.0000 val_loss=0.2372 val_auc=0.9944 time=3.5s


Epoch 9: 100%|██████████| 255/255 [00:02<00:00, 85.53batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 2.3164771e-08 0.9999999


100%|██████████| 48/48 [00:00<00:00, 227.49it/s]


Epoch 9 done. train_loss=0.0008 train_auc=1.0000 val_loss=0.2525 val_auc=0.9944 time=3.5s


Epoch 10: 100%|██████████| 255/255 [00:03<00:00, 84.46batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.8666995e-08 0.9999999


100%|██████████| 48/48 [00:00<00:00, 227.49it/s]


Epoch 10 done. train_loss=0.0007 train_auc=1.0000 val_loss=0.2650 val_auc=0.9944 time=3.5s


Epoch 11: 100%|██████████| 255/255 [00:03<00:00, 84.69batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.5976385e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 232.48it/s]


Epoch 11 done. train_loss=0.0006 train_auc=1.0000 val_loss=0.2697 val_auc=0.9944 time=3.5s


Epoch 12: 100%|██████████| 255/255 [00:02<00:00, 85.35batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.2679311e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 202.53it/s]


Epoch 12 done. train_loss=0.0006 train_auc=1.0000 val_loss=0.2731 val_auc=0.9944 time=3.5s


Epoch 13: 100%|██████████| 255/255 [00:03<00:00, 78.93batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.7801202e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 215.45it/s]


Epoch 13 done. train_loss=0.0006 train_auc=1.0000 val_loss=0.2788 val_auc=0.9944 time=3.7s


Epoch 14: 100%|██████████| 255/255 [00:03<00:00, 84.70batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 9.472233e-09 1.0


100%|██████████| 48/48 [00:00<00:00, 224.95it/s]


Epoch 14 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2816 val_auc=0.9944 time=3.5s


Epoch 15: 100%|██████████| 255/255 [00:03<00:00, 82.65batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.4027222e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 222.22it/s]


Epoch 15 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2859 val_auc=0.9944 time=3.6s


Epoch 16: 100%|██████████| 255/255 [00:03<00:00, 81.72batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.0025991e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 217.20it/s]


Epoch 16 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2895 val_auc=0.9944 time=3.6s


Epoch 17: 100%|██████████| 255/255 [00:03<00:00, 84.06batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.32518885e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 206.15it/s]


Epoch 17 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2913 val_auc=0.9944 time=3.5s


Epoch 18: 100%|██████████| 255/255 [00:03<00:00, 82.09batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.4221974e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 199.17it/s]


Epoch 18 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2926 val_auc=0.9944 time=3.7s


Epoch 19: 100%|██████████| 255/255 [00:03<00:00, 80.71batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.3673618e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 188.98it/s]


Epoch 19 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2943 val_auc=0.9944 time=3.7s


Epoch 20: 100%|██████████| 255/255 [00:03<00:00, 83.02batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.1235046e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 201.68it/s]


Epoch 20 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2953 val_auc=0.9944 time=3.6s


Epoch 21: 100%|██████████| 255/255 [00:03<00:00, 84.04batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.2008399e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 146.26it/s]


Epoch 21 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2968 val_auc=0.9944 time=3.6s


Epoch 22: 100%|██████████| 255/255 [00:03<00:00, 84.49batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.5468105e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 217.20it/s]


Epoch 22 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2977 val_auc=0.9944 time=3.5s


Epoch 23: 100%|██████████| 255/255 [00:03<00:00, 84.20batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.4891004e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 214.29it/s]


Epoch 23 done. train_loss=0.0005 train_auc=1.0000 val_loss=0.2982 val_auc=0.9944 time=3.5s


Epoch 24: 100%|██████████| 255/255 [00:03<00:00, 82.69batch/s]


NaNs in preds: False
NaNs in labels: False
Pred min/max: 1.29379965e-08 1.0


100%|██████████| 48/48 [00:00<00:00, 200.84it/s]


Epoch 24 done. train_loss=0.0004 train_auc=1.0000 val_loss=0.2990 val_auc=0.9944 time=3.6s


In [14]:
# After training: evaluate on test split (video-level)
best = CHECKPOINT_DIR / "temporal_best_valAUC.pth"
if best.exists():
    ck = torch.load(best, map_location=DEVICE)
    model.load_state_dict(ck["model_state"])
    print("Loaded best temporal model with val_auc:", ck.get("best_val_auc"))
    # test dataset and loader
    test_loader = DataLoader(VideoEmbeddingDataset("test"), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    model.eval()
    t_preds, t_labels, stems_all = [], [], []
    with torch.no_grad():
        for seqs, lengths, labels, stems in test_loader:
            seqs = seqs.to(DEVICE); lengths = lengths.to(DEVICE)
            
            logits, _ = model(seqs, lengths)
            t_preds.append(torch.sigmoid(logits).cpu())
            t_labels.append(labels)
            stems_all.extend(stems)
    t_preds = torch.cat(t_preds).numpy()
    t_labels = torch.cat(t_labels).numpy()
    print("Test AUC:", roc_auc_score(t_labels, t_preds))


Loaded best temporal model with val_auc: 0.9943908256371856
Test AUC: 0.9900270869244028
