In [None]:
!pip install nltk

!python -m nltk.downloader punkt

!pip install pycocoevalcap

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
Collecting pycocoevalcap
  Downloading pycocoevalcap-1.2-py3-none-any.whl.metadata (3.2 kB)
Downloading pycocoevalcap-1.2-py3-none-any.whl (104.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.3/104.3 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: pycocoevalcap
Successfully installed pycocoevalcap-1.2


In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel


class Attention(nn.Module):
    """
    Attention Mechanism (Bahdanau Attention)
    Giúp Decoder "nhìn" vào các vùng quan trọng của ảnh khi sinh từng từ.
    """
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # Biến đổi feature ảnh
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # Biến đổi trạng thái hidden LSTM
        self.full_att = nn.Linear(attention_dim, 1)               # Tính điểm năng lượng
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        # encoder_out: (Batch, Num_Pixels, Encoder_Dim) -> (B, 197, 768)
        # decoder_hidden: (Batch, Decoder_Dim) -> (B, 512)

        att1 = self.encoder_att(encoder_out)          # (B, 197, Att_Dim)
        att2 = self.decoder_att(decoder_hidden)       # (B, Att_Dim)

        # Cộng broadcast: (B, 197, Att_Dim) + (B, 1, Att_Dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))) # (B, 197, 1)
        
        alpha = self.softmax(att)                     # (B, 197, 1) - Trọng số sự chú ý
        attention_weighted_encoding = (encoder_out * alpha).sum(dim=1) # (B, Encoder_Dim)

        return attention_weighted_encoding, alpha


class ViT_LSTM_Attention(nn.Module):
    def __init__(self, vocab_size, pad_idx, embed_dim=512, hidden_dim=512, attention_dim=256, unfreeze_layers=2, dropout=0.3):
        super(ViT_LSTM_Attention, self).__init__()
        self.vocab_size = vocab_size
        self.pad_idx = pad_idx
        
        # --- 1. ENCODER (ViT) ---
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.vit_dim = 768  # Output dimension của ViT Base
        
        # Freeze & Unfreeze logic
        for param in self.vit.parameters():
            param.requires_grad = False
        if unfreeze_layers > 0:
            for layer in self.vit.encoder.layer[-unfreeze_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
            for param in self.vit.layernorm.parameters():
                param.requires_grad = True

        # --- 2. DECODER (LSTM + Attention) ---
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attention = Attention(encoder_dim=self.vit_dim, decoder_dim=hidden_dim, attention_dim=attention_dim)
        
        # Khởi tạo trạng thái hidden/cell của LSTM từ đặc trưng ảnh
        self.init_h = nn.Linear(self.vit_dim, hidden_dim)
        self.init_c = nn.Linear(self.vit_dim, hidden_dim)
        
        # LSTM Cell: Input = Embedding + Context Vector (từ Attention)
        self.lstm_cell = nn.LSTMCell(embed_dim + self.vit_dim, hidden_dim)
        
        # Lớp đầu ra
        self.f_beta = nn.Linear(hidden_dim, self.vit_dim)  # Gating sigmoid (tùy chọn)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)
        
        self.dropout = nn.Dropout(dropout)


    def forward(self, images, captions):
        """
        images: (Batch, 3, 224, 224)
        captions: (Batch, Seq_Len) - Chứa cả <bos> và <eos>
        """
        device = images.device
        batch_size = images.size(0)
        
        # 1. Encode ảnh
        with torch.no_grad(): # Nếu muốn tiết kiệm mem, hoặc bỏ no_grad nếu finetune sâu
            vit_output = self.vit(pixel_values=images).last_hidden_state # (B, 197, 768)
        
        encoder_out = vit_output # Giữ nguyên để attention soi vào 197 patches
        
        # 2. Khởi tạo LSTM state từ trung bình ảnh
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (B, hidden_dim)
        c = self.init_c(mean_encoder_out)
        
        # 3. Chuẩn bị Teacher Forcing
        # Input: <bos> ... từ_cuối (bỏ <eos> hoặc padding cuối)
        # Length thực tế cần chạy = seq_len - 1
        embeddings = self.embedding(captions) # (B, Seq_Len, Embed_Dim)
        
        # Tensor chứa kết quả: (B, Seq_Len-1, Vocab)
        # Lưu ý: shape output phải khớp với targets (captions[:, 1:])
        seq_len = captions.size(1) - 1 
        predictions = torch.zeros(batch_size, seq_len, self.vocab_size).to(device)
        
        # 4. Vòng lặp LSTM từng bước thời gian
        for t in range(seq_len):
            # Lấy attention context từ hidden state cũ
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            
            # Gating (mẹo nhỏ giúp train ổn định hơn, tuỳ chọn)
            gate = torch.sigmoid(self.f_beta(h))
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            # Input cho LSTM step t: Embedding từ t + Context
            lstm_input = torch.cat([embeddings[:, t, :], attention_weighted_encoding], dim=1)
            
            # Update LSTM
            h, c = self.lstm_cell(lstm_input, (h, c))
            
            # Dự đoán từ tiếp theo
            preds = self.fc_out(self.dropout(h))
            predictions[:, t, :] = preds
            
        return predictions

    def generate(self, images, bos_idx, eos_idx, max_len=40):
        """Hàm sinh caption (Greedy Search)"""
        self.eval()
        device = images.device
        batch_size = images.size(0)
        
        with torch.no_grad():
            vit_output = self.vit(pixel_values=images).last_hidden_state
            encoder_out = vit_output
            mean_encoder_out = encoder_out.mean(dim=1)
            h = self.init_h(mean_encoder_out)
            c = self.init_c(mean_encoder_out)
            
            # Bắt đầu với <bos>
            inputs = torch.tensor([bos_idx] * batch_size, dtype=torch.long).to(device) # (B,)
            generated_ids = []
            
            for i in range(max_len):
                embeddings = self.embedding(inputs) # (B, Embed_Dim)
                
                attention_weighted_encoding, alpha = self.attention(encoder_out, h)
                gate = torch.sigmoid(self.f_beta(h))
                attention_weighted_encoding = gate * attention_weighted_encoding
                
                lstm_input = torch.cat([embeddings, attention_weighted_encoding], dim=1)
                
                h, c = self.lstm_cell(lstm_input, (h, c))
                preds = self.fc_out(h) # (B, Vocab)
                
                predicted_id = preds.argmax(dim=1) # (B,)
                generated_ids.append(predicted_id.unsqueeze(1))
                
                # Update input cho bước sau
                inputs = predicted_id
                
                # (Tùy chọn) Break nếu tất cả batch đều ra EOS (code đơn giản bỏ qua)

            generated_ids = torch.cat(generated_ids, dim=1) # (B, Max_Len)
            return generated_ids

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from process_data import create_dataloaders


# Setup paths
os.makedirs("model_ver5", exist_ok=True)
CSV_PATH = "data/final_combined_ds_tokenized.csv"
VOCAB_PATH = "data/vocab_vi_underthesea.json"
DATA_ROOT = "data"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"DEVICE: {device}")


# Data Loading
print("Creating dataloaders...")
train_loader, val_loader, test_loader, train_ds, val_ds, test_ds = create_dataloaders(
    csv_path=CSV_PATH,
    vocab_path=VOCAB_PATH,
    data_root=DATA_ROOT,
    batch_size=64
)

vocab_size = len(train_ds.stoi)
pad_idx = train_ds.pad_idx
bos_idx = train_ds.bos_idx
eos_idx = train_ds.eos_idx
print(f"\nDATASET INFO:")
print(f"  Vocab size: {vocab_size}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Pad/BOS/EOS: {pad_idx}/{bos_idx}/{eos_idx}\n")


# Khởi tạo Model
model = ViT_LSTM_Attention(
    vocab_size=vocab_size,
    pad_idx=pad_idx,
    embed_dim=512,
    hidden_dim=512,
    attention_dim=256,
    unfreeze_layers=2,
    dropout=0.3
).to(device)

# 3. Loss Function & Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=0.1)

vit_params = list(map(id, model.vit.parameters()))
base_params = filter(lambda p: id(p) not in vit_params, model.parameters())

optimizer = torch.optim.AdamW([
    {'params': model.vit.parameters(), 'lr': 1e-5}, # Encoder học chậm
    {'params': base_params, 'lr': 4e-4}             # Decoder LSTM cần LR cao hơn Transformer chút (thường 4e-4 hoặc 5e-4)
], weight_decay=1e-4)

In [None]:
train_losses = []
val_losses = []
best_val_loss = float('inf')
PATIENCE = 4
patience_counter = 0
EPOCHS = 20
VOCAB_SIZE = vocab_size

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    # TRAINING
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for images, captions, lengths in progress_bar:
        images = images.to(device)
        captions = captions.to(device)
        
        # Với code model mới:
        # Input vào model là 'captions' đầy đủ (để model tự lấy embedding từng bước)
        # Nhưng target để tính loss là từ từ thứ 2 trở đi (bỏ <bos>)
        targets = captions[:, 1:] 
        
        optimizer.zero_grad()
        
        # Forward pass
        # Model ViT_LSTM_Attention của chúng ta đã viết sẵn vòng lặp bên trong
        # Nó sẽ trả về predictions cho (seq_len - 1) bước
        outputs = model(images, captions) 
        
        # Flatten để tính Loss
        loss = criterion(outputs.reshape(-1, VOCAB_SIZE), targets.reshape(-1))
        
        loss.backward()
        
        # Gradient Clipping rất quan trọng với LSTM để tránh bùng nổ gradient
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
        
    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # --- VALIDATION ---
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)
            targets = captions[:, 1:]
            
            outputs = model(images, captions)
            
            loss = criterion(outputs.reshape(-1, VOCAB_SIZE), targets.reshape(-1))
            total_val_loss += loss.item()
            
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f} | Val Loss = {avg_val_loss:.4f}")


    # CHECKPOINTING
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        
        torch.save(model.state_dict(), 'model_ver5/best_model.pth')
        print(f"  ✓ Best model saved! Val Loss: {avg_val_loss:.4f}")
    else:
        patience_counter += 1
        print(f"  ⚠ No improvement. Patience: {patience_counter}/{PATIENCE}")
    
    # EARLY STOPPING
    if patience_counter >= PATIENCE:
        print(f"\n⚠ Early stopping triggered after {epoch+1} epochs!")
        print(f"Best Val Loss: {best_val_loss:.4f}")
        break

print("Hoàn tất huấn luyện!")

In [None]:
from evaluate import generate_captions_for_dataset, compute_bleu_meteor, compute_cider
import json

BEST_MODEL_PATH = "model_5/best_model.pth"

model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))

test_refs, test_hyps = generate_captions_for_dataset(
    model,
    test_ds,
    device,
    bos_idx=bos_idx,
    eos_idx=eos_idx,
    max_len=40,
    batch_size=64,
)

metrics = compute_bleu_meteor(test_refs, test_hyps)

try:
    cider_score = compute_cider(test_refs, test_hyps)
    metrics["CIDEr"] = cider_score
except Exception as e:
    print("Không tính được CIDEr, lỗi:", e)

print("ViT + Transformer:")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")


output_caption_file = "model_5/vit_transformer_generated_captions.json"
save_captions = {k: {"pred": v, "ref": test_refs.get(k, [])} for k, v in test_hyps.items()}
with open(output_caption_file, "w", encoding="utf-8") as f:
    json.dump(save_captions, f, indent=2, ensure_ascii=False)