In [1]:
!nvidia-smi

Fri Nov 28 15:06:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off |   00000000:00:04.0 Off |                    0 |
| N/A   35C    P0             28W /  250W |       0MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
!pip install -q "scikit-learn<1.6.0" underthesea rouge-score pycocoevalcap vncorenlp timm transformers

# Library

In [3]:
import os
import re
import csv
import json
import time
import random
from typing import List, Dict, Any
from collections import Counter

import numpy as np
import pandas as pd 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

import timm
from transformers import AutoTokenizer, AutoModel

from torchvision import transforms
from PIL import Image

from tqdm.auto import tqdm
from underthesea import word_tokenize

# Metrics
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer

import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("✓ Libraries imported successfully!")



✓ Libraries imported successfully!


# Config

In [4]:
class config:
    # ==== PATH DATA GỐC ====
    TRAIN_IMAGES_PATH = '/kaggle/input/d/mrworldzero/openvivqa/data/train-images/training-images'
    TEST_IMAGES_PATH  = '/kaggle/input/d/mrworldzero/openvivqa/data/test-images/test-images'
    DEV_IMAGES_PATH   = '/kaggle/input/d/mrworldzero/openvivqa/data/dev-images/dev-images'
    
    TRAIN_JSON_PATH   = '/kaggle/input/d/mrworldzero/openvivqa/data/vlsp2023_train_data.json'
    TEST_JSON_PATH    = '/kaggle/input/d/mrworldzero/openvivqa/data/vlsp2023_test_data.json'
    DEV_JSON_PATH     = '/kaggle/input/d/mrworldzero/openvivqa/data/vlsp2023_dev_data.json'
    
    # ==== JSON FLAT SAU KHI CHUYỂN ====
    TRAIN_JSON_FLAT = "/kaggle/working/data/train_flat.json"
    DEV_JSON_FLAT   = "/kaggle/working/data/val_flat.json"
    TEST_JSON_FLAT  = "/kaggle/working/data/test_flat.json"
    
    # ==== TRAINING ====
    SEED = 42
    IMAGE_SIZE = 224
    BATCH_SIZE = 8
    NUM_WORKERS = 4
    NUM_EPOCHS = 10
    
    LR_VIT_PHOBERT = 2e-5      # LR cho encoder (ViT + PhoBERT)
    LR_DECODER = 5e-4          # LR cho decoder + fusion
    WEIGHT_DECAY = 1e-5
    
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # ==== MÔ HÌNH ====
    TEXT_ENCODER_NAME = "vinai/phobert-base-v2"
    VISION_NAME = "vit_base_patch16_224"
    
    MAX_QUESTION_LEN = 64
    MAX_ANSWER_LEN = 15
    DEC_HIDDEN_SIZE = 256
    MIN_ANSWER_FREQ = 3  # min freq để đưa token vào vocab
    
    NUM_BEAMS = 5  # ở đây mình dùng greedy; nếu muốn beam search thì code thêm
    
    OUT_DIR = "outputs_vqa"
    NUM_VIZ_EXAMPLES = 5

cfg = config()
os.makedirs(cfg.OUT_DIR, exist_ok=True)

# Seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(cfg.SEED)
print(f"Using device: {cfg.DEVICE}")

# NLTK & ROUGE
try:
    nltk.data.find("corpora/wordnet")
except LookupError:
    nltk.download("wordnet")

smooth_fn = SmoothingFunction().method1
rouge_scorer_obj = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)


Using device: cuda


