In [2]:
from google.colab import userdata
token = userdata.get('HF_TOKEN')

In [4]:
import os
import json
import math
import random
import numpy as np
from PIL import Image
from pathlib import Path
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from decord import VideoReader, cpu

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    CLIPVisionModel,
    CLIPImageProcessor,
)
from einops import rearrange, repeat
from torch import einsum

## Data

In [5]:
class AIHubVideoCaptionDataset(Dataset):
    def __init__(self, root_dir, num_frames=10, language="kr"):
        self.video_dir = os.path.join(root_dir, "videos")
        self.label_dir = os.path.join(root_dir, "labels")
        self.num_frames = num_frames
        self.language = language
        self.samples = []

        for fname in os.listdir(self.video_dir):
            if not fname.lower().endswith(".mp4"):
                continue

            video_path = os.path.join(self.video_dir, fname)
            json_name = os.path.splitext(fname)[0] + ".json"
            label_path = os.path.join(self.label_dir, json_name)

            if os.path.exists(label_path):
                self.samples.append({"video": video_path, "label": label_path})

    def __len__(self):
        return len(self.samples)

    def load_video_frames(self, video_path):
        vr = VideoReader(video_path, ctx=cpu(0))
        total_frames = len(vr)
        indices = np.linspace(0, total_frames - 1, self.num_frames).astype(int)
        frames = vr.get_batch(indices).asnumpy()
        return [Image.fromarray(frame) for frame in frames]

    def __getitem__(self, idx):
        sample = self.samples[idx]
        video_path = sample["video"]
        frames = self.load_video_frames(video_path)

        with open(sample["label"], "r", encoding="utf-8") as f:
            data = json.load(f)
            caption = (
                data["annotation"]["description_kr"]
                if self.language == "kr"
                else data["annotation"]["description_en"]
            )

        # ‚úÖ video_pathÎ•º Í∞ôÏù¥ Î∞òÌôò
        return video_path, frames, caption

In [None]:
TRAIN_ROOT = "train"
VAL_ROOT = "val"
train_dataset = AIHubVideoCaptionDataset(root_dir=TRAIN_ROOT, num_frames=10, language="kr")
val_dataset = AIHubVideoCaptionDataset(root_dir=VAL_ROOT, num_frames=10, language="kr")

print(f"‚úÖ Train data: {len(train_dataset)}")
print(f"‚úÖ Val data: {len(val_dataset)}")

‚úÖ Train data: 865
‚úÖ Val data: 97


#Adapter
##PerceiverResambler
> ÌîÑÎ†àÏûÑ ÌäπÏßï(Í∞ÄÎ≥Ä Í∏∏Ïù¥) -> Í≥†Ï†ï Í∏∏Ïù¥ `num_queries`Ïùò visual tokensÎ°ú ÏïïÏ∂ïÌïòÎäî Î™®Îìà

In [7]:
class PerceiverResampler(nn.Module):
    """Compress variable-length frame features into a fixed number of visual tokens.

    Input:
      - x: [B, T, Dv] OR [B, T, F, Dv]
    Output:
      - latents: [B, N, Dl]
    """
    def __init__(
        self,
        video_dim: int,
        llm_dim: int,
        num_queries: int = 16,
        depth: int = 2,
        heads: int = 8,
    ):
        super().__init__()
        self.num_queries = num_queries
        self.llm_dim = llm_dim

        # learnable latent queries
        self.latents = nn.Parameter(torch.randn(num_queries, llm_dim) * 0.02)

        # project video features to llm dim
        self.proj_in = nn.Linear(video_dim, llm_dim) if video_dim != llm_dim else nn.Identity()

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                nn.MultiheadAttention(embed_dim=llm_dim, num_heads=heads, batch_first=True),
                nn.LayerNorm(llm_dim),
                nn.LayerNorm(llm_dim),

                nn.MultiheadAttention(embed_dim=llm_dim, num_heads=heads, batch_first=True),
                nn.LayerNorm(llm_dim),

                nn.Sequential(
                    nn.LayerNorm(llm_dim),
                    nn.Linear(llm_dim, llm_dim * 4),
                    nn.GELU(),
                    nn.Linear(llm_dim * 4, llm_dim),
                ),
            ]))

        self.norm_out = nn.LayerNorm(llm_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 4:
            x = rearrange(x, 'b t f d -> b (t f) d')

        b, t, dv = x.shape
        x = self.proj_in(x)                      # [B, T, Dl]
        latents = repeat(self.latents, 'n d -> b n d', b=b)  # [B, N, Dl]

        for cross_attn, ln_q, ln_kv, self_attn, ln_self, ffn in self.layers:
            # cross-attn: latents attend to video tokens
            q = ln_q(latents)
            kv = ln_kv(x)
            attn_out, _ = cross_attn(query=q, key=kv, value=kv)
            latents = latents + attn_out

            # self-attn among latents
            q2 = ln_self(latents)
            attn_out2, _ = self_attn(query=q2, key=q2, value=q2)
            latents = latents + attn_out2

            # ffn
            latents = latents + ffn(latents)

        return self.norm_out(latents)


### CLIPFrameEncoder
> ÎπÑÎîîÏò§ ÌîÑÎ†àÏûÑÏùÑ CLIP Vision EncoderÏóê ÎÑ£Ïñ¥ ÌîÑÎ†àÏûÑÎ≥Ñ feature ÎΩëÍ∏∞

In [8]:
class CLIPFrameEncoder(nn.Module):
    """Vision encoder: video frames -> per-frame features.

    Expects video as float tensor [B, T, 3, H, W] in range [0, 1].
    Produces frame features [B, T, Dv].
    """
    def __init__(self, vision_name: str = 'openai/clip-vit-base-patch32'):
        super().__init__()
        self.vision = CLIPVisionModel.from_pretrained(vision_name)
        self.processor = CLIPImageProcessor.from_pretrained(vision_name)
        self.video_dim = self.vision.config.hidden_size

        # Freeze vision encoder by default (Flamingo-style)
        for p in self.vision.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def forward(self, video: torch.Tensor) -> torch.Tensor:
        b, t, c, h, w = video.shape

        target = self.processor.size.get('shortest_edge', 224)
        frames = video.view(b * t, c, h, w)
        frames = F.interpolate(frames, size=(target, target), mode='bilinear', align_corners=False)

        mean = torch.tensor(self.processor.image_mean, device=frames.device).view(1, 3, 1, 1)
        std  = torch.tensor(self.processor.image_std,  device=frames.device).view(1, 3, 1, 1)
        pixel_values = (frames - mean) / std

        out = self.vision(pixel_values=pixel_values)
        feats = out.last_hidden_state[:, 0]  # CLS, [B*T, Dv]
        return feats.view(b, t, -1)


## Cross Attention + Gating

### 1) MaskedCrossAttention
> ÌÖçÏä§Ìä∏ hidden(x) -> query
>
> ÎπÑÏ£ºÏñº ÌÜ†ÌÅ∞(media) -> key/value
>
> ÌÖçÏä§Ìä∏Í∞Ä ÎπÑÏ£ºÏñº Ï†ïÎ≥¥Î•º Ï∞∏Ï°∞ÌïòÎèÑÎ°ù ÎßåÎì¶

### 2) GatedCrossAttentionBlock

In [9]:
class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, inner_dim, bias=False),
            nn.GELU(),
            nn.Linear(inner_dim, dim, bias=False),
        )

    def forward(self, x):
        return self.net(x)

