In [13]:
# ===== Cell 1: Setup & imports =====
import os, sys, platform, pathlib, math, time, copy
import torch
import torch.nn.functional as F
from tqdm import tqdm

# repo-root import (프로젝트 루트 경로를 sys.path에 올리는 헬퍼)
import autorootcwd

# Windows에서 pickled checkpoint 로드시 PosixPath 이슈 회피
if platform.system() == 'Windows':
    pathlib.PosixPath = pathlib.WindowsPath

# parse_config 모듈을 직접 import
import scripts.parse_config as _pc
sys.modules['parse_config'] = _pc  # 전역 네임스페이스에 등록

from transformers import AutoTokenizer

# 프로젝트 모듈
from src.data.data_loader import TextVideoDataLoader
from src.model.model import FrozenInTime, compute_similarity
from src.model import metric as module_metric
from src.trainer.trainer import verbose
from src.utils.util import state_dict_data_parallel_fix

print(torch.__version__, platform.system())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


2.5.1+cu121 Windows
Device: cuda


In [14]:
# ===== Cell 2: Configuration =====

# --- Data settings (test.py와 동일 split/test 변환) ---
DL_KW = dict(
    dataset_name="NTU",
    text_params={"input": "text"},
    # ✔ FiT ckpt가 4-frame STFormer라면 dataloader도 4로 맞추는 게 가장 깔끔
    video_params={"extraction_fps": 25, "extraction_res": 320, "input_res": 224, "num_frames": 4, "stride": 1},
    data_dir="data/nturgbd",
    metadata_dir="data/nturgbd",
    split="test",
    tsfm_params={"input_res": 224, "center_crop": 224},
    tsfm_split="test",
    subsample=1,
    sliding_window_stride=-1,
    reader="decord",
    batch_size=8,
    num_workers=2,
    shuffle=False,
)

TEXT_MODEL_NAME = "distilbert-base-uncased"

# --- Backbone 1: Frozen-in-Time (SpaceTimeTransformer) ---
FIT_ARGS = dict(
    video_params={"model": "SpaceTimeTransformer",
                  "arch_config":"base_patch16_224",
                  "num_frames": 4,                     # ckpt와 정확히 일치
                  "vit_init":"imagenet-21k",
                  "attention_style":"frozen-in-time",
                  "pretrained": True},
    text_params={"model": TEXT_MODEL_NAME, "pretrained": True},
    projection_dim=256,
)

# --- Backbone 2: V-JEPA2 (모션 민감 표현)
# HF 접근이 막힐 수 있으면, 너의 FrozenInTime(VJEPA2) 구현에서 로컬 가중치 경로를 쓰도록 해둔 버전이면 그대로 동작함.
USE_VJEPA = True
VJEPA_ARGS = dict(
    video_params={"model": "VJEPA2", "num_frames": 32, "pretrained": True},
    text_params={"model": TEXT_MODEL_NAME, "pretrained": True},
    projection_dim=256,
)

# --- FiT checkpoint (4-frame STFormer) ---
FIT_CKPT_PATH = "src/exps/pretrained/cc-webvid2m-4f_stformer_b_16_224.pth.tar"  # 필요 시 수정

# --- Fusion settings ---
FUSION_MODE = "weighted"   # "weighted" 또는 "rrf"
ALPHA       = 0.85         # weighted일 때 FiT(텍스트정렬) 비중
TAU_CLIP    = 0.07         # FiT 점수 온도
TAU_MOT     = 0.07         # JEPA->256 점수 온도
RIDGE_LAMBDA = 1e-3        # JEPA->256 사상 릿지 정규화

# --- Loop control ---
MAX_BATCHES = None         # 빠른 테스트용으로 정수로 제한 가능 (예: 100)

print('Config OK')


Config OK


In [15]:
# ===== Cell 3: DataLoader & Tokenizer =====
data_loader = TextVideoDataLoader(**DL_KW)
print('Dataset length:', len(data_loader.dataset))

tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
print('Tokenizer loaded:', TEXT_MODEL_NAME)


