In [19]:
# ======================
# Block 1: 라이브러리 & 설정
# ======================
import autorootcwd
import sys

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, CLIPTokenizer
from scripts.parse_config import ConfigParser
import torch.serialization
import src.model.model as module_arch
import src.data.data_loader as module_data
import src.model.metric as module_metric
from src.utils.util import state_dict_data_parallel_fix
from src.model.model import compute_similarity
from src.model.text_augmentation import augment_text_labels, average_augmented_embeddings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [21]:
# ===== Block 1: 공통 설정 =====
import torch, torch.nn.functional as F
from transformers import AutoTokenizer
from src.data.data_loader import TextVideoDataLoader
from src.model.model import FrozenInTime, compute_similarity

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DL_KW = dict(
    dataset_name="NTU",
    text_params={"input": "text"},
    video_params={"extraction_fps": 25, "extraction_res": 320, "input_res": 224, "num_frames": 4, "stride": 1},
    data_dir="data/nturgbd",
    metadata_dir="data/nturgbd",
    split="test",
    tsfm_params={"input_res": 224, "center_crop": 224},
    tsfm_split="test",
    subsample=1,
    sliding_window_stride=-1,
    reader="decord",
    batch_size=8,
    num_workers=2,
    shuffle=False,
)

TEXT_MODEL_NAME = "distilbert-base-uncased"

FIT_ARGS = dict(
    video_params=dict(model="SpaceTimeTransformer",
                      arch_config="base_patch16_224",
                      num_frames=4,
                      vit_init="imagenet-21k",
                      attention_style="frozen-in-time",
                      pretrained=True),
    text_params=dict(model=TEXT_MODEL_NAME, pretrained=True),
    projection_dim=256,
    projection='minimal',
)

VJEPA_ARGS = dict(
    video_params=dict(model="VJEPA2", num_frames=32, pretrained=True),  # hf_repo나 로컬 가중치 쓸거면 여기에 키 추가
    text_params=dict(model=TEXT_MODEL_NAME, pretrained=True),
    projection_dim=256,
    projection='',  # ★ 프로젝션 끔: 원 임베딩 쓰기
)

ALPHA = 0.7  # S_fused = α*S_clip + (1-α)*S_motion


In [None]:
# ===== Block 2: 데이터/토크나이저 =====
data_loader = TextVideoDataLoader(**DL_KW)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)

# 텍스트 캡션 전체 수집 (라벨/메타도 함께)
all_texts = []
all_video_paths = []
for batch in data_loader:
    all_texts += batch['text']
    all_video_paths += batch['meta']['paths'] if isinstance(batch['meta']['paths'], list) else [batch['meta']['paths']]
len(all_texts), len(all_video_paths)

(600, 600)

In [22]:
fit_model = FrozenInTime(**FIT_ARGS).to(device).eval()
vjepa_model = FrozenInTime(**VJEPA_ARGS).to(device).eval()

ckpt_path = "src/exps/pretrained/cc-webvid2m-4f_stformer_b_16_224.pth.tar"
import torch
from src.utils.util import state_dict_data_parallel_fix

ckpt = torch.load(ckpt_path, map_location='cuda', weights_only=False)

sd = state_dict_data_parallel_fix(ckpt['state_dict'], fit_model.state_dict())
fit_model.load_state_dict(sd, strict=True)

######USING ATTENTION STYLE:  frozen-in-time


<All keys matched successfully>

In [27]:
# ===== Block 4: 임베딩 추출 (FiT: text & video, V‑JEPA2: video만) =====
import torch.nn.functional as F

txt_emb_list, vid_fit_list, vid_jepa_list, labels = [], [], [], []