class MaskedCrossAttention(nn.Module):
    """Text queries attend to visual tokens (non-causal across media)."""
    def __init__(self, dim, dim_visual, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, media):
        # ---- FIX: allow x to be (T, D) by adding batch dim ----
        squeeze_b = False
        if x.dim() == 2:
            x = x.unsqueeze(0)          # [1, T, D]
            squeeze_b = True

        # mediaÎèÑ ÌòπÏãú (N, D)Î°ú Îì§Ïñ¥Ïò§Î©¥ [1, N, D]Î°ú ÎßûÏ∂§
        if media.dim() == 2:
            media = media.unsqueeze(0)  # [1, N, D]

        b, t, d = x.shape
        h = self.heads

        x = self.norm(x)
        q = self.to_q(x)

        if media.ndim == 4:
            media = rearrange(media, 'b tm n dv -> b (tm n) dv')

        k, v = self.to_kv(media).chunk(2, dim=-1)

        q = rearrange(q, 'b t (h dh) -> b h t dh', h=h)
        k = rearrange(k, 'b n (h dh) -> b h n dh', h=h)
        v = rearrange(v, 'b n (h dh) -> b h n dh', h=h)

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h t dh -> b t (h dh)')
        out = self.to_out(out)

        # ÏõêÎûò xÍ∞Ä 2DÏòÄÏúºÎ©¥ Ï∂úÎ†•ÎèÑ 2DÎ°ú Î≥µÍµ¨
        if squeeze_b:
            out = out.squeeze(0)        # [T, D]

        return out


class GatedCrossAttentionBlock(nn.Module):
    """Flamingo-style gated xattn + gated FFN."""
    def __init__(self, dim, dim_visual, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.attn = MaskedCrossAttention(dim=dim, dim_visual=dim_visual, dim_head=dim_head, heads=heads)
        self.attn_gate = nn.Parameter(torch.tensor([0.0]))

        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.0]))

    def forward(self, x, media):
        if x.dim() == 2:
            x = x.unsqueeze(0)
            squeezed = True
        else:
            squeezed = False

        x = self.attn(x, media) * self.attn_gate.tanh() + x
        x = self.ff(x) * self.ff_gate.tanh() + x

        if squeezed:
            x = x.squeeze(0)
        return x



## FlamingoDecoderLayerWrapper
> LLMÏùÄ Í∑∏ÎåÄÎ°ú ÎëêÍ≥†, ÎπÑÏ£ºÏñºÏùÑ ÌïÑÏöîÌï† ÎïåÎßå Ï∞∏Ï°∞ÌïòÍ≤å ÎßåÎì§Í∏∞

## VideoTextModel
> Ï†ÑÏ≤¥ ÌååÏù¥ÌîÑÎùºÏù∏ Ï°∞Î¶Ω

In [10]:
class FlamingoDecoderLayerWrapper(nn.Module):
    def __init__(self, base_layer: nn.Module, xattn_block: nn.Module | None):
        super().__init__()
        self.base_layer = base_layer
        self.xattn_block = xattn_block
        self._media = None

    def set_media(self, media: torch.Tensor):
        self._media = media

    def forward(self, *args, **kwargs):
        outputs = self.base_layer(*args, **kwargs)

        # HF Î™®Îç∏ Î†àÏù¥Ïñ¥Í∞Ä Tensor ÎòêÎäî tupleÏùÑ Î∞òÌôòÌï† Ïàò ÏûàÏùå
        if isinstance(outputs, tuple):
            hidden_states = outputs[0]
            rest = outputs[1:]
        else:
            hidden_states = outputs
            rest = ()

        if (self._media is not None) and (self.xattn_block is not None):
            hidden_states = self.xattn_block(hidden_states, self._media)

        return (hidden_states,) + rest if rest else hidden_states

    def __getattr__(self, name):
        # nn.Module ÏÜçÏÑ± ÌÉêÏÉâ Î®ºÏ†Ä -> ÏóÜÏúºÎ©¥ base_layerÎ°ú ÏúÑÏûÑ
        try:
            return nn.Module.__getattr__(self, name)
        except AttributeError:
            return getattr(self.base_layer, name)