Dataset length: 600
Tokenizer loaded: distilbert-base-uncased


In [16]:
# ===== Cell 4: Load models =====
fit_model = FrozenInTime(**FIT_ARGS).to(device).eval()

vjepa_model = None
if USE_VJEPA:
    try:
        vjepa_model = FrozenInTime(**VJEPA_ARGS).to(device).eval()
        print('[V-JEPA2] model ready')
    except Exception as e:
        print('[V-JEPA2] init failed:', e)
        vjepa_model = None

# FiT checkpoint 로드 (4-frame STFormer과 strict하게 맞춤)
print('Loading FiT checkpoint:', FIT_CKPT_PATH)
ckpt = torch.load(FIT_CKPT_PATH, map_location='cpu', weights_only=False)

sd = state_dict_data_parallel_fix(ckpt['state_dict'], fit_model.state_dict())
missing, unexpected = fit_model.load_state_dict(sd, strict=False)
if missing or unexpected:
    print('[FiT] load_state_dict non-strict:', 'missing', len(missing), 'unexpected', len(unexpected))
else:
    print('[FiT] checkpoint loaded strictly')


######USING ATTENTION STYLE:  frozen-in-time
[V-JEPA2] model ready
Loading FiT checkpoint: src/exps/pretrained/cc-webvid2m-4f_stformer_b_16_224.pth.tar
[FiT] checkpoint loaded strictly


In [17]:
# ===== Cell 5: Extract embeddings (FiT text/video, V-JEPA2 video) =====
txt_emb_list, vid_fit_list, vid_jepa_list, labels = [], [], [], []

fit_frames = FIT_ARGS['video_params'].get('num_frames', 4)

with torch.no_grad():
    for bi, batch in enumerate(tqdm(data_loader, desc='Embedding')):
        if MAX_BATCHES is not None and bi >= MAX_BATCHES:
            break

        # 텍스트
        texts = batch['text']
        toks = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device)

        # 비디오
        video = batch['video'].to(device)                     # (B, T, C, H, W)
        video_fit = video[:, :fit_frames].contiguous()        # ✔ FiT가 기대하는 프레임 이하로 보장 + contiguous()

        # 라벨
        caps = batch['meta']['raw_captions']
        if isinstance(caps, list):
            labels.extend(caps)
        else:
            labels.append(caps)

        # FiT: text & video
        out_t, out_v = fit_model({'text': toks, 'video': video_fit})
        out_t = F.normalize(out_t, dim=-1)
        out_v = F.normalize(out_v, dim=-1)
        txt_emb_list.append(out_t.cpu())
        vid_fit_list.append(out_v.cpu())

        # V-JEPA2: video only
        if vjepa_model is not None:
            vj = vjepa_model.compute_video(video)  # 너의 FrozenInTime 구현에서 내부적으로 V-JEPA2 경로 처리
            vid_jepa_list.append(vj.cpu())

txt_emb  = torch.cat(txt_emb_list, dim=0)    # [N, 256]
vid_fit  = torch.cat(vid_fit_list, dim=0)    # [N, 256]
vid_jepa = torch.cat(vid_jepa_list, dim=0) if len(vid_jepa_list) > 0 else None

print('Shapes -> txt:', tuple(txt_emb.shape), 'fit_vid:', tuple(vid_fit.shape), 'jepa_vid:', None if vid_jepa is None else tuple(vid_jepa.shape))
print('Samples (labels):', len(labels))


Embedding: 100%|██████████| 75/75 [01:20<00:00,  1.07s/it]

Shapes -> txt: (600, 256) fit_vid: (600, 256) jepa_vid: (600, 256)
Samples (labels): 600





In [18]:
# ===== Cell 6: JEPA -> 256 mapping (ridge regression with centering) =====
vid_jepa_256 = None
W = None

