# Korean Video Captioning Demo

**목적**: 학습된 모델로 샘플 비디오에 대한 한국어 캡션 생성 시연

**모델 구조**: CLIP-ViT-L/14 (Vision Encoder) + Projector + Qwen3-8B (LLM + LoRA)

**체크포인트**: `siglip_study/results3/` 에서 학습된 모델

**데이터**: AI-Hub 대한민국 배경영상 (aihub_splitted)

## 1. Google Drive 마운트 및 환경 설정

In [None]:
# Google Drive 마운트
from google.colab import drive
drive.mount('/content/drive')

# 경로 설정
DRIVE_ROOT = "/content/drive/MyDrive/mutsa-02"
DATA_PATH = f"{DRIVE_ROOT}/aihub_splitted"
RESULTS_DIR = f"{DRIVE_ROOT}/korean_video_captioning/siglip_study/results3"

import os
from pathlib import Path

print(f"DRIVE_ROOT:   {DRIVE_ROOT}")
print(f"DATA_PATH:    {DATA_PATH}")
print(f"RESULTS_DIR:  {RESULTS_DIR}")

# 데이터 구조 확인
print(f"\n{'='*60}")
print("Data Structure Check:")
print(f"{'='*60}")

if os.path.exists(DATA_PATH):
    for split in ["train", "val"]:
        split_path = Path(DATA_PATH) / split
        if split_path.exists():
            items = list(split_path.iterdir())
            print(f"\n{split}/ ({len(items)} items):")
            
            # labels/videos 폴더가 있는지 확인
            has_labels = (split_path / "labels").exists()
            has_videos = (split_path / "videos").exists()
            
            if has_labels and has_videos:
                labels = list((split_path / "labels").glob("*.json"))
                videos = list((split_path / "videos").glob("*.mp4"))
                print(f"  Structure: labels/videos folders")
                print(f"  Labels: {len(labels)}, Videos: {len(videos)}")
            else:
                # 샘플 폴더들인지 확인
                sample_dirs = [d for d in items if d.is_dir()]
                print(f"  Structure: {len(sample_dirs)} sample folders")
                if sample_dirs:
                    sample = sample_dirs[0]
                    print(f"  Example: {sample.name}/")
                    for f in sample.iterdir():
                        print(f"    - {f.name}")
    print(f"\n✅ Data found!")
else:
    print(f"\n❌ Data path not found!")

# 사용 가능한 실험 확인
print(f"\n{'='*60}")
print("Available experiments:")
print(f"{'='*60}")
EXPERIMENTS = [
    ("E1_v2_linear_optimized", "linear"),
    ("E3_v2_mlp_optimized", "mlp_2l"),
    ("E5_v2_lr_reduced", "c_abstractor"),
    ("E5_v3_epoch_increased", "c_abstractor"),
    ("E7_v2_perceiver_optimized", "perceiver"),
]
for exp_name, proj_type in EXPERIMENTS:
    ckpt_path = f"{RESULTS_DIR}/{exp_name}/checkpoints/best_model.pt"
    status = "✅" if os.path.exists(ckpt_path) else "❌"
    print(f"  {status} {exp_name} ({proj_type})")

In [None]:
# 필요한 패키지 설치 (Colab)
!pip install -q transformers accelerate bitsandbytes peft
!pip install -q opencv-python-headless

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. 모델 정의 (siglip_study와 동일한 구조)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    CLIPVisionModel, CLIPImageProcessor,
    AutoModelForCausalLM, AutoTokenizer,
    BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig, TaskType
from pathlib import Path
import json
import random
import numpy as np
from PIL import Image
import cv2
from tqdm import tqdm

# 설정 (siglip_study와 동일)
CONFIG = {
    "vision_encoder": "openai/clip-vit-large-patch14-336",
    "llm": "Qwen/Qwen3-8B",
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "num_frames": 8,
    "max_length": 768,
    "num_queries": 64,  # C-Abstractor, Perceiver용
    "num_heads": 8,
    "num_layers": 2,
    "prompt": "이 영상을 자세히 설명해주세요.",
    "max_new_tokens": 256,
    "repetition_penalty": 1.2,
    "data_path": DATA_PATH,
    "results_dir": RESULTS_DIR,
}