class VideoTextModel(nn.Module):
    def __init__(
        self,
        llm_name: str = "meta-llama/Llama-3.2-3B",
        vision_name: str = "openai/clip-vit-base-patch32",
        perceiver_depth: int = 2,
        perceiver_heads: int = 8,
        num_visual_tokens: int = 64,
        xattn_heads: int = 8,
        xattn_dim_head: int = 64,
        xattn_every: int = 2,
        freeze_llm: bool = True,
        hf_token: str | None = None,
        torch_dtype: torch.dtype | None = None,
    ):
        super().__init__()

        if hf_token is None:
            hf_token = os.environ.get("HF_TOKEN")  # Colab SecretsÏóêÏÑú Í∞ÄÏ†∏Ïò§Í∏∞

        # dtype Í∏∞Î≥∏Í∞í: bf16 Í∞ÄÎä•ÌïòÎ©¥ bf16, ÏïÑÎãàÎ©¥ fp16
        if torch_dtype is None:
            if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
                torch_dtype = torch.bfloat16
            else:
                torch_dtype = torch.float16

        # 1) Tokenizer (HF ÌÜ†ÌÅ∞ Ìè¨Ìï®)
        self.tokenizer = AutoTokenizer.from_pretrained(llm_name, token=hf_token, use_fast=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 2) LLM (HF ÌÜ†ÌÅ∞ Ìè¨Ìï®)
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_name,
            token=hf_token,
            torch_dtype=torch_dtype,
            device_map=None,  # Ïó¨Í∏∞ÏÑúÎäî Ïô∏Î∂ÄÏóêÏÑú .cuda() / .to(device)Î°ú ÌÜµÏùº
        )
        self.llm.config.pad_token_id = self.tokenizer.pad_token_id

        llm_dim = self.llm.config.hidden_size

        if freeze_llm:
            for p in self.llm.parameters():
                p.requires_grad = False

        # 3) Vision + Perceiver
        self.vision_encoder = CLIPFrameEncoder(vision_name=vision_name)
        self.perceiver = PerceiverResampler(
            video_dim=self.vision_encoder.video_dim,
            llm_dim=llm_dim,
            num_queries=num_visual_tokens,
            depth=perceiver_depth,
            heads=perceiver_heads,
        )

        # 4) Wrap layers with xattn blocks
        layers = self.llm.model.layers
        wrapped = []
        for i, layer in enumerate(layers):
            use_xattn = (xattn_every > 0) and ((i % xattn_every) == 0)
            xattn = None
            if use_xattn:
                xattn = GatedCrossAttentionBlock(
                    dim=llm_dim,
                    dim_visual=llm_dim,
                    heads=xattn_heads,
                    dim_head=xattn_dim_head,
                )
            wrapped.append(FlamingoDecoderLayerWrapper(layer, xattn))
        self.llm.model.layers = nn.ModuleList(wrapped)

    def _set_media_for_layers(self, media_tokens: torch.Tensor):
        for layer in self.llm.model.layers:
            layer.set_media(media_tokens)

    def encode_video(self, video: torch.Tensor) -> torch.Tensor:
        frame_feats = self.vision_encoder(video)   # [B,T,Dv]
        media_tokens = self.perceiver(frame_feats) # [B,64,Dl]
        return media_tokens

    def forward(self, video: torch.Tensor, text: list[str], labels: torch.Tensor | None = None):
        device = video.device

        tokens = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        tokens = {k: v.to(device) for k, v in tokens.items()}

        media_tokens = self.encode_video(video)
        self._set_media_for_layers(media_tokens)

        return self.llm(
            input_ids=tokens["input_ids"],
            attention_mask=tokens.get("attention_mask", None),
            labels=labels,
        )

    @torch.no_grad()
    def generate(self, video: torch.Tensor, prompt: str | list[str], **gen_kwargs):
        if isinstance(prompt, str):
            prompt = [prompt]
        device = video.device

        tokens = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        tokens = {k: v.to(device) for k, v in tokens.items()}

        media_tokens = self.encode_video(video)
        self._set_media_for_layers(media_tokens)

        return self.llm.generate(**tokens, **gen_kwargs)


##Training

In [11]:
# =========================
# 1) Collate: (frames[List[PIL]], caption[str]) -> video tensor [B,T,3,H,W]
# =========================
def collate_video_caption_with_paths(batch, image_size=224):
    """
    batch: List[(video_path:str, frames:List[PIL], caption:str)]
    returns:
      video_paths: List[str]
      video: FloatTensor [B,T,3,H,W]
      captions: List[str]
    """
    video_paths, frames_list, captions = zip(*batch)

    videos = []
    for frames in frames_list:
        one_video = []
        for img in frames:
            if img.size != (image_size, image_size):
                img = img.resize((image_size, image_size), resample=Image.BICUBIC)
            one_video.append(to_tensor(img))  # [3,H,W], float(0~1)
        videos.append(torch.stack(one_video, dim=0))  # [T,3,H,W]

    video = torch.stack(videos, dim=0)  # [B,T,3,H,W]
    return list(video_paths), video, list(captions)

def _safe_video_id(video_path: str) -> str:
    # ÌååÏùºÎ™Ö Í∏∞Î∞ò (ÌôïÏû•Ïûê Ï†úÍ±∞)
    base = os.path.basename(video_path)
    return os.path.splitext(base)[0]

def cache_path_for(video_path: str, cache_dir: str, num_frames: int, image_size: int, vision_name: str) -> str:
    os.makedirs(cache_dir, exist_ok=True)
    vid = _safe_video_id(video_path)
    vtag = vision_name.replace("/", "_")
    fname = f"{vid}__T{num_frames}_{image_size}__{vtag}.pt"
    return os.path.join(cache_dir, fname)

@torch.no_grad()
def get_frame_feats_cached(
    model,                 # VideoTextModel (vision_encoder + perceiver Î≥¥Ïú†)
    video_paths: list[str],
    video_tensor: torch.Tensor,  # [B,T,3,H,W]
    cache_dir: str,
    num_frames: int,
    image_size: int,
    vision_name: str,
    dtype_for_cache=torch.float16,
):
    """
    return frame_feats: [B,T,Dv] on GPU (model device)
    - Ï∫êÏãú ÌûàÌä∏: Î°úÎìú
    - Ï∫êÏãú ÎØ∏Ïä§: vision_encoder Ïã§Ìñâ ÌõÑ Ï†ÄÏû•
    """
    device = video_tensor.device
    B = len(video_paths)

    feats_cpu_list = [None] * B
    miss_indices = []

    # 1) Ï∫êÏãú ÌôïÏù∏
    for i, vp in enumerate(video_paths):
        cpath = cache_path_for(vp, cache_dir, num_frames, image_size, vision_name)
        if os.path.exists(cpath):
            feats_cpu_list[i] = torch.load(cpath, map_location="cpu")  # [T,Dv]
        else:
            miss_indices.append(i)

    # 2) ÎØ∏Ïä§Îßå vision encoderÎ°ú Í≥ÑÏÇ∞
    if len(miss_indices) > 0:
        miss_video = video_tensor[miss_indices]  # [Bm,T,3,H,W]
        frame_feats = model.vision_encoder(miss_video)  # [Bm,T,Dv] (on device)
        frame_feats_cpu = frame_feats.detach().to("cpu", dtype=dtype_for_cache)  # Ï†ÄÏû•ÏùÄ cpu+fp16

        # 3) Ï†ÄÏû• + Î¶¨Ïä§Ìä∏ Ï±ÑÏö∞Í∏∞
        for j, i in enumerate(miss_indices):
            feats_TDv = frame_feats_cpu[j]  # [T,Dv]
            cpath = cache_path_for(video_paths[i], cache_dir, num_frames, image_size, vision_name)

            # atomic save (Ï§ëÍ∞Ñ Íπ®Ïßê Î∞©ÏßÄ)
            tmp = cpath + ".tmp"
            torch.save(feats_TDv, tmp)
            os.replace(tmp, cpath)

            feats_cpu_list[i] = feats_TDv

    # 4) Î∞∞ÏπòÎ°ú Îã§Ïãú Î¨∂Ïñ¥ÏÑú GPUÎ°ú Ïù¥Îèô
    feats = torch.stack(feats_cpu_list, dim=0).to(device=device, dtype=torch.float16)
    return feats


