# Extract embeddings from trained spatial model
Saves embeddings/<split>/<video_stem>.npy (shape T x feat_dim).
Be sure the checkpoint path below points to your trained checkpoint (spatial_best_valAUC.pth).


In [1]:
from pathlib import Path
import json, os
from PIL import Image
import numpy as np
import torch
from torch.cuda.amp import autocast, GradScaler
from torch import nn
import timm
from torchvision import transforms
from tqdm import tqdm

# CONFIGURATION

ROOT = Path.cwd().parent
FRAMES_ROOT = ROOT / "preprocessed" / "frames"
EMB_ROOT = ROOT / "embeddings"
CHECKPOINT_DIR = ROOT / "checkpoints" / "spatial"
CHECKPOINT_PATH = CHECKPOINT_DIR / "spatial_best_valAUC.pth"   # change if using spatial_last.pth
SPLITS = ["train", "val", "test"]
IMG_SIZE = 224
BATCH = 16                # frames per batch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMB_ROOT.mkdir(parents=True, exist_ok=True)

print("Device:", DEVICE)
print("Checkpoint path:", CHECKPOINT_PATH)


Device: cuda
Checkpoint path: c:\Users\lkmah\OneDrive\Desktop\Lokesh\VS Code\DeepFake_Detection_SIC\checkpoints\spatial\spatial_best_valAUC.pth


In [2]:
# define model architecture (same as during training)
class SpatialModel(nn.Module):
    def __init__(self, backbone_name="efficientnet_b3", pretrained=False, head_hidden=512, dropout=0.4):
        super().__init__()
        # backbone as during training: num_classes=0 to output features
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0)
        feat_dim = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Linear(feat_dim, head_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(head_hidden, 1)
        )

    def forward(self, x):
        feats = self.backbone(x)
        logits = self.head(feats).squeeze(1)
        return logits

# instantiate model and load checkpoint
model = SpatialModel(pretrained=False).to(DEVICE)
if not CHECKPOINT_PATH.exists():
    raise FileNotFoundError(f"Checkpoint not found: {CHECKPOINT_PATH}")

ck = torch.load(CHECKPOINT_PATH, map_location="cpu")
# ck may have keys like 'model_state' depending on saver; handle both forms
state = ck.get("model_state", ck)
# If state dict has 'module.' prefixes or similar, load strict=False
try:
    model.load_state_dict(state, strict=False)
except Exception as e:
    # try to strip 'module.' prefix if present
    new_state = {}
    for k,v in state.items():
        nk = k.replace("module.", "") if k.startswith("module.") else k
        new_state[nk] = v
    model.load_state_dict(new_state, strict=False)

model.eval()
model.to(DEVICE)

FEAT_DIM = model.backbone.num_features
print("Loaded model. Backbone feat dim:", FEAT_DIM)


Loaded model. Backbone feat dim: 1536


In [3]:
# define image preprocessing (same as during training)

preprocess = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])

def load_frame_tensor(p: Path):
    img = Image.open(str(p)).convert("RGB")
    t = preprocess(img)
    return t


In [4]:
# main function to extract and save embeddings for a single video
def extract_for_video(split, video_stem, batch_size=BATCH):
    frames_dir = FRAMES_ROOT / split / video_stem
    if not frames_dir.exists():
        return {"status":"missing_frames"}
    out_path = EMB_ROOT / split / f"{video_stem}.npy"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if out_path.exists():
        return {"status":"exists"}

    frame_files = sorted(list(frames_dir.glob("frame_*.jpg")))
    if len(frame_files) == 0:
        return {"status":"no_frames"}

    # create tensors list
    tensors = [load_frame_tensor(p) for p in frame_files]
    embeddings = []
    with torch.no_grad():
        for i in range(0, len(tensors), batch_size):
            batch = torch.stack(tensors[i:i+batch_size], dim=0).to(DEVICE)
            # use autocast (mixed precision) for speed if CUDA available
            with autocast(enabled=(DEVICE.type=="cuda")):
                feats = model.backbone(batch)   # [b, feat_dim]
            
            embeddings.append(feats.detach().cpu().float().numpy())
    embeddings = np.vstack(embeddings)  # (T, feat_dim)
    np.save(str(out_path), embeddings)
    return {"status":"saved", "shape": embeddings.shape}


In [7]:
manifest = {}
# smoke test toggle: set to True to test a few videos only
SMOKE_TEST = False
SMOKE_COUNT = 20

for split in SPLITS:
    split_dir = FRAMES_ROOT / split
    if not split_dir.exists():
        print("Skipping missing split:", split)
        continue
    stems = sorted([p.name for p in split_dir.iterdir() if p.is_dir()])
    if SMOKE_TEST:
        stems = stems[:SMOKE_COUNT]
    print(f"Extracting embeddings: split={split} videos={len(stems)}")
    manifest[split] = {"total": len(stems), "saved":0, "exists":0, "missing":0, "no_frames":0}
    for stem in tqdm(stems):
        res = extract_for_video(split, stem, batch_size=BATCH)
        s = res["status"]
        if s == "saved":
            manifest[split]["saved"] += 1
        elif s == "exists":
            manifest[split]["exists"] += 1
        elif s == "missing_frames":
            manifest[split]["missing"] += 1
        elif s == "no_frames":
            manifest[split]["no_frames"] += 1
    print(split, "done:", manifest[split])

# Save manifest
with open(EMB_ROOT / "manifest.json", "w") as f:
    json.dump(manifest, f, indent=2)

print("Wrote embeddings manifest:", EMB_ROOT / "manifest.json")


Extracting embeddings: split=train videos=4066


  with autocast(enabled=(DEVICE.type=="cuda")):
100%|██████████| 4066/4066 [01:42<00:00, 39.54it/s]


train done: {'total': 4066, 'saved': 4046, 'exists': 20, 'missing': 0, 'no_frames': 0}
Extracting embeddings: split=val videos=762


100%|██████████| 762/762 [00:20<00:00, 37.34it/s]


val done: {'total': 762, 'saved': 741, 'exists': 20, 'missing': 0, 'no_frames': 1}
Extracting embeddings: split=test videos=255


100%|██████████| 255/255 [00:06<00:00, 40.83it/s]

test done: {'total': 255, 'saved': 235, 'exists': 20, 'missing': 0, 'no_frames': 0}
Wrote embeddings manifest: c:\Users\lkmah\OneDrive\Desktop\Lokesh\VS Code\DeepFake_Detection_SIC\embeddings\manifest.json





In [6]:
# inspect one saved embedding
for split in SPLITS:
    pdir = EMB_ROOT / split
    if pdir.exists():
        files = sorted(list(pdir.glob("*.npy")))
        if files:
            arr = np.load(files[0])
            print("Sample:", files[0].name, "shape:", arr.shape)
            break


Sample: 000.npy shape: (8, 1536)