print("Config loaded!")

In [None]:
# ============================================
# Projector 정의 (siglip_study와 동일한 구조)
# ============================================

class LinearProjector(nn.Module):
    """Linear Projector (4M params)"""
    def __init__(self, vision_dim=1024, llm_dim=4096):
        super().__init__()
        self.proj = nn.Linear(vision_dim, llm_dim)
    def forward(self, x):
        return self.proj(x)

class MLPProjector(nn.Module):
    """MLP-2L Projector (8M params)"""
    def __init__(self, vision_dim=1024, llm_dim=4096, hidden_dim=4096):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(vision_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, llm_dim)
        )
    def forward(self, x):
        return self.proj(x)

class CAbstractor(nn.Module):
    """C-Abstractor Projector (206M params)"""
    def __init__(self, num_queries=64, vision_dim=1024, llm_dim=4096, num_heads=8, num_layers=2):
        super().__init__()
        self.queries = nn.Parameter(torch.randn(num_queries, llm_dim) * 0.02)
        self.vision_proj = nn.Linear(vision_dim, llm_dim)
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model=llm_dim, nhead=num_heads, dim_feedforward=llm_dim*4, batch_first=True)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(llm_dim)
    def forward(self, x):
        B = x.size(0)
        x = self.vision_proj(x)
        q = self.queries.unsqueeze(0).expand(B, -1, -1)
        for layer in self.layers:
            q = layer(q, x)
        return self.norm(q)

# ============================================
# Perceiver Resampler (siglip_study 원본 구조)
# ============================================
class PerceiverLayer(nn.Module):
    """Perceiver Layer with Self-Attention + Cross-Attention"""
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim*4), 
            nn.GELU(), 
            nn.Dropout(dropout), 
            nn.Linear(dim*4, dim), 
            nn.Dropout(dropout)
        )
    
    def forward(self, queries, context):
        q = self.norm1(queries + self.self_attn(queries, queries, queries)[0])
        q = self.norm2(q + self.cross_attn(q, context, context)[0])
        return self.norm3(q + self.ffn(q))

class PerceiverResampler(nn.Module):
    """Perceiver Resampler Projector (siglip_study 원본 구조)"""
    def __init__(self, vision_dim=1024, llm_dim=4096, num_queries=64, num_heads=8, num_layers=2, dropout=0.1):
        super().__init__()
        self.queries = nn.Parameter(torch.randn(num_queries, llm_dim) * 0.02)
        self.input_proj = nn.Linear(vision_dim, llm_dim)
        self.layers = nn.ModuleList([PerceiverLayer(llm_dim, num_heads, dropout) for _ in range(num_layers)])
        self.output_norm = nn.LayerNorm(llm_dim)
    
    def forward(self, x):
        # x: [B, seq_len, vision_dim]
        B = x.size(0)
        context = self.input_proj(x)  # [B, seq_len, llm_dim]
        queries = self.queries.unsqueeze(0).expand(B, -1, -1)  # [B, num_queries, llm_dim]
        for layer in self.layers:
            queries = layer(queries, context)
        return self.output_norm(queries)  # [B, num_queries, llm_dim]


def create_projector(projector_type, vision_dim=1024, llm_dim=4096, config=None):
    """Projector 생성"""
    if projector_type == "linear":
        return LinearProjector(vision_dim, llm_dim)
    elif projector_type == "mlp_2l":
        return MLPProjector(vision_dim, llm_dim)
    elif projector_type == "c_abstractor":
        return CAbstractor(
            num_queries=config.get("num_queries", 64),
            vision_dim=vision_dim,
            llm_dim=llm_dim,
            num_heads=config.get("num_heads", 8),
            num_layers=config.get("num_layers", 2)
        )
    elif projector_type == "perceiver":
        return PerceiverResampler(
            vision_dim=vision_dim,
            llm_dim=llm_dim,
            num_queries=config.get("num_queries", 64),
            num_heads=config.get("num_heads", 8),
            num_layers=config.get("num_layers", 2)
        )
    else:
        raise ValueError(f"Unknown projector type: {projector_type}")