In [None]:

# =========================
# 2) labels ÏÉùÏÑ±: prompt Î∂ÄÎ∂Ñ -100 ÎßàÏä§ÌÇπ (Ï∫°ÏÖòÎßå loss)
# =========================
def build_text_and_labels(model, captions, prompt, max_len, device):
    """
    model.tokenizerÏôÄ ÎèôÏùºÌïú Í∏∞Ï§ÄÏúºÎ°ú tokenizingÌïòÏó¨ labelsÎ•º Ï†ïÌôïÌûà ÎßûÏ∂òÎã§.
    - input text = prompt + caption
    - labelsÎäî prompt ÌÜ†ÌÅ∞Í≥º padÎ•º -100 Ï≤òÎ¶¨Ìï¥ÏÑú caption ÌÜ†ÌÅ∞Îßå loss Í≥ÑÏÇ∞
    """
    tok = model.tokenizer
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    texts = [prompt + c for c in captions]

    tok_full = tok(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_len,
    )
    tok_prompt = tok(
        [prompt] * len(captions),
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_len,
    )

    input_ids = tok_full["input_ids"].to(device)
    attention_mask = tok_full.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)

    prompt_lens = tok_prompt["attention_mask"].sum(dim=1).to(device)  # [B]

    labels = input_ids.clone()
    for i in range(labels.size(0)):
        pl = int(prompt_lens[i].item())
        labels[i, :pl] = -100
    if attention_mask is not None:
        labels[attention_mask == 0] = -100

    return texts, input_ids, attention_mask, labels


# =========================
# 3) Gate warmup: warmup ÎèôÏïà gate ÌååÎùºÎØ∏ÌÑ∞ ÎèôÍ≤∞
# =========================
def set_gate_requires_grad(model, flag: bool):
    """
    Flamingo dense ÏÇΩÏûÖÎêú Î†àÏù¥Ïñ¥(wrapper)Ïùò xattn_block Ïïà gate ÌååÎùºÎØ∏ÌÑ∞Î•º Ï†úÏñ¥
    """
    for layer in model.llm.model.layers:
        if hasattr(layer, "xattn_block"):
            xb = layer.xattn_block
            if hasattr(xb, "attn_gate"):
                xb.attn_gate.requires_grad = flag
            if hasattr(xb, "ff_gate"):
                xb.ff_gate.requires_grad = flag

def apply_gate_warmup(model, global_step: int, warmup_steps: int):
    if warmup_steps <= 0:
        set_gate_requires_grad(model, True)
        return
    if global_step < warmup_steps:
        set_gate_requires_grad(model, False)  
    else:
        set_gate_requires_grad(model, True)


# =========================
# 4) AMP / checkpointing ÏÑ∏ÌåÖ
# =========================
def pick_amp_dtype():
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
        return torch.bfloat16
    return torch.float16

def enable_grad_checkpointing_if_needed(model, enabled: bool):
    if not enabled:
        return
    try:
        model.llm.gradient_checkpointing_enable()
        model.llm.config.use_cache = False
        print("‚úÖ gradient checkpointing enabled")
    except Exception as e:
        print(f"[warn] gradient checkpointing enable failed: {e}")