[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [5]:
def text_normalize_simple(s: str) -> str:
    s = str(s).strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def vi_seg(s: str) -> List[str]:
    """Segment tiếng Việt bằng underthesea → list token."""
    s = text_normalize_simple(s)
    return word_tokenize(s, format="text").split()

# Data Preparation

In [6]:
def convert_json_to_flat_json(input_file, folder_path, output_file, split="train"):
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    flat_list = []
    skipped = 0

    for anno_id, annotation in data["annotations"].items():
        image_id = annotation["image_id"]
        question = annotation["question"]

        image_name = data["images"].get(str(image_id), "")
        image_path = os.path.join(folder_path, image_name)

        if split in ["train", "dev"]:
            answer = str(annotation.get("answer", "")).strip()

            if (answer == "") or (answer.lower() == "your answer"):
                skipped += 1
                continue

            flat_item = {
                "image_path": image_path,
                "question": question,
                "answer": answer,
            }
        else:
            flat_item = {
                "image_path": image_path,
                "question": question,
            }

        flat_list.append(flat_item)

    out_dir = os.path.dirname(output_file)
    if out_dir != "":
        os.makedirs(out_dir, exist_ok=True)

    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(flat_list, f, ensure_ascii=False, indent=2)

    print(f"[{split}] Saved {len(flat_list)} valid items to {output_file}")
    if split in ["train", "dev"] and skipped > 0:
        print(f"[{split}] Skipped {skipped} invalid items (empty or 'your answer')")

# Chạy nếu chưa có file flat
convert_json_to_flat_json(
    cfg.TRAIN_JSON_PATH,
    cfg.TRAIN_IMAGES_PATH,
    cfg.TRAIN_JSON_FLAT,
    split="train"
)

convert_json_to_flat_json(
    cfg.DEV_JSON_PATH,
    cfg.DEV_IMAGES_PATH,
    cfg.DEV_JSON_FLAT,
    split="dev"
)

convert_json_to_flat_json(
    cfg.TEST_JSON_PATH,
    cfg.TEST_IMAGES_PATH,
    cfg.TEST_JSON_FLAT,
    split="test"
)

[train] Saved 30833 valid items to /kaggle/working/data/train_flat.json
[dev] Saved 3545 valid items to /kaggle/working/data/val_flat.json
[test] Saved 14035 valid items to /kaggle/working/data/test_flat.json


# Build Vocabulary

In [7]:
def build_answer_vocab(train_json_flat: str, min_freq: int = 1):
    with open(train_json_flat, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    counter = Counter()
    for item in data:
        ans = item["answer"]
        tokens = vi_seg(ans)
        counter.update(tokens)
    
    # SPECIAL
    specials = ["<pad>", "<bos>", "<eos>", "<unk>"]
    stoi = {}
    itos = []
    
    for sp in specials:
        idx = len(itos)
        itos.append(sp)
        stoi[sp] = idx
    
    for tok, freq in counter.items():
        if freq >= min_freq and tok not in stoi:
            idx = len(itos)
            itos.append(tok)
            stoi[tok] = idx
    
    print(f"Vocab size (answer): {len(itos)} (min_freq={min_freq})")
    return stoi, itos

answer_stoi, answer_itos = build_answer_vocab(cfg.TRAIN_JSON_FLAT, cfg.MIN_ANSWER_FREQ)

PAD_ID = answer_stoi["<pad>"]
BOS_ID = answer_stoi["<bos>"]
EOS_ID = answer_stoi["<eos>"]
UNK_ID = answer_stoi["<unk>"]

def encode_answer(text: str) -> torch.Tensor:
    tokens = vi_seg(text)
    ids = [BOS_ID] + [answer_stoi.get(t, UNK_ID) for t in tokens] + [EOS_ID]
    return torch.tensor(ids, dtype=torch.long)

def decode_answer_ids(ids: List[int]) -> str:
    tokens = []
    for i in ids:
        if i == EOS_ID:
            break
        if i in [PAD_ID, BOS_ID]:
            continue
        if 0 <= i < len(answer_itos):
            tokens.append(answer_itos[i])
    return " ".join(tokens)


Vocab size (answer): 3498 (min_freq=3)


# Dataset + DataLoader

In [8]:
# PhoBERT tokenizer cho câu hỏi
q_tokenizer = AutoTokenizer.from_pretrained(cfg.TEXT_ENCODER_NAME)

# Image transform cho ViT
image_transform = transforms.Compose([
    transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
    transforms.RandomResizedCrop(cfg.IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(
        brightness=0.1, contrast=0.1, saturation=0.1
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
])


class VQADataset(Dataset):
    def __init__(self, json_flat_path: str, tokenizer, transform, has_answer: bool = True):
        super().__init__()
        with open(json_flat_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)
        self.tokenizer = tokenizer
        self.transform = transform
        self.has_answer = has_answer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = item["image_path"]
        question = item["question"]
        
        # Load image
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        
        # Encode question bằng PhoBERT tokenizer
        encoded = self.tokenizer(
            text_normalize_simple(question),
            max_length=cfg.MAX_QUESTION_LEN,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        q_input_ids = encoded["input_ids"].squeeze(0)      # (L,)
        q_attention = encoded["attention_mask"].squeeze(0) # (L,)
        
        sample = {
            "image": img,
            "question": question,
            "q_input_ids": q_input_ids,
            "q_attention_mask": q_attention
        }
        
        if self.has_answer:
            answer = item["answer"]
            ans_ids = encode_answer(answer)
            sample["answer"] = answer
            sample["answer_ids"] = ans_ids
        
        return sample

def vqa_collate_fn(batch):
    images = torch.stack([b["image"] for b in batch], dim=0)
    q_input_ids = torch.stack([b["q_input_ids"] for b in batch], dim=0)
    q_attention = torch.stack([b["q_attention_mask"] for b in batch], dim=0)
    
    out = {
        "images": images,
        "q_input_ids": q_input_ids,
        "q_attention_mask": q_attention,
        "questions": [b["question"] for b in batch],
    }
    
    if "answer_ids" in batch[0]:
        ans_seqs = [b["answer_ids"] for b in batch]
        ans_padded = pad_sequence(ans_seqs, batch_first=True, padding_value=PAD_ID)
        out["answer_ids"] = ans_padded
        out["answers"] = [b["answer"] for b in batch]
    
    return out

train_dataset = VQADataset(cfg.TRAIN_JSON_FLAT, q_tokenizer, image_transform, has_answer=True)
dev_dataset   = VQADataset(cfg.DEV_JSON_FLAT,   q_tokenizer, image_transform, has_answer=True)
test_dataset  = VQADataset(cfg.TEST_JSON_FLAT,  q_tokenizer, image_transform, has_answer=False)

train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.BATCH_SIZE,
    shuffle=True,
    num_workers=cfg.NUM_WORKERS,
    collate_fn=vqa_collate_fn
)

dev_loader = DataLoader(
    dev_dataset,
    batch_size=cfg.BATCH_SIZE,
    shuffle=False,
    num_workers=cfg.NUM_WORKERS,
    collate_fn=vqa_collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=cfg.BATCH_SIZE,
    shuffle=False,
    num_workers=cfg.NUM_WORKERS,
    collate_fn=vqa_collate_fn
)

print(len(train_dataset), len(dev_dataset), len(test_dataset))


30833 3545 14035


# Encoder ViT + PhoBERT và Decoder LSTM

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from transformers import AutoModel

# Giả sử đã có PAD_ID, BOS_ID, EOS_ID, decode_answer_ids, cfg.DEC_HIDDEN_SIZE trong môi trường.


class ViTEncoderTokens(nn.Module):
    def __init__(self, model_name: str = "vit_base_patch16_224"):
        super().__init__()
        # global_pool="" để trả ra chuỗi token patch
        self.backbone = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=0,
            global_pool=""
        )
        self.out_dim = self.backbone.num_features  # D

    def forward(self, x):
        """
        x: (B,3,H,W)
        return: (B, N_img, D)
        """
        feats = self.backbone(x)  # (B, N_img, D)
        return feats


class PhoBERTEncoderSeq(nn.Module):
    def __init__(self, model_name: str):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.hidden_size = self.model.config.hidden_size

    def forward(self, input_ids, attention_mask):
        """
        return: last_hidden_state (B, L, H)
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        last_hidden = outputs.last_hidden_state  # (B, L, H)
        return last_hidden


class VQAModel(nn.Module):
    def __init__(self, cfg, vocab_size: int):
        super().__init__()
        self.cfg = cfg

        # Encoders
        self.vit = ViTEncoderTokens(cfg.VISION_NAME)
        self.text_encoder = PhoBERTEncoderSeq(cfg.TEXT_ENCODER_NAME)

        ctx_dim = cfg.DEC_HIDDEN_SIZE
        self.ctx_dim = ctx_dim

        # Project image & text tokens về cùng chiều
        self.img_proj = nn.Linear(self.vit.out_dim, ctx_dim)
        self.txt_proj = nn.Linear(self.text_encoder.hidden_size, ctx_dim)

        # Fusion global context (img_ctx + txt_mean) để init decoder
        self.global_fuse = nn.Linear(ctx_dim * 2, ctx_dim)
        self.global_ln = nn.LayerNorm(ctx_dim)
        self.global_dropout = nn.Dropout(0.3)

        # Decoder LSTM (input = [token_emb, img_ctx])
        self.embed_ans = nn.Embedding(vocab_size, ctx_dim, padding_idx=PAD_ID)
        self.emb_dropout = nn.Dropout(0.3)

        self.decoder = nn.LSTM(
            input_size=ctx_dim * 2,   # token_emb + img_ctx
            hidden_size=ctx_dim,
            num_layers=1,
            batch_first=True
        )

        # Cross-attn lên chuỗi PhoBERT (text memory)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=ctx_dim,
            num_heads=4,
            dropout=0.3,
            batch_first=True
        )

        # Gating cho ảnh + fusion output
        self.gate_ff = nn.Linear(ctx_dim * 3, ctx_dim)   # [h_t, text_ctx, img_ctx]
        self.fusion_ff = nn.Linear(ctx_dim * 3, ctx_dim)
        self.fusion_dropout = nn.Dropout(0.3)
        self.fusion_ln = nn.LayerNorm(ctx_dim)

        # Weight tying: out_proj.weight = embed_ans.weight
        self.out_proj = nn.Linear(ctx_dim, vocab_size, bias=False)
        self.out_proj.weight = self.embed_ans.weight

    # ====== BUILD MEMORY (text) + GLOBAL CONTEXT ======
    def build_memory(self, images, q_input_ids, q_attention_mask):
        """
        images: (B,3,H,W)
        q_input_ids: (B,L)
        q_attention_mask: (B,L)

        return:
            img_ctx:     (B, C)
            memory:      (B, L_txt, C)  (PhoBERT tokens)
            memory_mask: (B, L_txt)     (1=valid, 0=pad)
            global_ctx:  (B, C)
        """
        device = images.device
        B = images.size(0)

        # Image tokens -> mean pool
        img_tokens = self.vit(images)                  # (B, N_img, D_v)
        img_feats = torch.tanh(self.img_proj(img_tokens))  # (B, N_img, C)
        img_ctx = img_feats.mean(dim=1)                # (B, C)

        # Text tokens
        txt_tokens = self.text_encoder(q_input_ids, q_attention_mask)  # (B,L,H_t)
        txt_feats = torch.tanh(self.txt_proj(txt_tokens))              # (B,L,C)
        txt_mask = q_attention_mask                                    # (B,L)

        # mean pooling text để tạo txt_ctx
        mask_exp = txt_mask.unsqueeze(-1)                      # (B,L,1)
        txt_sum = (txt_feats * mask_exp).sum(dim=1)            # (B,C)
        lengths = mask_exp.sum(dim=1).clamp(min=1)             # (B,1)
        txt_ctx = txt_sum / lengths                            # (B,C)

        # global context = fusion(img_ctx, txt_ctx)
        global_ctx = torch.tanh(self.global_fuse(
            torch.cat([img_ctx, txt_ctx], dim=-1)
        ))                                                     # (B,C)
        global_ctx = self.global_ln(global_ctx)
        global_ctx = self.global_dropout(global_ctx)

        memory = txt_feats                                     # (B,L,C)
        memory_mask = txt_mask                                 # (B,L)

        return img_ctx, memory, memory_mask, global_ctx

    # ====== TRAIN FORWARD ======
    def forward(self, images, q_input_ids, q_attention_mask, ans_input_ids):
        """
        ans_input_ids: (B, L_ans) với [BOS, ..., EOS]
        """
        img_ctx, memory, memory_mask, global_ctx = self.build_memory(
            images, q_input_ids, q_attention_mask
        )  # img_ctx:(B,C), memory:(B,Lm,C), mask:(B,Lm), global_ctx:(B,C)

        # Decoder inputs
        dec_in_ids = ans_input_ids[:, :-1]  # (B, L-1)
        targets    = ans_input_ids[:, 1:]   # (B, L-1)

        tok_emb = self.embed_ans(dec_in_ids)   # (B, L-1, C)
        tok_emb = self.emb_dropout(tok_emb)

        # lặp img_ctx theo thời gian
        img_ctx_exp = img_ctx.unsqueeze(1).expand(-1, tok_emb.size(1), -1)  # (B,L-1,C)
        dec_input = torch.cat([tok_emb, img_ctx_exp], dim=-1)               # (B,L-1,2C)

        # Init hidden/cell
        h0 = global_ctx.unsqueeze(0)           # (1,B,C)
        c0 = torch.zeros_like(h0)

        dec_out, (h_n, c_n) = self.decoder(dec_input, (h0, c0))   # (B,L-1,C)

        # Cross-attention lên text memory
        attn_ctx, _ = self.cross_attn(
            dec_out,              # query
            memory, memory,       # key, value
            key_padding_mask=(memory_mask == 0)
        )  # (B,L-1,C)

        # Gating với ảnh
        img_ctx_time = img_ctx.unsqueeze(1).expand_as(dec_out)  # (B,L-1,C)
        gate_input = torch.cat([dec_out, attn_ctx, img_ctx_time], dim=-1)  # (B,L-1,3C)
        gate = torch.sigmoid(self.gate_ff(gate_input))                      # (B,L-1,C)
        visual_ctx = gate * img_ctx_time                                    # (B,L-1,C)

        fused = torch.cat([dec_out, attn_ctx, visual_ctx], dim=-1)          # (B,L-1,3C)
        fused = torch.tanh(self.fusion_ff(fused))                           # (B,L-1,C)
        fused = self.fusion_dropout(fused)
        fused = self.fusion_ln(fused)

        logits = self.out_proj(fused)                                       # (B,L-1,V)
        return logits, targets

    # ====== 1 STEP DECODE (cho greedy / beam) ======
    def _decode_step(self, prev_token, h, c, memory, memory_mask, img_ctx):
        """
        prev_token: (B,)
        h, c:      (1,B,C)
        memory:    (B,Lm,C)
        memory_mask:(B,Lm)
        img_ctx:   (B,C)
        """
        tok_emb = self.embed_ans(prev_token).unsqueeze(1)  # (B,1,C)
        tok_emb = self.emb_dropout(tok_emb)

        img_ctx_exp = img_ctx.unsqueeze(1)                 # (B,1,C)
        dec_input = torch.cat([tok_emb, img_ctx_exp], dim=-1)  # (B,1,2C)

        dec_out, (h, c) = self.decoder(dec_input, (h, c))      # (B,1,C)

        attn_ctx, _ = self.cross_attn(
            dec_out, memory, memory,
            key_padding_mask=(memory_mask == 0)
        )  # (B,1,C)

        gate_input = torch.cat([dec_out, attn_ctx, img_ctx_exp], dim=-1)    # (B,1,3C)
        gate = torch.sigmoid(self.gate_ff(gate_input))                       # (B,1,C)
        visual_ctx = gate * img_ctx_exp                                     # (B,1,C)

        fused = torch.cat([dec_out, attn_ctx, visual_ctx], dim=-1)          # (B,1,3C)
        fused = torch.tanh(self.fusion_ff(fused))                           # (B,1,C)
        fused = self.fusion_dropout(fused)
        fused = self.fusion_ln(fused)

        logits = self.out_proj(fused.squeeze(1))                            # (B,V)
        return logits, h, c

    # ====== GREEDY DECODE ======
    def generate_greedy(self, images, q_input_ids, q_attention_mask,
                        max_len=32, bos_id=BOS_ID, eos_id=EOS_ID):
        self.eval()
        decoded = []
        with torch.no_grad():
            img_ctx, memory, memory_mask, global_ctx = self.build_memory(
                images, q_input_ids, q_attention_mask
            )  # img_ctx:(B,C), memory:(B,Lm,C), mask:(B,Lm), global:(B,C)
            B = images.size(0)

            h = global_ctx.unsqueeze(0)              # (1,B,C)
            c = torch.zeros_like(h)
            prev_tokens = torch.full(
                (B,), bos_id, dtype=torch.long, device=images.device
            )
            finished = torch.zeros(B, dtype=torch.bool, device=images.device)
            sequences = [[] for _ in range(B)]

            for _ in range(max_len):
                logits, h, c = self._decode_step(
                    prev_tokens, h, c, memory, memory_mask, img_ctx
                )
                next_tokens = logits.argmax(dim=-1)   # (B,)

                for i in range(B):
                    if finished[i]:
                        continue
                    tid = next_tokens[i].item()
                    if tid == eos_id:
                        finished[i] = True
                    else:
                        sequences[i].append(tid)

                if finished.all():
                    break

                prev_tokens = next_tokens

            decoded = [decode_answer_ids(seq) for seq in sequences]
        return decoded

    # ====== BEAM SEARCH CHO 1 SAMPLE ======
    def _generate_one_beam(self, img_ctx, memory, memory_mask, global_ctx,
                           max_len=32, bos_id=BOS_ID, eos_id=EOS_ID, num_beams=3):
        """
        img_ctx:    (1,C)
        memory:     (1,Lm,C)
        memory_mask:(1,Lm)
        global_ctx: (1,C)
        """
        device = img_ctx.device
        C = img_ctx.size(-1)
        Lm = memory.size(1)
        beam_size = num_beams

        img_beam = img_ctx.expand(beam_size, C)            # (beam,C)
        mem_beam = memory.expand(beam_size, Lm, C)         # (beam,Lm,C)
        mask_beam = memory_mask.expand(beam_size, Lm)      # (beam,Lm)
        h = global_ctx.expand(1, beam_size, C).contiguous()# (1,beam,C)
        c = torch.zeros_like(h)

        prev_tokens = torch.full(
            (beam_size,), bos_id, dtype=torch.long, device=device
        )

        beams = [{
            "tokens": [],
            "log_prob": 0.0,
            "finished": False
        }] + [{
            "tokens": [],
            "log_prob": float("-inf"),
            "finished": True
        } for _ in range(beam_size - 1)]

        for _ in range(max_len):
            logits, h, c = self._decode_step(
                prev_tokens, h, c, mem_beam, mask_beam, img_beam
            )  # logits: (beam,V)

            log_probs = F.log_softmax(logits, dim=-1)  # (beam,V)

            # beam đã finish → chỉ cho phép EOS
            for i, beam in enumerate(beams):
                if beam["finished"]:
                    log_probs[i, :] = float("-inf")
                    log_probs[i, eos_id] = 0.0

            beam_log_probs = torch.tensor(
                [b["log_prob"] for b in beams], device=device
            ).unsqueeze(1)  # (beam,1)
            total_log_probs = log_probs + beam_log_probs       # (beam,V)

            flat = total_log_probs.view(-1)                    # (beam*V,)
            topk_log_probs, topk_indices = torch.topk(flat, beam_size)

            new_beams = []
            new_prev_tokens = torch.zeros_like(prev_tokens)

            V = logits.size(-1)
            for new_i, (lp, idx) in enumerate(zip(topk_log_probs, topk_indices)):
                beam_idx = (idx // V).item()
                token_id = (idx % V).item()

                old_beam = beams[beam_idx]
                new_tokens = old_beam["tokens"].copy()
                finished = old_beam["finished"]

                if not finished:
                    if token_id == eos_id:
                        finished = True
                    else:
                        new_tokens.append(token_id)

                new_beams.append({
                    "tokens": new_tokens,
                    "log_prob": lp.item(),
                    "finished": finished
                })
                new_prev_tokens[new_i] = token_id

            beams = new_beams
            prev_tokens = new_prev_tokens

            if all(b["finished"] for b in beams):
                break

        finished_beams = [b for b in beams if b["finished"] and len(b["tokens"]) > 0]
        if len(finished_beams) == 0:
            finished_beams = beams
        best_beam = max(finished_beams, key=lambda b: b["log_prob"])
        return best_beam["tokens"]

    def generate_beam(self, images, q_input_ids, q_attention_mask,
                      max_len=32, bos_id=BOS_ID, eos_id=EOS_ID, num_beams=None):
        if num_beams is None:
            num_beams = getattr(self.cfg, "NUM_BEAMS", 3)

        self.eval()
        decoded = []
        with torch.no_grad():
            img_ctx, memory, memory_mask, global_ctx = self.build_memory(
                images, q_input_ids, q_attention_mask
            )  # img_ctx:(B,C), memory:(B,Lm,C), mask:(B,Lm), global:(B,C)
            B = images.size(0)
            for i in range(B):
                img_i = img_ctx[i:i+1]
                mem_i = memory[i:i+1]
                mask_i = memory_mask[i:i+1]
                ctx_i = global_ctx[i:i+1]
                token_ids = self._generate_one_beam(
                    img_i, mem_i, mask_i, ctx_i,
                    max_len=max_len,
                    bos_id=bos_id,
                    eos_id=eos_id,
                    num_beams=num_beams
                )
                decoded.append(decode_answer_ids(token_ids))
        return decoded


# Hàm tính loss, train, dev loss

In [None]:
model = VQAModel(cfg, vocab_size=len(answer_itos)).to(cfg.DEVICE)

# Freeze encoder (ViT + PhoBERT)
for p in model.vit.parameters():
    p.requires_grad = False
for p in model.text_encoder.parameters():
    p.requires_grad = False

for p in model.vit.parameters():
    p.requires_grad = False
for p in model.text_encoder.parameters():
    p.requires_grad = False

decoder_params = [p for n, p in model.named_parameters()
                  if not n.startswith("vit.") and not n.startswith("text_encoder.")]

optimizer = torch.optim.AdamW(
    decoder_params,
    lr=cfg.LR_DECODER,
    weight_decay=cfg.WEIGHT_DECAY
)

def compute_loss(logits, targets, smoothing=0.1):
    B, Lm1, V = logits.size()
    logits = logits.reshape(-1, V)        # (N, V)
    targets = targets.reshape(-1)         # (N,)

    # bỏ pad
    mask = (targets != PAD_ID)
    if mask.sum() == 0:
        return torch.tensor(0.0, device=logits.device)

    logits = logits[mask]
    targets = targets[mask]

    with torch.no_grad():
        true_dist = torch.zeros_like(logits)
        true_dist.fill_(smoothing / (V - 1))
        true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)

    log_probs = F.log_softmax(logits, dim=-1)
    loss = -(true_dist * log_probs).sum(dim=-1).mean()
    return loss


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=1,        
    verbose=True
)

train_epoch_losses = []
dev_epoch_losses = []
best_dev_loss = float("inf")
best_dev = float("inf")
global_step = 0

for epoch in range(cfg.NUM_EPOCHS):
    model.train()
    train_loss_sum = 0.0
    num_train_batches = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.NUM_EPOCHS}")
    
    for batch in pbar:
        images = batch["images"].to(cfg.DEVICE)
        q_ids  = batch["q_input_ids"].to(cfg.DEVICE)
        q_mask = batch["q_attention_mask"].to(cfg.DEVICE)
        ans_ids= batch["answer_ids"].to(cfg.DEVICE)
        
        optimizer.zero_grad()
        logits, targets = model(images, q_ids, q_mask, ans_ids)
        loss = compute_loss(logits, targets, smoothing=0.1)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        global_step += 1
        train_loss_sum += loss.item()
        num_train_batches += 1
        avg_loss = train_loss_sum / num_train_batches   # avg theo epoch, không phải global
        
        pbar.set_postfix({"loss": loss.item(), "avg_loss": avg_loss})
    
    # ======= Train loss epoch này =======
    epoch_train_loss = train_loss_sum / max(1, num_train_batches)
    train_epoch_losses.append(epoch_train_loss)
    print(f"==> Epoch {epoch+1} Train Loss: {epoch_train_loss:.4f}")
    
    # ======= Dev loss (KHÔNG tính metric trong lúc train) =======
    model.eval()
    dev_loss_sum = 0.0
    dev_count = 0
    with torch.no_grad():
        for batch in dev_loader:
            images = batch["images"].to(cfg.DEVICE)
            q_ids  = batch["q_input_ids"].to(cfg.DEVICE)
            q_mask = batch["q_attention_mask"].to(cfg.DEVICE)
            ans_ids= batch["answer_ids"].to(cfg.DEVICE)
            
            logits, targets = model(images, q_ids, q_mask, ans_ids)
            loss = compute_loss(logits, targets)
            dev_loss_sum += loss.item() * images.size(0)
            dev_count += images.size(0)
    
    dev_loss = dev_loss_sum / dev_count
    dev_epoch_losses.append(dev_loss)
    print(f"==> Epoch {epoch+1} Dev Loss: {dev_loss:.4f}")

    scheduler.step(dev_loss)

    if dev_loss < best_dev_loss - 1e-3:
        best_dev_loss = dev_loss
        no_imp = 0
        torch.save(model.state_dict(), os.path.join(cfg.OUT_DIR, "best_model.pt"))
    else:
        no_imp += 1
        if no_imp >= cfg.EARLY_STOP_PATIENCE:
            print("Early stopping.")
            break
    
    if dev_loss < best_dev_loss:
        best_dev_loss = dev_loss
        save_path = os.path.join(cfg.OUT_DIR, "best_model.pt")
        torch.save(model.state_dict(), save_path)
        print(f"    ✓ Saved best model to {save_path}")


Some weights of RobertaModel were not initialized from the model checkpoint at vinai/phobert-base-v2 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/10:   0%|          | 0/3855 [00:00<?, ?it/s]

==> Epoch 1 Train Loss: 9.7828
==> Epoch 1 Dev Loss: 4.4111


Epoch 2/10:   0%|          | 0/3855 [00:00<?, ?it/s]

==> Epoch 2 Train Loss: 4.3671
==> Epoch 2 Dev Loss: 4.0187


Epoch 3/10:   0%|          | 0/3855 [00:00<?, ?it/s]

==> Epoch 3 Train Loss: 4.0463
==> Epoch 3 Dev Loss: 3.8716


Epoch 4/10:   0%|          | 0/3855 [00:00<?, ?it/s]

==> Epoch 4 Train Loss: 3.8186
==> Epoch 4 Dev Loss: 3.7417


Epoch 5/10:   0%|          | 0/3855 [00:00<?, ?it/s]

==> Epoch 5 Train Loss: 3.6443
==> Epoch 5 Dev Loss: 3.6878


Epoch 6/10:   0%|          | 0/3855 [00:00<?, ?it/s]

# Sau khi train xong: Load best model + tính BLEU, METEOR, ROUGE-L

In [None]:
# Load lại best model (phòng trường hợp bạn chạy nhiều cell)
best_model_path = os.path.join(cfg.OUT_DIR, "best_model.pt")
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=cfg.DEVICE))
    print(f"Loaded best model from {best_model_path}")
else:
    print("Best model not found, dùng model hiện tại.")

def compute_text_metrics(refs: List[str], hyps: List[str]) -> Dict[str, float]:
    assert len(refs) == len(hyps)
    n = len(refs)
    
    bleu1 = bleu2 = bleu3 = bleu4 = 0.0
    meteor = 0.0
    rougeL = 0.0
    
    for r, h in zip(refs, hyps):
        ref_tokens = vi_seg(r)
        hyp_tokens = vi_seg(h)
        
        if len(hyp_tokens) == 0:
            continue
        
        bleu1 += sentence_bleu([ref_tokens], hyp_tokens, weights=(1,0,0,0), smoothing_function=smooth_fn)
        bleu2 += sentence_bleu([ref_tokens], hyp_tokens, weights=(0.5,0.5,0,0), smoothing_function=smooth_fn)
        bleu3 += sentence_bleu([ref_tokens], hyp_tokens, weights=(1/3,1/3,1/3,0), smoothing_function=smooth_fn)
        bleu4 += sentence_bleu([ref_tokens], hyp_tokens, weights=(0.25,0.25,0.25,0.25), smoothing_function=smooth_fn)
        
        meteor += meteor_score([ref_tokens], hyp_tokens)
        
        rouge = rouge_scorer_obj.score(" ".join(ref_tokens), " ".join(hyp_tokens))["rougeL"].fmeasure
        rougeL += rouge
    
    # trung bình
    return {
        "BLEU-1": bleu1 / n,
        "BLEU-2": bleu2 / n,
        "BLEU-3": bleu3 / n,
        "BLEU-4": bleu4 / n,
        "METEOR": meteor / n,
        "ROUGE-L": rougeL / n
    }

# Sinh dự đoán trên Dev
model.eval()
dev_refs = []
dev_hyps = []

with torch.no_grad():
    for batch in tqdm(dev_loader, desc="Eval Dev (metrics)"):
        images = batch["images"].to(cfg.DEVICE)
        q_ids  = batch["q_input_ids"].to(cfg.DEVICE)
        q_mask = batch["q_attention_mask"].to(cfg.DEVICE)
        
        preds = model.generate_beam(
            images, q_ids, q_mask,
            max_len=cfg.MAX_ANSWER_LEN,
            num_beams=cfg.NUM_BEAMS  
        )

        dev_hyps.extend(preds)
        dev_refs.extend(batch["answers"])

metrics = compute_text_metrics(dev_refs, dev_hyps)
print("=== DEV METRICS ===")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")


Loaded best model from outputs_vqa/best_model.pt


Eval Dev (metrics):   0%|          | 0/444 [00:00<?, ?it/s]

=== DEV METRICS ===
BLEU-1: 0.3057
BLEU-2: 0.2383
BLEU-3: 0.1922
BLEU-4: 0.1590
METEOR: 0.3151
ROUGE-L: 0.4377