if vid_jepa is not None:
    X = vid_jepa.to(device)           # [N, Dj]   (JEPA 원공간)
    Y = vid_fit.to(device)            # [N, 256]  (FiT 비디오 임베딩)

    # 중앙화(centering)
    Xm = X.mean(dim=0, keepdim=True)
    Ym = Y.mean(dim=0, keepdim=True)
    Xc = X - Xm
    Yc = Y - Ym

    lam = RIDGE_LAMBDA
    XT = Xc.t()
    XtX = XT @ Xc
    reg = lam * torch.eye(XtX.size(0), device=X.device, dtype=X.dtype)
    W = torch.linalg.solve(XtX + reg, XT @ Yc)  # [Dj, 256]

    vid_jepa_256 = (X - Xm) @ W + Ym
    vid_jepa_256 = F.normalize(vid_jepa_256, dim=-1)

print('JEPA->256:', None if vid_jepa_256 is None else tuple(vid_jepa_256.shape))


JEPA->256: (600, 256)


In [19]:
# ===== Cell 7: Score computation & fusion =====

def zscore_rows(M: torch.Tensor) -> torch.Tensor:
    mu = M.mean(dim=1, keepdim=True)
    sd = M.std(dim=1, keepdim=True).clamp_min(1e-6)
    return (M - mu) / sd

def rrf(sim: torch.Tensor, k: int = 60) -> torch.Tensor:
    # Reciprocal Rank Fusion (스케일 독립적 late fusion)
    ranks = sim.argsort(dim=1, descending=True)
    rrf_scores = torch.zeros_like(sim)
    N = sim.size(1)
    inv = 1.0 / (k + torch.arange(1, N+1, device=sim.device, dtype=sim.dtype))
    for i in range(sim.size(0)):
        rrf_scores[i, ranks[i]] = inv
    return rrf_scores

txt_n = F.normalize(txt_emb.to(device), dim=-1)
fit_n = F.normalize(vid_fit.to(device), dim=-1)

# (1) CLIP-like score (FiT text ↔ FiT video)
S_clip = (txt_n @ fit_n.t()) / TAU_CLIP

# (2) Motion score (FiT text ↔ (JEPA→256) video)
if vid_jepa_256 is not None:
    S_motion = (txt_n @ vid_jepa_256.t()) / TAU_MOT
    if FUSION_MODE == "weighted":
        S_fused = ALPHA * zscore_rows(S_clip) + (1 - ALPHA) * zscore_rows(S_motion)
    elif FUSION_MODE == "rrf":
        S_fused = rrf(S_clip) + rrf(S_motion)
    else:
        raise ValueError(f"Unknown FUSION_MODE: {FUSION_MODE}")
else:
    print('[WARN] No JEPA embeddings — using FiT-only scores.')
    S_motion = None
    S_fused = S_clip

print('Score shapes:', tuple(S_clip.shape), None if S_motion is None else tuple(S_motion.shape), tuple(S_fused.shape))


Score shapes: (600, 600) (600, 600) (600, 600)


In [20]:
# === (Fix) NTU: 텍스트를 라벨별 평균으로 60개 쿼리로 축약 ===
import torch
import torch.nn.functional as F

# 1) 라벨 인덱스 모음
unique_labels = sorted(set(labels))
lab2idx = {lab: [] for lab in unique_labels}
for i, lab in enumerate(labels):
    lab2idx[lab].append(i)

# 2) 라벨별 텍스트 임베딩 평균 (기존 txt_emb는 인스턴스 단위)
txt_lab = []
for lab in unique_labels:
    idxs = torch.tensor(lab2idx[lab], dtype=torch.long)
    txt_lab.append(txt_emb[idxs].mean(dim=0))
txt_lab = torch.stack(txt_lab, dim=0)   # [L=60, 256]
txt_lab = F.normalize(txt_lab, dim=-1).to(device)

# 3) 비디오 임베딩 정규화(이미 되어 있으면 생략 가능)
fit_n = F.normalize(vid_fit.to(device), dim=-1)
# vjepa_256이 없는 경우도 처리
vjepa_n = F.normalize(vid_jepa_256.to(device), dim=-1) if 'vid_jepa_256' in globals() and vid_jepa_256 is not None else None