# =========================
# 5) Train / Val Î£®ÌîÑ
# =========================
def run_train_flamingo_dense(
    model,
    train_dataset,
    val_dataset=None,
    *,
    prompt="Ïù¥ ÎπÑÎîîÏò§Î•º Í∞ÑÎã®ÌïòÍ≤å ÏÑ§Î™ÖÌï¥Ï§ò: ",
    image_size=224,
    max_text_len=128,
    batch_size=2,
    num_workers=2,
    lr=5e-5,
    epochs=10,
    grad_clip=1.0,
    use_amp=True,
    use_grad_ckpt=False,
    gate_warmup_steps=200,
    log_every=10,
    save_path="flamingo_dense_ckpt.pt",
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    amp_dtype = pick_amp_dtype()

    model = model.to(device)
    model.train()

    # ---- checkpointing ÏòµÏÖò
    enable_grad_checkpointing_if_needed(model, use_grad_ckpt)

    # ---- DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=lambda b: collate_video_caption(b, image_size=image_size),
        pin_memory=True,
    )

    val_loader = None
    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            collate_fn=lambda b: collate_video_caption(b, image_size=image_size),
            pin_memory=True,
        )

    # ---- Trainable params: perceiver + xattn blocks (LLM freezeÎ•º Í∂åÏû•)
    trainable_params = []
    trainable_params += list(model.perceiver.parameters())
    for layer in model.llm.model.layers:
        if hasattr(layer, "xattn_block"):
            trainable_params += list(layer.xattn_block.parameters())

    optimizer = torch.optim.AdamW(trainable_params, lr=lr)

    # fp16Îßå scaler ÌïÑÏöî (bf16ÏùÄ Î≥¥ÌÜµ scaler Î∂àÌïÑÏöî)
    scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))

    global_step = 0

    print(f"üöÄ Train start | amp={use_amp}({amp_dtype}) | grad_ckpt={use_grad_ckpt} | warmup_steps={gate_warmup_steps}")
    print(f"‚úÖ Train data: {len(train_dataset)}")
    if val_dataset is not None:
        print(f"‚úÖ Val data: {len(val_dataset)}")

    for epoch in range(epochs):
        total_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for step, (video, captions) in enumerate(pbar):
            video = video.to(device, non_blocking=True)

            # ---- text/labels Íµ¨ÏÑ±
            texts, input_ids, attention_mask, labels = build_text_and_labels(
                model, captions, prompt, max_text_len, device
            )

            # ---- gate warmup
            apply_gate_warmup(model, global_step, gate_warmup_steps)

            optimizer.zero_grad(set_to_none=True)

            # ---- media tokens ÏÉùÏÑ± + Î†àÏù¥Ïñ¥Ïóê Ï£ºÏûÖ
            # encode_video ÎÇ¥Î∂ÄÏóêÏÑú vision+perceiverÎ•º Í±∞Ï≥ê [B,64,D] ÏÉùÏÑ±
            media_tokens = model.encode_video(video)
            model._set_media_for_layers(media_tokens)

            # ---- forward (LLM ÏßÅÏ†ë Ìò∏Ï∂ú)
            with torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
                out = model.llm(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                loss = out.loss

            # ---- backward + step
            if use_amp and amp_dtype == torch.float16:
                scaler.scale(loss).backward()
                if grad_clip and grad_clip > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(trainable_params, grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if grad_clip and grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(trainable_params, grad_clip)
                optimizer.step()

            total_loss += loss.item()
            global_step += 1

            if (step + 1) % log_every == 0:
                pbar.set_postfix({"loss": f"{loss.item():.4f}", "avg": f"{total_loss/(step+1):.4f}"})
            else:
                pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_loss = total_loss / max(1, len(train_loader))
        print(f"Epoch {epoch+1} avg loss: {avg_loss:.4f}")

        # ---- Validation
        if val_loader is not None:
            model.eval()
            vloss = 0.0
            with torch.no_grad():
                for video, captions in tqdm(val_loader, desc="Valid"):
                    video = video.to(device, non_blocking=True)
                    _, input_ids, attention_mask, labels = build_text_and_labels(
                        model, captions, prompt, max_text_len, device
                    )
                    media_tokens = model.encode_video(video)
                    model._set_media_for_layers(media_tokens)

                    with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
                        out = model.llm(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            labels=labels,
                        )
                        vloss += out.loss.item()

            vavg = vloss / max(1, len(val_loader))
            print(f"Val avg loss: {vavg:.4f}")
            model.train()

        # ---- Save (perceiver + xattn blocksÎßå)
        ckpt = {
            "perceiver": model.perceiver.state_dict(),
            "xattn_blocks": {
                f"layer_{i}": layer.xattn_block.state_dict()
                for i, layer in enumerate(model.llm.model.layers)
                if hasattr(layer, "xattn_block")
            },
            "epoch": epoch + 1,
            "global_step": global_step,
        }
        torch.save(ckpt, save_path)
        print(f"üíæ Saved checkpoint: {save_path}")

    print("‚úÖ Training done.")
    return model


In [None]:
def run_train_flamingo_dense_cached(
    model,
    train_dataset,
    val_dataset=None,
    *,
    cache_dir="cache",
    save_dir="checkpoint",  # ‚úÖ Ï∂îÍ∞Ä
    vision_name="openai/clip-vit-base-patch32",
    num_frames=10,
    image_size=224,
    prompt="Ïù¥ ÎπÑÎîîÏò§Î•º Í∞ÑÎã®ÌïòÍ≤å ÏÑ§Î™ÖÌï¥Ï§ò: ",
    max_text_len=128,
    batch_size=2,
    num_workers=4,
    lr=5e-5,
    epochs=10,
    use_amp=True,
    gate_warmup_steps=200,
):
    import os
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=lambda b: collate_video_caption_with_paths(b, image_size=image_size),
        pin_memory=True,
    )

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=lr
    )

    amp_dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
    scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))

    global_step = 0
    best_loss = float("inf")

    for epoch in range(epochs):
        epoch_loss_sum = 0.0
        step_count = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for video_paths, video, captions in pbar:
            video = video.to(device, non_blocking=True)

            _, input_ids, attention_mask, labels = build_text_and_labels(
                model, captions, prompt, max_text_len, device
            )

            apply_gate_warmup(model, global_step, gate_warmup_steps)
            optimizer.zero_grad(set_to_none=True)

            frame_feats = get_frame_feats_cached(
                model=model,
                video_paths=video_paths,
                video_tensor=video,
                cache_dir=cache_dir,
                num_frames=num_frames,
                image_size=image_size,
                vision_name=vision_name,
                dtype_for_cache=torch.float16,
            )

            perceiver_dtype = model.perceiver.proj_in.weight.dtype

            with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
                media_tokens = model.perceiver(frame_feats.to(dtype=perceiver_dtype))
                model._set_media_for_layers(media_tokens)

                out = model.llm(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                loss = out.loss

            if use_amp and amp_dtype == torch.float16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            global_step += 1
            epoch_loss_sum += loss.item()
            step_count += 1
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_loss = epoch_loss_sum / max(step_count, 1)

        # ‚úÖ epoch Ï†ÄÏû•
        epoch_path = f"{save_dir}/epoch_{epoch+1:02d}.pt"
        torch.save({
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "avg_epoch_loss": avg_loss,
        }, epoch_path)
        print(f"[Saved] {epoch_path} (avg_loss={avg_loss:.4f})")

        # ‚úÖ best Ï†ÄÏû•
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_path = f"{save_dir}/best.pt"
            torch.save({
                "epoch": epoch + 1,
                "global_step": global_step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "avg_epoch_loss": avg_loss,
            }, best_path)
            print(f"[Saved BEST] {best_path} (loss={best_loss:.4f})")

    return model


In [14]:
# NOTE: This patch was for Qwen2-specific incompatibilities. You can skip it when using Llama-family models.
def patch_wrapper_forward_robust(model):
    """
    Qwen2 + wrapper Ìò∏Ìôò:
    - base_layer Ï∂úÎ†• hidden_statesÍ∞Ä 2DÎ°ú ÎÇòÏò§Í±∞ÎÇò([B*T,D], [B,D], [T,D]) ÌòïÌÉúÍ∞Ä ÏÑûÏó¨ÎèÑ
      ÏûÖÎ†• hidden shapeÎ•º Ïù¥Ïö©Ìï¥ ÏµúÎåÄÌïú [B,T,D]Î°ú Î≥µÍµ¨.
    - Î≥µÍµ¨ Î∂àÍ∞ÄÎä•Ìïú ÏòàÏô∏ ÏºÄÏù¥Ïä§Îäî xattnÏùÑ Ïä§ÌÇµÌï¥ÏÑú ÌïôÏäµÏùÑ Í≥ÑÏÜç ÏßÑÌñâ(ÏùëÍ∏â ÏïàÏ†ÑÏû•Ïπò).
    """
    WrapperCls = type(model.llm.model.layers[0])
    if getattr(WrapperCls, "_patched_qwen2_robust", False):
        print("‚úÖ Robust wrapper patch already applied")
        return

    def new_forward(self, *args, **kwargs):
        inp_hidden = args[0] if len(args) > 0 else None  # Î≥¥ÌÜµ [B,T,D]

        outputs = self.base_layer(*args, **kwargs)

        if isinstance(outputs, (tuple, list)):
            hidden_states = outputs[0]
            rest = outputs[1:]
            tuple_mode = True
        else:
            hidden_states = outputs[0]
            tuple_mode = False

        # ---- robust reshape
        if inp_hidden is not None and isinstance(hidden_states, torch.Tensor):
            if inp_hidden.dim() == 3 and hidden_states.dim() == 2:
                B, T, D = inp_hidden.shape

                # (1) [B*T, D]
                if hidden_states.shape == (B * T, D):
                    hidden_states = hidden_states.view(B, T, D)

                # (2) [B, D]  -> [B,1,D]
                elif hidden_states.shape == (B, D):
                    hidden_states = hidden_states.unsqueeze(1)

                # (3) [T, D] -> [1,T,D] ÌòπÏùÄ [B,T,D] Ï§ë Í∞ÄÎä•Ìïú Í≤ÉÏúºÎ°ú
                elif hidden_states.shape == (T, D):
                    hidden_states = hidden_states.unsqueeze(0)  # [1,T,D]
                    if B != 1:
                        # Î∞∞ÏπòÍ∞Ä 1Ïù¥ ÏïÑÎãåÎç∞ [T,D]Î°ú ÏôîÎã§Î©¥ Î≥µÍµ¨ Î∂àÍ∞Ä -> ÏùºÎã® expand ÏãúÎèÑ
                        hidden_states = hidden_states.expand(B, T, D)

                else:
                    # Î™®Î•¥Îäî 2D shape: ÏïàÏ†ÑÌïòÍ≤å [B,1,D]Î°ú ÎßûÏ∂∞Î≥¥Í∏∞ (Í∞ÄÎä•Ìïú Í≤ΩÏö∞)
                    if hidden_states.shape[1] == D:
                        hidden_states = hidden_states[:B].unsqueeze(1)  # [B,1,D] (ÏûòÎùºÏÑúÎùºÎèÑ)
                    # Í∑∏ÎûòÎèÑ Ïïà ÎêòÎ©¥ Í∑∏ÎåÄÎ°ú ÎëêÍ≥† ÏïÑÎûòÏóêÏÑú Ïä§ÌÇµ Ï≤òÎ¶¨

        # ---- media Ï£ºÏûÖ (shapeÍ∞Ä 3D ÏïÑÎãê ÎïåÎäî Ïä§ÌÇµ)
        if getattr(self, "_media", None) is not None:
            if isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 3:
                hidden_states = self.xattn_block(hidden_states, self._media)
            else:
                # ÏùëÍ∏â ÏïàÏ†ÑÏû•Ïπò: shapeÏù¥ Ïù¥ÏÉÅÌïòÎ©¥ xattnÏùÑ Í±¥ÎÑàÎõ∞Í≥† base_layer Ï∂úÎ†•Îßå ÏÇ¨Ïö©
                # (ÌïôÏäµÏùÄ Í≥ÑÏÜç ÏßÑÌñâÎêòÏßÄÎßå, Ìï¥Îãπ stepÏóêÏÑ† Î©ÄÌã∞Î™®Îã¨ Ï£ºÏûÖÏù¥ Îπ†Ïßê)
                pass

        if tuple_mode:
            return (hidden_states,) + rest
        else:
            outputs[0] = hidden_states
            return outputs

    WrapperCls.forward = new_forward
    WrapperCls._patched_qwen2_robust = True
    print("‚úÖ Patched FlamingoDecoderLayerWrapper.forward (ROBUST)")

In [None]:
model = VideoTextModel(
    llm_name="meta-llama/Llama-3.2-3B",
    vision_name="openai/clip-vit-base-patch32",
    num_visual_tokens=64,
    perceiver_depth=2,
    xattn_every=2,
    freeze_llm=True,
).cuda()

In [16]:
# 1) wrapper ÌÅ¥ÎûòÏä§Ïóê attention_type property Í∞ïÏ†ú Ï£ºÏûÖ
def _attention_type(self):
    return getattr(self.base_layer, "attention_type", "global")

FlamingoDecoderLayerWrapper.attention_type = property(_attention_type)

# 2) ÌôïÏù∏
print(hasattr(model.llm.model.layers[0], "attention_type"))
print(model.llm.model.layers[0].attention_type)


True
global


In [17]:
def _wrapper_getattr(self, name):
    try:
        return nn.Module.__getattr__(self, name)
    except AttributeError:
        return getattr(self.base_layer, name)

FlamingoDecoderLayerWrapper.__getattr__ = _wrapper_getattr


In [None]:
trained_model = run_train_flamingo_dense_cached(
    model,
    train_dataset,
    val_dataset=None,
    cache_dir="cache",
    save_dir="checkpoint",
    vision_name="openai/clip-vit-base-patch32",
    num_frames=10,
    image_size=224,
    batch_size=8,
    epochs=10,
    lr=5e-5,
)

  scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))
  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [23:42<00:00, 13.05s/it, loss=2.1691]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_01.pt (avg_loss=2.3410)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=2.3410)


Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [16:10<00:00,  8.91s/it, loss=1.8383]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_02.pt (avg_loss=2.3225)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=2.3225)


Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [26:51<00:00, 14.78s/it, loss=1.3693]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_03.pt (avg_loss=1.8795)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=1.8795)


Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [25:32<00:00, 14.06s/it, loss=1.9757]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_04.pt (avg_loss=1.6564)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=1.6564)


Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [26:56<00:00, 14.83s/it, loss=1.6024]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_05.pt (avg_loss=1.4657)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=1.4657)


Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [26:19<00:00, 14.49s/it, loss=1.3322]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_06.pt (avg_loss=1.2593)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=1.2593)


  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [27:04<00:00, 14.90s/it, loss=0.8275]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_07.pt (avg_loss=1.0184)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=1.0184)


Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [26:39<00:00, 14.67s/it, loss=0.5373]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_08.pt (avg_loss=0.7678)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=0.7678)


