# 03 - Extract per-frame embeddings (EfficientNet-B3 backbone)
This notebook produces embeddings/<split>/<video_stem>.npy for each video (shape: T x feat_dim).
It will use your local spatial checkpoint if you provide its path; otherwise it will use a timm pretrained EfficientNet-B3 backbone.
Run cell-by-cell. This is GPU-accelerated and resumable.


In [None]:
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import timm
from tqdm import tqdm
import json
import os

# ------------- USER CONFIG -------------
ROOT = Path.cwd()
FRAMES_ROOT = ROOT / "preprocessed" / "frames"   # <split>/<video_stem>/frame_00.jpg ...
EMB_ROOT = ROOT / "embeddings"                  # outputs go here
SPLITS = ["train", "val", "test"]
BATCH_FRAMES = 16      # how many frames to run through backbone at once (GPU memory)
IMG_SIZE = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SPATIAL_CKPT = None    # if you have a local checkpoint file path, set it (string) else keep None
# Example: SPATIAL_CKPT = "/home/me/checkpoints/spatial/spatial_best_valAUC.pth"
USE_TIMM_PRETRAINED = True   # if SPATIAL_CKPT is None, use timm pretrained weights
# ---------------------------------------

print("Device:", DEVICE)
print("Frames root:", FRAMES_ROOT)
print("Embeddings root:", EMB_ROOT)
EMB_ROOT.mkdir(parents=True, exist_ok=True)


In [None]:
from torchvision import transforms

# Keep transforms simple and deterministic - same normalization used for training typically
preprocess = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),    # gives [C,H,W], float in [0,1]
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])

def load_frame_tensor(p: Path):
    img = Image.open(str(p)).convert("RGB")
    return preprocess(img)  # tensor [3,H,W]


In [None]:
class FeatureBackbone(nn.Module):
    def __init__(self, model_name="efficientnet_b3", use_pretrained=True):
        super().__init__()
        # create backbone with no classifier head (num_classes=0 gives feature vector)
        self.backbone = timm.create_model(model_name, pretrained=use_pretrained, num_classes=0)
        self.feat_dim = self.backbone.num_features

    def forward(self, x):
        # x: [B,3,H,W] -> returns [B, feat_dim]
        return self.backbone(x)

# instantiate
print("Loading backbone...")
if SPATIAL_CKPT:
    # When user supplies a checkpoint, we still create the model with pretrained=False then load weights
    backbone = FeatureBackbone(use_pretrained=False)
    # load checkpoint tolerant to different keys
    ckpt = torch.load(SPATIAL_CKPT, map_location="cpu")
    if "model_state_dict" in ckpt:
        state = ckpt["model_state_dict"]
    elif "state_dict" in ckpt:
        state = ckpt["state_dict"]
    else:
        state = ckpt
    # strip "module." prefix if present
    new_state = {}
    for k,v in state.items():
        nk = k.replace("module.", "") if k.startswith("module.") else k
        # If checkpoint had a head, ignore head keys (we only expect backbone feature extractor)
        # try to map keys: keep those that match our model
        new_state[nk] = v
    missing, unexpected = backbone.backbone.load_state_dict(new_state, strict=False)
    print("Loaded checkpoint:", SPATIAL_CKPT)
    print("Missing keys:", missing)
    print("Unexpected keys:", unexpected)
else:
    backbone = FeatureBackbone(use_pretrained=USE_TIMM_PRETRAINED)
    print("Using timm pretrained EfficientNet-B3.")

backbone = backbone.to(DEVICE)
backbone.eval()
feat_dim = backbone.feat_dim
print("Backbone feature dim:", feat_dim)


In [None]:
def extract_embeddings_for_video(video_stem: str, split: str):
    frames_dir = FRAMES_ROOT / split / video_stem
    if not frames_dir.exists():
        return {"status":"missing_frames_folder"}
    out_path = EMB_ROOT / split / f"{video_stem}.npy"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    # resume if already exists
    if out_path.exists():
        return {"status":"exists"}

    # collect frame files sorted
    frame_files = sorted([p for p in frames_dir.glob("frame_*.jpg")])
    if len(frame_files) == 0:
        return {"status":"no_frames_found"}

    # load all frames tensors into a list
    tensors = [load_frame_tensor(p) for p in frame_files]  # list of [3,H,W]
    # batch through backbone in B x 3 x H x W
    embeddings = []
    with torch.no_grad():
        for i in range(0, len(tensors), BATCH_FRAMES):
            batch = torch.stack(tensors[i:i+BATCH_FRAMES], dim=0).to(DEVICE)  # [b,3,H,W]
            with autocast(enabled=(DEVICE.type=="cuda")):
                feats = backbone(batch)  # [b,feat_dim]
            feats = feats.detach().cpu().numpy()
            embeddings.append(feats)
    embeddings = np.vstack(embeddings)  # shape (T, feat_dim)
    np.save(str(out_path), embeddings)
    return {"status":"saved", "shape": embeddings.shape, "path": str(out_path)}

# run per split
summary = {}
for split in SPLITS:
    split_dir = FRAMES_ROOT / split
    if not split_dir.exists():
        print(f"Split frames folder not found: {split_dir} (skipping)")
        continue
    stems = sorted([p.name for p in split_dir.iterdir() if p.is_dir()])
    print(f"Processing split {split} - videos found: {len(stems)}")
    summary[split] = {"total": len(stems), "saved":0, "exists":0, "missing":0}
    for stem in tqdm(stems):
        res = extract_embeddings_for_video(stem, split)
        if res["status"] == "saved":
            summary[split]["saved"] += 1
        elif res["status"] == "exists":
            summary[split]["exists"] += 1
        else:
            summary[split]["missing"] += 1

# write manifest
EMB_ROOT.mkdir(parents=True, exist_ok=True)
with open(EMB_ROOT / "manifest.json", "w") as f:
    json.dump(summary, f, indent=2)

print("Done. Summary:")
print(summary)


In [None]:
# inspect one saved embedding to make sure shapes are right
for split in SPLITS:
    pdir = EMB_ROOT / split
    if pdir.exists():
        files = sorted(list(pdir.glob("*.npy")))
        if files:
            arr = np.load(files[0])
            print("Sample:", files[0].name, "shape:", arr.shape, "dtype:", arr.dtype)
            break