# 4) 라벨-수준 점수 재계산
S_clip_lab = (txt_lab @ fit_n.t()) / TAU_CLIP                # [60, N]
if vjepa_n is not None:
    S_motion_lab = (txt_lab @ vjepa_n.t()) / TAU_MOT         # [60, N]
    if FUSION_MODE == "weighted":
        # z-score는 행(각 쿼리) 단위로
        mu1, sd1 = S_clip_lab.mean(dim=1, keepdim=True), S_clip_lab.std(dim=1, keepdim=True).clamp_min(1e-6)
        mu2, sd2 = S_motion_lab.mean(dim=1, keepdim=True), S_motion_lab.std(dim=1, keepdim=True).clamp_min(1e-6)
        S_fused_lab = ALPHA * (S_clip_lab - mu1) / sd1 + (1 - ALPHA) * (S_motion_lab - mu2) / sd2
    elif FUSION_MODE == "rrf":
        def rrf(sim, k=60):
            ranks = sim.argsort(dim=1, descending=True)
            out = torch.zeros_like(sim)
            N = sim.size(1)
            inv = 1.0 / (k + torch.arange(1, N+1, device=sim.device, dtype=sim.dtype))
            for i in range(sim.size(0)):
                out[i, ranks[i]] = inv
            return out
        S_fused_lab = rrf(S_clip_lab) + rrf(S_motion_lab)
    else:
        raise ValueError(f"Unknown FUSION_MODE: {FUSION_MODE}")
else:
    S_fused_lab = S_clip_lab

print('Shapes(label-level):', S_clip_lab.shape, S_fused_lab.shape)  # [60, N]

# 5) 메트릭 입력은 라벨-수준 유사도와, 비디오별 라벨 목록
sims_np = S_fused_lab.detach().cpu().numpy()
action_labels = labels  # 길이 N (비디오 열에 해당)

# 그대로 metric 계산 실행
from src.model import metric as module_metric
from src.trainer.trainer import verbose

metric_fns = [getattr(module_metric, m) for m in ["ntu_t2v_metrics", "ntu_v2t_metrics"]]
nested_metrics = {}
for met in metric_fns:
    name = met.__name__
    res = met(sims_np, action_labels, query_masks=None)
    verbose(epoch=0, metrics=res, name="", mode=name)
    nested_metrics[name] = res

nested_metrics


Shapes(label-level): torch.Size([60, 600]) torch.Size([60, 600])
[ntu_t2v_metrics] epoch 0, R@1: 10.0, R@5: 25.0, R@10 36.7, R@50 78.3MedR: 20, MeanR: 32.8
[ntu_v2t_metrics] epoch 0, R@1: 5.2, R@5: 15.5, R@10 25.8, R@50 88.2MedR: 23, MeanR: 25.4


{'ntu_t2v_metrics': {'R1': np.float64(10.0),
  'R5': np.float64(25.0),
  'R10': np.float64(36.666666666666664),
  'R50': np.float64(78.33333333333333),
  'MedR': np.float64(20.0),
  'MeanR': np.float64(32.766666666666666),
  'query_ranks': array([  7,  35, 155,  87,  20,   9,   3,  15,  53,  43,  49,  20,  23,
           2,  55,   8,  10,   1,  43,  42,  52,  54,  18,  38,  43,   3,
          14,  23, 332,  40,  53,  72,   8,   1,   1,  25,  24,  21,   2,
           7, 111,  55,  12,   7,  12,  30,  17,   2,   2,   1,  14,  29,
          54,   5,  68,   2,   1,  28,   4,   1])},
 'ntu_v2t_metrics': {'R1': np.float64(5.166666666666667),
  'R5': np.float64(15.5),
  'R10': np.float64(25.833333333333336),
  'R50': np.float64(88.16666666666667),
  'MedR': np.float64(23.0),
  'MeanR': np.float64(25.421666666666667),
  'query_ranks': array([21, 10, 38, 57, 42, 17,  3, 21, 21, 35, 38, 23, 53, 30, 42,  5,  1,
         37, 35,  8, 53, 55, 28, 10, 49, 39, 53, 39, 55, 58, 36, 26, 32, 24,
         