Epoch 9/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [26:12<00:00, 14.42s/it, loss=0.6122]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_09.pt (avg_loss=0.5291)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=0.5291)


Epoch 10/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [27:01<00:00, 14.88s/it, loss=0.3696]


[Saved] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/epoch_10.pt (avg_loss=0.3322)
[Saved BEST] /content/drive/MyDrive/·ÑÜ·Ö•·Ü∫·Ñå·Ö¢·Üº·Ñã·Öµ ·Ñâ·Ö°·Ñå·Ö°·Ñé·Ö•·ÑÖ·Ö•·Ü∑ AI NLP /·Ñâ·Öµ·ÜØ·Ñå·Ö•·Ü´ ·Ñë·Ö≥·ÑÖ·Ö©·Ñå·Ö¶·Ü®·Ñê·Ö≥2/checkpoint/best.pt (loss=0.3322)


In [None]:
torch.cuda.empty_cache()

In [18]:
import nltk
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

### Load Best Trained Model


In [None]:
checkpoint_path = 'best.pt'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(checkpoint_path, map_location=device)

model.load_state_dict(checkpoint['model'])
model.eval()

# Ensure gate parameters in GatedCrossAttentionBlock match the LLM's dtype
llm_dtype = model.llm.dtype 

for layer in model.llm.model.layers:
    if hasattr(layer, "xattn_block") and layer.xattn_block is not None:
        if hasattr(layer.xattn_block, "attn_gate"):
            layer.xattn_block.attn_gate.data = layer.xattn_block.attn_gate.data.to(llm_dtype)
        if hasattr(layer.xattn_block, "ff_gate"):
            layer.xattn_block.ff_gate.data = layer.xattn_block.ff_gate.data.to(llm_dtype)