print("Projector classes defined!")

In [None]:
# ============================================
# CustomVLM 정의 (dtype 처리 포함)
# ============================================

class CustomVLM(nn.Module):
    """Vision-Language Model: CLIP + Projector + Qwen3-8B"""
    def __init__(self, vision_encoder, projector, llm, tokenizer):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.projector = projector
        self.llm = llm
        self.tokenizer = tokenizer
    
    def get_vision_features(self, pixel_values):
        """Vision Encoder로 특징 추출"""
        with torch.no_grad():
            outputs = self.vision_encoder(pixel_values=pixel_values)
            # [CLS] 토큰 제외
            features = outputs.last_hidden_state[:, 1:, :]  # [B, 576, 1024]
        return features
    
    @torch.no_grad()
    def generate(self, pixel_values, prompt, max_new_tokens=256, **kwargs):
        """캡션 생성"""
        device = pixel_values.device
        
        # 1. Vision features 추출
        vision_features = self.get_vision_features(pixel_values)
        
        # 2. 프레임별 features를 하나로 합침
        # pixel_values: [num_frames, C, H, W] → vision_features: [num_frames, 576, 1024]
        # → reshape to [1, num_frames * 576, 1024]
        vision_features = vision_features.reshape(1, -1, vision_features.size(-1))  # [1, num_frames*576, 1024]
        
        # 3. Projector로 변환 (dtype 맞추기)
        projected = self.projector(vision_features.float())  # [1, N, 4096]
        
        # 4. 텍스트 토큰화
        text_inputs = self.tokenizer(
            [prompt],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(device)
        
        # 5. 텍스트 임베딩
        text_embeds = self.llm.get_input_embeddings()(text_inputs.input_ids)
        
        # 6. Vision + Text 결합 (dtype 맞추기)
        projected = projected.to(text_embeds.dtype)
        inputs_embeds = torch.cat([projected, text_embeds], dim=1)
        
        # 7. Attention mask 생성
        vision_mask = torch.ones(1, projected.size(1), device=device, dtype=text_inputs.attention_mask.dtype)
        attention_mask = torch.cat([vision_mask, text_inputs.attention_mask], dim=1)
        
        # 8. 생성
        gen_kwargs = {
            "inputs_embeds": inputs_embeds,
            "attention_mask": attention_mask,
            "max_new_tokens": max_new_tokens,
            "do_sample": False,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        
        outputs = self.llm.generate(**gen_kwargs)
        
        # 9. 디코딩
        captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return captions

print("CustomVLM class defined!")

## 3. 모델 로드 함수

In [None]:
def build_model(config, projector_type, device, resume_path=None):
    """
    모델 빌드 및 체크포인트 로드
    siglip_study 노트북과 동일한 방식
    """
    print(f"\n{'='*60}")
    print(f"Building model with {projector_type} projector...")
    print(f"{'='*60}")
    
    # 1. Vision Encoder (CLIP)
    print("Loading Vision Encoder...")
    vision_encoder = CLIPVisionModel.from_pretrained(config["vision_encoder"]).to(device)
    image_processor = CLIPImageProcessor.from_pretrained(config["vision_encoder"])
    vision_encoder.eval()
    for param in vision_encoder.parameters():
        param.requires_grad = False
    print(f"  ✅ Vision Encoder loaded (frozen)")
    
    # 2. LLM (Qwen3-8B with 4-bit quantization)
    print("Loading LLM...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    llm = AutoModelForCausalLM.from_pretrained(
        config["llm"],
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(config["llm"], trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print(f"  ✅ LLM loaded (4-bit quantized)")
    print(f"  PAD: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
    print(f"  EOS: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
    
    # 3. Projector 생성
    print(f"Creating {projector_type} Projector...")
    projector = create_projector(
        projector_type,
        vision_dim=1024,
        llm_dim=llm.config.hidden_size,
        config=config
    ).to(device)
    proj_params = sum(p.numel() for p in projector.parameters())
    print(f"  ✅ Projector created ({proj_params:,} params)")
    
    # 4. LoRA 적용
    print("Applying LoRA...")
    lora_config = LoraConfig(
        r=config["lora_r"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        task_type=TaskType.CAUSAL_LM,
    )
    llm = get_peft_model(llm, lora_config)
    llm.print_trainable_parameters()
    
    # 5. CustomVLM 조립
    model = CustomVLM(vision_encoder, projector, llm, tokenizer)
    
    # 6. 체크포인트 로드
    if resume_path and Path(resume_path).exists():
        print(f"\nLoading checkpoint: {resume_path}")
        checkpoint = torch.load(resume_path, map_location=device, weights_only=False)
        
        # Projector 가중치 로드
        if "projector_state_dict" in checkpoint:
            model.projector.load_state_dict(checkpoint["projector_state_dict"])
            print("  ✅ Projector weights loaded")
        
        # LoRA 가중치 로드
        if "lora_state_dict" in checkpoint:
            # PEFT 모델에 LoRA 가중치 적용
            missing, unexpected = model.llm.load_state_dict(checkpoint["lora_state_dict"], strict=False)
            print(f"  ✅ LoRA weights loaded (missing: {len(missing)}, unexpected: {len(unexpected)})")
        
        print("  ✅ Checkpoint loaded successfully!")
    else:
        print(f"\n⚠️ No checkpoint found at {resume_path}")
        print("  Using randomly initialized weights...")
    
    model.eval()
    print(f"\n{'='*60}")
    print("✅ Model ready for inference!")
    print(f"{'='*60}")
    
    return model, vision_encoder, image_processor, tokenizer


def load_trained_model(exp_name, projector_type, config):
    """
    학습된 모델 로드 (편의 함수)
    """
    ckpt_path = Path(config["results_dir"]) / exp_name / "checkpoints" / "best_model.pt"
    if not ckpt_path.exists():
        print(f"❌ Checkpoint not found: {ckpt_path}")
        return None, None, None, None
    
    model, vision_encoder, img_proc, tok = build_model(config, projector_type, "cuda", ckpt_path)
    model.eval()
    return model, vision_encoder, img_proc, tok

print("Model loading functions defined!")

## 4. 데이터 로드 유틸리티

In [None]:
import matplotlib.pyplot as plt

def extract_frames(video_path, num_frames=8):
    """비디오에서 균등 간격으로 프레임 추출"""
    cap = cv2.VideoCapture(str(video_path))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    if total_frames == 0:
        cap.release()
        return []
    
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    
    frames = []
    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame_rgb))
    
    cap.release()
    return frames


def load_sample(data_path, split="val", idx=0):
    """
    샘플 데이터 로드 - 두 가지 구조 지원
    
    구조 1 (siglip_study 기본):
      - {split}/labels/*.json
      - {split}/videos/*.mp4
      
    구조 2 (폴더별 샘플):
      - {split}/{sample_id}/video.mp4 또는 {sample_id}.mp4
      - {split}/{sample_id}/label.json 또는 {sample_id}.json
    """
    split_path = Path(data_path) / split
    
    if not split_path.exists():
        print(f"Split path not found: {split_path}")
        return None, None, None
    
    # 구조 1: labels/videos 폴더가 있는 경우
    label_dir = split_path / "labels"
    video_dir = split_path / "videos"
    
    if label_dir.exists() and video_dir.exists():
        print(f"Using structure 1: labels/videos folders")
        label_files = sorted(list(label_dir.glob("*.json")))
        
        if idx >= len(label_files):
            print(f"Index {idx} out of range (total: {len(label_files)})")
            return None, None, None
        
        label_path = label_files[idx]
        video_name = label_path.stem
        video_path = video_dir / f"{video_name}.mp4"
        
    # 구조 2: 각 샘플이 폴더인 경우
    else:
        print(f"Using structure 2: sample folders")
        sample_dirs = sorted([d for d in split_path.iterdir() if d.is_dir()])
        
        if idx >= len(sample_dirs):
            print(f"Index {idx} out of range (total: {len(sample_dirs)})")
            return None, None, None
        
        sample_dir = sample_dirs[idx]
        video_name = sample_dir.name
        
        # 비디오 파일 찾기
        video_path = None
        for pattern in ["video.mp4", "*.mp4", "video.avi", "*.avi"]:
            matches = list(sample_dir.glob(pattern))
            if matches:
                video_path = matches[0]
                break
        
        if video_path is None:
            print(f"No video found in {sample_dir}")
            return None, None, None
        
        # 라벨 파일 찾기
        label_path = None
        for pattern in ["label.json", "*.json"]:
            matches = list(sample_dir.glob(pattern))
            if matches:
                label_path = matches[0]
                break
    
    # 비디오 존재 확인
    if video_path is None or not video_path.exists():
        print(f"Video not found: {video_path}")
        return None, None, None
    
    # 라벨 로드
    gt_caption = ""
    if label_path and label_path.exists():
        try:
            with open(label_path, 'r', encoding='utf-8') as f:
                label_data = json.load(f)
            # 다양한 구조 시도
            if isinstance(label_data, dict):
                # 구조 1: annotation.description_kr
                gt_caption = label_data.get("annotation", {}).get("description_kr", "")
                # 대체 키들
                if not gt_caption:
                    gt_caption = label_data.get("description_kr", "")
                if not gt_caption:
                    gt_caption = label_data.get("description", "")
                if not gt_caption:
                    gt_caption = label_data.get("caption", "")
        except Exception as e:
            print(f"Error loading label: {e}")
    
    print(f"Loaded: {video_name}")
    print(f"  Video: {video_path}")
    print(f"  Caption length: {len(gt_caption)} chars")
    
    return video_path, gt_caption, video_name


def display_result(video_name, frames, generated_caption, gt_caption=None):
    """결과 시각화"""
    # 프레임 표시
    n_frames = min(len(frames), 8)
    if n_frames == 0:
        print("No frames to display!")
        return
        
    fig, axes = plt.subplots(1, n_frames, figsize=(3*n_frames, 3))
    
    if n_frames == 1:
        axes = [axes]
    
    for i, (ax, frame) in enumerate(zip(axes, frames[:n_frames])):
        ax.imshow(frame)
        ax.set_title(f'Frame {i+1}', fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f'Video: {video_name}', fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # 캡션 출력
    print("\n" + "="*70)
    print("📝 Generated Caption:")
    print("-"*70)
    print(generated_caption)
    
    if gt_caption:
        print("\n" + "="*70)
        print("📋 Ground Truth Caption:")
        print("-"*70)
        if len(gt_caption) > 500:
            print(gt_caption[:500] + "...")
        else:
            print(gt_caption)
    print("="*70 + "\n")

print("Data utilities defined!")

## 5. 모델 로드 및 Inference

In [None]:
# ============================================
# 사용할 실험 선택 (사용 가능한 모델 중 선택)
# ============================================

# 옵션 1: C-Abstractor (LR 감소 버전)
EXP_NAME = "E5_v2_lr_reduced"
PROJECTOR_TYPE = "c_abstractor"

# 옵션 2: C-Abstractor (Epoch 증가 버전)
# EXP_NAME = "E5_v3_epoch_increased"
# PROJECTOR_TYPE = "c_abstractor"

# 옵션 3: Perceiver Resampler
# EXP_NAME = "E7_v2_perceiver_optimized"
# PROJECTOR_TYPE = "perceiver"

# ============================================
# ⚠️ 참고: 
# - E1_v2 (Linear), E3_v2 (MLP)는 아직 학습되지 않음
# - C-Abstractor는 Mode Collapse 가능성 있음
# ============================================

print(f"Selected experiment: {EXP_NAME}")
print(f"Projector type: {PROJECTOR_TYPE}")

In [None]:
# 모델 로드
model, vision_encoder, image_processor, tokenizer = load_trained_model(
    EXP_NAME, PROJECTOR_TYPE, CONFIG
)

## 6. 샘플 비디오 캡셔닝 (2개)

In [None]:
def generate_caption_for_video(model, image_processor, video_path, config):
    """비디오에 대한 캡션 생성"""
    device = next(model.parameters()).device
    
    # 프레임 추출
    frames = extract_frames(video_path, num_frames=config["num_frames"])
    if not frames:
        return "Error: Could not extract frames", []
    
    # 이미지 전처리
    pixel_values = image_processor(
        images=frames,
        return_tensors="pt"
    ).pixel_values.to(device)  # [num_frames, 3, 336, 336]
    
    # 배치 차원 추가 (모델이 [B, num_frames, C, H, W] 기대할 수 있음)
    # 하지만 현재 구조는 [B * num_frames, C, H, W] 형태
    # generate 함수에서 처리
    
    # 캡션 생성
    with torch.no_grad():
        captions = model.generate(
            pixel_values,
            prompt=config["prompt"],
            max_new_tokens=config["max_new_tokens"],
            repetition_penalty=config["repetition_penalty"],
        )
    
    return captions[0] if captions else "Error: No caption generated", frames

In [None]:
# 샘플 1
print("\n" + "#"*70)
print("# SAMPLE 1")
print("#"*70)

video_path, gt_caption, video_name = load_sample(CONFIG["data_path"], "val", idx=0)

if video_path and video_path.exists():
    print(f"Video: {video_name}")
    generated, frames = generate_caption_for_video(model, image_processor, video_path, CONFIG)
    display_result(video_name, frames, generated, gt_caption)
else:
    print(f"Video not found!")

In [None]:
# 샘플 2
print("\n" + "#"*70)
print("# SAMPLE 2")
print("#"*70)

video_path, gt_caption, video_name = load_sample(CONFIG["data_path"], "val", idx=2)

if video_path and video_path.exists():
    print(f"Video: {video_name}")
    generated, frames = generate_caption_for_video(model, image_processor, video_path, CONFIG)
    display_result(video_name, frames, generated, gt_caption)
else:
    print(f"Video not found!")

## 7. (선택) 다른 실험 모델 비교

In [None]:
# 여러 모델 비교 (선택 사항)
def compare_models(video_idx=0):
    """여러 Projector 모델의 캡션 비교"""
    experiments = [
        ("E1_v2_linear_optimized", "linear"),
        ("E3_v2_mlp_optimized", "mlp_2l"),
        # ("E5_v2_lr_reduced", "c_abstractor"),  # Mode Collapse 위험
    ]
    
    # 비디오 로드
    video_path, gt_caption, video_name = load_sample(CONFIG["data_path"], "val", idx=video_idx)
    frames = extract_frames(video_path, num_frames=CONFIG["num_frames"])
    
    print(f"\n{'='*70}")
    print(f"Comparing models on: {video_name}")
    print(f"{'='*70}")
    
    # 프레임 표시
    fig, axes = plt.subplots(1, min(4, len(frames)), figsize=(12, 3))
    for i, ax in enumerate(axes):
        ax.imshow(frames[i])
        ax.axis('off')
    plt.suptitle(f'{video_name}', fontweight='bold')
    plt.show()
    
    results = {}
    for exp_name, proj_type in experiments:
        print(f"\n--- {exp_name} ({proj_type}) ---")
        
        # 모델 로드
        m, ve, ip, tok = load_trained_model(exp_name, proj_type, CONFIG)
        if m is None:
            print(f"  Skipped (checkpoint not found)")
            continue
        
        # 캡션 생성
        generated, _ = generate_caption_for_video(m, ip, video_path, CONFIG)
        results[exp_name] = generated
        print(f"  Caption: {generated[:200]}..." if len(generated) > 200 else f"  Caption: {generated}")
        
        # 메모리 정리
        del m
        torch.cuda.empty_cache()
    
    if gt_caption:
        print(f"\n--- Ground Truth ---")
        print(f"  {gt_caption[:200]}..." if len(gt_caption) > 200 else f"  {gt_caption}")
    
    return results

# 실행하려면 아래 주석 해제
# results = compare_models(video_idx=5)

## 8. 정리

In [None]:
# GPU 메모리 정리
import gc

if 'model' in dir():
    del model
if 'vision_encoder' in dir():
    del vision_encoder

gc.collect()
torch.cuda.empty_cache()

print("✅ Cleanup complete!")