with torch.no_grad():
    for batch in data_loader:
        video = batch['video'].to(device)     # (B, T, C, H, W)
        texts = batch['text']
        labels += batch['meta']['raw_captions'] if isinstance(batch['meta']['raw_captions'], list) else [batch['meta']['raw_captions']]

        toks = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)

        # --- FiT는 4프레임 체크포인트 가정 ---
        F_fit = fit_model.video_params.get('num_frames', 4)
        T = video.shape[1]
        if T < F_fit:
            # 0-패딩으로 프레임 보충
            pad = F_fit - T
            pad_zeros = torch.zeros(video.size(0), pad, *video.shape[2:], device=video.device, dtype=video.dtype)
            video_fit = torch.cat([video, pad_zeros], dim=1)
        else:
            video_fit = video[:, :F_fit]

        video_fit = video_fit.contiguous()

        # --- FiT: text & (4프레임) video ---
        data_fit = {'text': toks, 'video': video_fit}   # ← 여기!
        t_fit, v_fit = fit_model(data_fit)
        t_fit = F.normalize(t_fit, dim=-1)
        v_fit = F.normalize(v_fit, dim=-1)

        # --- V-JEPA2: 32프레임 그대로 사용 ---
        v_jepa = vjepa_model.compute_video(video)       # (B, Dj)
        v_jepa = F.normalize(v_jepa, dim=-1)

        # CPU로 모아서 나중에 cat (메모리 안전)
        txt_emb_list.append(t_fit.cpu())
        vid_fit_list.append(v_fit.cpu())
        vid_jepa_list.append(v_jepa.cpu())

txt_emb = torch.cat(txt_emb_list, dim=0)   # [N, D]
vid_fit = torch.cat(vid_fit_list, dim=0)   # [N, D]
vid_jepa = torch.cat(vid_jepa_list, dim=0) # [N, Dj]


In [28]:
# ===== Block 5: V‑JEPA2 임베딩 표준화 + L2 정규화 =====
# 데이터셋 통계로 간단 표준화 (평균/표준편차); 그 다음 L2 normalize
vj_mean = vid_jepa.mean(dim=0, keepdim=True)
vj_std  = vid_jepa.std(dim=0, keepdim=True) + 1e-6
vid_jepa_std = (vid_jepa - vj_mean) / vj_std
vid_jepa_std = F.normalize(vid_jepa_std, dim=-1)


In [34]:
import torch
import torch.nn.functional as F
from collections import defaultdict

ALPHA = 0.7  # 가중합 비율 (필요시 조정)

# 현재:
#   txt_emb: [N, 256]  (샘플별 텍스트 임베딩)
#   vid_fit: [N, 256]  (샘플별 FiT 비디오 임베딩)
#   vid_jepa_256: [N, 256] (샘플별 V-JEPA를 256으로 사상한 임베딩)
#   labels: 길이 N, 각 샘플의 action 라벨 문자열

# 1) 라벨 → 인덱스들
idxs_per_label = defaultdict(list)
for i, lab in enumerate(labels):
    idxs_per_label[lab].append(i)

# 2) 라벨별 텍스트 임베딩 만들기 (평균 + L2 정규화)
unique_labels = sorted(list(idxs_per_label.keys()))
T_cls = []
for lab in unique_labels:
    idxs = idxs_per_label[lab]
    t_mean = txt_emb[idxs].mean(dim=0, keepdim=True)      # [1, 256]
    t_mean = F.normalize(t_mean, dim=-1)
    T_cls.append(t_mean)
T_cls = torch.cat(T_cls, dim=0).to(txt_emb.device)        # [L, 256], L=유니크 라벨 수

# 3) 비디오 임베딩도 L2 정규화 (안 돼있다면)
V_fit = F.normalize(vid_fit, dim=-1).to(T_cls.device)         # [N, 256]
V_j256 = F.normalize(vid_jepa_256, dim=-1).to(T_cls.device)   # [N, 256]

# 4) 라벨-비디오 유사도 계산
#   CLIP 스타일: 라벨 텍스트 T_cls ⟷ FiT 비디오 V_fit
S_clip_cls = T_cls @ V_fit.T          # [L, N]
#   모션 점수:   라벨 텍스트 T_cls ⟷ V-JEPA(256) V_j256
S_motion_cls = T_cls @ V_j256.T       # [L, N]
#   가중합
S_fused_cls = ALPHA * S_clip_cls + (1 - ALPHA) * S_motion_cls  # [L, N]

# 5) numpy로 변환
sims_np = S_fused_cls.detach().cpu().numpy()


In [39]:
from src.model import metric as module_metric
from src.trainer.trainer import verbose

metric_fns = [getattr(module_metric, m) for m in ["ntu_t2v_metrics", "ntu_v2t_metrics"]]

nested_metrics = {}
for met in metric_fns:
    name = met.__name__
    res = met(sims_np, labels, query_masks=None)   # sims: [L, N], labels: 길이 N
    verbose(epoch=0, metrics=res, name="", mode=name)
    nested_metrics[name] = res

nested_metrics


AssertionError: 쿼리 수(600)와 유니크 라벨 수(60)가 맞지 않음