print(f"‚úÖ Model weights loaded from {checkpoint_path} and set to evaluation mode successfully.")
print("‚úÖ xattn_block gate parameters cast to LLM's dtype.")

### ÏòÅÏÉÅÏóê ÎåÄÌïú LLM Í≤∞Í≥º Ï∂îÏ∂ú


In [None]:
import torch
import nltk
from torch.utils.data import DataLoader
from tqdm import tqdm
# from nltk.translate.meteor_score import meteor_score # Removed as per user request


def evaluate_model(
    model,
    val_dataset,
    prompt,
    image_size=224,
    batch_size=4,
    num_workers=2,
    gen_kwargs=None
):
    device = next(model.parameters()).device
    model.eval()

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,  
        num_workers=num_workers,
        collate_fn=lambda b: collate_video_caption_with_paths(b, image_size=image_size),
        pin_memory=True,
    )

    generated_captions = []
    reference_captions = []
    video_path_list = []  

    if gen_kwargs is None:
        gen_kwargs = {
            "max_new_tokens": 100,
            "num_beams": 5,
            "do_sample": False,
            "early_stopping": True,
        }

    amp_dtype = (
        torch.bfloat16
        if (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
        else torch.float16
    )

    print("Starting evaluation...")
    with torch.no_grad():
        for video_paths, video_tensor, captions in tqdm(val_loader, desc="Evaluating"):
            video_tensor = video_tensor.to(device, non_blocking=True)
            B = video_tensor.shape[0]


            # Encode video
            media_tokens = model.encode_video(video_tensor)  # [B, N, D]

            num_beams = gen_kwargs.get("num_beams", 1)
            if num_beams > 1:
                media_tokens = media_tokens.repeat_interleave(num_beams, dim=0)

            model._set_media_for_layers(media_tokens)

            # Prompt tokenize
            input_tokenized = model.tokenizer(
                [prompt] * B,
                return_tensors="pt",
                padding=True,
                truncation=True,
            ).to(device)

            with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
                output_ids = model.llm.generate(
                    **input_tokenized,
                    **gen_kwargs
                )

            generated_texts = model.tokenizer.batch_decode(
                output_ids, skip_special_tokens=True
            )

            # Í≤∞Í≥º Ï†ÄÏû• (üî• pathÍπåÏßÄ Í∞ôÏù¥)
            for i in range(len(generated_texts)):
                gen = generated_texts[i]
                if gen.startswith(prompt):
                    gen = gen[len(prompt):].strip()

                generated_captions.append(gen)
                reference_captions.append(captions[i])
                video_path_list.append(video_paths[i])  # üî• ÌïµÏã¨

    print(f"\nEvaluation complete. Generated {len(generated_captions)} captions.") # Modified print statement

    return generated_captions, reference_captions, video_path_list # Removed avg_meteor_score

In [33]:
validation_prompt = """
ÎãπÏã†ÏùÄ ÏòÅÏÉÅÏóê Ïã§Ï†úÎ°ú Î≥¥Ïù¥Îäî Ïû•Î©¥ÎßåÏùÑ ÏÇ¨Ïã§Ï†ÅÏúºÎ°ú ÏÑ§Î™ÖÌïòÎäî Ïñ¥ÏãúÏä§ÌÑ¥Ìä∏ÏûÖÎãàÎã§.
Ï∂îÏ∏°, Ìï¥ÏÑù, ÏùºÎ∞òÌôî ÏóÜÏù¥ ÌôîÎ©¥Ïóê Í¥ÄÏ∞∞ÎêòÎäî ÏöîÏÜåÎßå Í∏∞Ïà†ÌïòÏÑ∏Ïöî.

[ÏûëÏÑ± Í∑úÏπô ‚Äî Î∞òÎìúÏãú Î™®Îëê Îî∞Î•¥ÏÑ∏Ïöî]
1. Ï¥¨ÏòÅ Ïó∞ÎèÑ, ÏãúÎåÄ, Í≥ºÍ±∞/ÌòÑÎåÄ, Í≥ÑÏ†à Ï∂îÏ†ï ÌëúÌòÑÏùÑ Ï†àÎåÄ ÏÇ¨Ïö©ÌïòÏßÄ ÎßàÏÑ∏Ïöî.
2. ÏòÅÏÉÅÏóêÏÑú Ïã§Ï†úÎ°ú Î≥¥Ïù¥ÏßÄ ÏïäÎäî Ïû•ÏÜå(Í≥µÏõê, ÎèÑÏã¨, Ïà≤ Îì±)Î•º ÏûÑÏùòÎ°ú ÌåêÎã®ÌïòÏßÄ ÎßàÏÑ∏Ïöî.
3. ÌôîÎ©¥Ïóê Î≥¥Ïù¥Îäî Í∞ùÏ≤¥, ÏÉâÏÉÅ, Î∞∞Ïπò, ÏõÄÏßÅÏûÑÎßå Î¨òÏÇ¨ÌïòÏÑ∏Ïöî.
4. Í±¥Î¨º Ïù¥Î¶Ñ, Í∞ÑÌåê, Î°úÍ≥†, ÌÖçÏä§Ìä∏Îäî ÌôîÎ©¥Ïóê Î™ÖÌôïÌûà Î≥¥Ïùº Í≤ΩÏö∞ÏóêÎßå Í∑∏ÎåÄÎ°ú ÏûëÏÑ±ÌïòÏÑ∏Ïöî.
5. Î™®Îì† Î¨∏Ïû•ÏùÄ ÌôîÎ©¥ ÏúÑÏπòÎ•º Ìè¨Ìï®Ìï¥Ïïº Ìï©ÎãàÎã§.
6. Î¨∏Ïû•ÏùÄ Ï§ëÍ∞ÑÏóê ÎÅäÍ∏∞ÏßÄ ÏïäÎèÑÎ°ù ÎÅùÍπåÏßÄ ÏôÑÏÑ±ÌïòÏÑ∏Ïöî.

[ÏÑúÏà† ÏàúÏÑú]
‚ë† ÎÇ†Ïî®ÏôÄ ÌïòÎäò ÏÉÅÌÉú
‚ë° ÌôîÎ©¥ Ï§ëÏïôÏùò Ï£ºÏöî Í∞ùÏ≤¥
‚ë¢ Ï¢åÏ∏°¬∑Ïö∞Ï∏°¬∑Ï†ÑÎ©¥¬∑Î∞∞Í≤Ω ÏöîÏÜå
‚ë£ ÏÇ¨Îûå¬∑Ï∞®Îüâ¬∑ÏûêÏó∞Î¨º
‚ë§ Ïπ¥Î©îÎùº ÏõÄÏßÅÏûÑ ÎòêÎäî Í≥†Ï†ï Ïó¨Î∂Ä

[Ï∂úÎ†• ÌòïÏãù]
- 500ÏûêÎ°ú ÏûëÏÑ±
- Í∞Å Î¨∏Ïû•ÏùÄ Í¥ÄÏ∞∞ ÏÇ¨Ïã§ ÌïòÎÇòÎßå Ìè¨Ìï®
- Î™®Îì† Î¨∏Ïû•ÏùÄ ÏôÑÍ≤∞Ìòï

[ÏïÑÎûò ÏòÅÏÉÅÏóê ÎåÄÌïú ÏÑ§Î™Ö]
"""

generated, references, video_paths = evaluate_model(
    model,
    val_dataset,
    prompt=validation_prompt,
    image_size=224,
    batch_size=8,
    num_workers=2
)
print("\n--- Sample Generations ---")
for i in range(min(5, len(generated))):
    video_name = os.path.basename(video_paths[i]) # Extract only the filename
    print(f"Video name: {video_name}")
    print(f"Reference: {references[i]}")
    print(f"Generated: {generated[i]}\n")

Starting evaluation...


Evaluating:   0%|          | 0/13 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Evaluating:   8%|‚ñä         | 1/13 [00:24<04:51, 24.29s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Evaluating:  15%|‚ñà‚ñå        | 2/13 [00:31<02:39, 14.51s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Evaluating:  23%|‚ñà‚ñà‚ñé       | 3/13 [00:40<01:56, 11.62s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Evaluating:  31%|‚ñà‚ñà‚ñà       | 4/13 [00:47<01:30, 10.03s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Evaluating:  38%|‚ñà‚ñà‚ñà‚ñä      | 5/13 [00:55<01:14,  9.34s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Evaluating:  46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 6/13 [01:03<01:01,  8.73s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Evaluating:  54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 7/13 [01:11<00


Evaluation complete. Generated 97 captions.

--- Sample Generations ---
Video name: architecture-mid-NEWS01892-0001.mp4
Reference: Ïù¥ ÏòÅÏÉÅÏùÄ ÎßëÏùÄ ÎÇ†Ïî®Ïóê Ï¥¨ÏòÅÎêòÏóàÏäµÎãàÎã§. ÌôîÎ©¥ Ï§ëÏïôÏóêÎäî "ÌïúÍµ≠ÌïôÏà†ÏßÑÌù•Ïû¨Îã®" Í±¥Î¨ºÏù¥ ÏûêÎ¶¨ Ïû°Í≥† ÏûàÏäµÎãàÎã§. Ïù¥ Í±¥Î¨ºÏùÄ Î≤†Ïù¥ÏßÄÏÉâ Ïô∏Î≤ΩÍ≥º ÏàòÏßÅÏúºÎ°ú Î∞∞Ïó¥Îêú Ï∞ΩÎ¨∏Îì§Ïù¥ ÌäπÏßïÏù¥Î©∞, 5Ï∏µ ÎÜíÏù¥Î°ú Î≥¥ÏûÖÎãàÎã§. Í±¥Î¨ºÏùò ÏÉÅÎã®ÏóêÎäî Í≤ÄÏùÄÏÉâ Î∞∞Í≤ΩÏóê Ìù∞ Í∏ÄÏî®Î°ú Ïì∞Ïù∏ Í±¥Î¨ºÎ™ÖÏù¥ ÏÑ†Î™ÖÌïòÍ≤å ÎÇòÌÉÄÎÇò ÏûàÏäµÎãàÎã§. Í±¥Î¨ºÏùò Ï¢åÏ∏°ÏúºÎ°úÎäî Ïú†Î¶¨Î°ú ÎßàÍ∞êÎêú Í≥†Ï∏µ Í±¥Î¨ºÏù¥ ÏùºÎ∂Ä Î≥¥ÏûÖÎãàÎã§. ÌôîÎ©¥ Ï§ëÏïô Í±¥Î¨º ÏïûÏóêÎäî Î™á ÎåÄÏùò Ï∞®ÎüâÏù¥ Ï†ïÏ∞®Ìï¥ ÏûàÎäî Î™®ÏäµÏù¥ Î≥¥ÏûÖÎãàÎã§. Í±¥Î¨ºÏùò Ïö∞Ï∏°ÏóêÎäî ÎÇòÎ¨¥Í∞Ä Ïã¨Ïñ¥ÏßÑ Ïñ∏ÎçïÏù¥ Ïù¥Ïñ¥Ï†∏ ÏûàÏäµÎãàÎã§. Ïñ∏Îçï ÏúÑ ÎÇòÎ¨¥Îì§ÏùÄ ÏûéÏùÄ Í∞àÏÉâÍ≥º Ï¥àÎ°ùÏÉâÏùÑ Îù†Í≥† ÏûàÏäµÎãàÎã§. ÌôîÎ©¥ Ï†ÑÎ©¥ÏóêÎäî ÎèÑÎ°úÍ∞Ä Î≥¥Ïù¥Î©∞, ÎèÑÎ°ú ÏúÑÎ°úÎäî Ï∞®Îì§Ïù¥ Ïò§Í∞ÄÎäî Î™®ÏäµÏù¥ Îã¥Í≤® ÏûàÏäµÎãàÎã§. ÏäπÏö©Ï∞®ÏôÄ ÌôîÎ¨ºÏ∞®Í∞Ä Ïù¥Îèô Ï§ëÏù¥Î©∞, ÎèÑÎ°ú Ï§ëÏïôÏóêÎ


