In [1]:
import sys
from pathlib import Path

CODE_ROOT = Path("/kaggle/input/envi-nmt-code/src")
DATA_ROOT = Path("/kaggle/input/envi-nmt-data-p2/data_p2")

PROCESSED_DIR = DATA_ROOT / "processed"     # train.en, train.vi, dev.*, test.*
SPM_PATH      = DATA_ROOT / "spm" / "spm.model"
CKPT_PATH     = Path("/kaggle/input/envi-nmt-data-p2/data_p2/best_transformer_v3.pt")

sys.path.insert(0, str(CODE_ROOT))

print("CODE_ROOT:", CODE_ROOT)
print("PROCESSED_DIR:", PROCESSED_DIR)
print("SPM_PATH:", SPM_PATH, "exists:", SPM_PATH.exists())
print("CKPT_PATH:", CKPT_PATH, "exists:", CKPT_PATH.exists())
print("train.en exists:", (PROCESSED_DIR/"train.en").exists())


CODE_ROOT: /kaggle/input/envi-nmt-code/src
PROCESSED_DIR: /kaggle/input/envi-nmt-data-p2/data_p2/processed
SPM_PATH: /kaggle/input/envi-nmt-data-p2/data_p2/spm/spm.model exists: True
CKPT_PATH: /kaggle/input/envi-nmt-data-p2/data_p2/best_transformer_v3.pt exists: True
train.en exists: True


In [2]:
!pip -q install sacrebleu sentencepiece


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import torch
from torch.utils.data import DataLoader

from tokenizer import SubwordTokenizer
from dataset import NMTDataset, collate_fn

tok = SubwordTokenizer(str(SPM_PATH))
pad_id = tok.pad_id

train_ds = NMTDataset(str(PROCESSED_DIR), split="train", tokenizer=tok, max_src_len=80, max_tgt_len=80)
dev_ds   = NMTDataset(str(PROCESSED_DIR), split="dev",   tokenizer=tok, max_src_len=80, max_tgt_len=80)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,  num_workers=2,
                          collate_fn=lambda b: collate_fn(b, pad_id=pad_id))
dev_loader   = DataLoader(dev_ds,   batch_size=64, shuffle=False, num_workers=2,
                          collate_fn=lambda b: collate_fn(b, pad_id=pad_id))

batch = next(iter(train_loader))
print({k: v.shape for k, v in batch.items()})
print("BOS check:", batch["tgt_in_ids"][0,0].item() == tok.bos_id)


{'src_ids': torch.Size([64, 80]), 'tgt_in_ids': torch.Size([64, 80]), 'tgt_out_ids': torch.Size([64, 80]), 'src_padding_mask': torch.Size([64, 80]), 'tgt_padding_mask': torch.Size([64, 80])}
BOS check: True


In [4]:
import torch

def make_src_mask(src_padding_mask: torch.Tensor):
    """
    src_padding_mask: (B,S) bool, True tại PAD
    return: (B,1,1,S) bool, True = allowed (không bị mask)
    """
    return (~src_padding_mask).unsqueeze(1).unsqueeze(2)

def build_causal_cache(max_tgt_len: int, device: torch.device):
    """
    Tạo causal mask 1 lần:
    shape (max_tgt_len, max_tgt_len), True ở (i>=j)
    """
    return torch.tril(
        torch.ones((max_tgt_len, max_tgt_len), dtype=torch.bool, device=device)
    )

def make_tgt_mask(tgt_padding_mask: torch.Tensor, causal_cache: torch.Tensor):
    """
    tgt_padding_mask: (B,T) bool, True tại PAD
    causal_cache: (maxT,maxT) bool (đã build 1 lần)
    return: (B,1,T,T) bool, True = allowed (causal + pad theo KEY)
    """
    B, T = tgt_padding_mask.shape

    nonpad = (~tgt_padding_mask)              # (B,T) True = token thật
    causal = causal_cache[:T, :T]             # (T,T)

    causal = causal.unsqueeze(0).unsqueeze(1) # (1,1,T,T)
    nonpad_k = nonpad.unsqueeze(1).unsqueeze(2)  # (B,1,1,T)

    return nonpad_k & causal                  # (B,1,T,T)


In [5]:
import torch
from model import Transformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SRC_VOCAB = 8000
TGT_VOCAB = 8000
D_MODEL   = 384
N_LAYERS  = 4
N_HEADS   = 8
D_FF      = 1536
DROPOUT   = 0.1
MAX_LEN   = 5000

model = Transformer(
    src_vocab_size=SRC_VOCAB,
    tgt_vocab_size=TGT_VOCAB,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_ff=D_FF,
    dropout=DROPOUT,
    max_len=MAX_LEN,
).to(device)

state = torch.load(str(CKPT_PATH), map_location=device)
missing, unexpected = model.load_state_dict(state, strict=True)

print("✅ Loaded checkpoint (strict=True)")
print("Missing keys:", len(missing))
print("Unexpected keys:", len(unexpected))


✅ Loaded checkpoint (strict=True)
Missing keys: 0
Unexpected keys: 0


In [6]:
import math, time
import torch
from pathlib import Path
from torch import nn
from torch.optim import AdamW
from tqdm.notebook import tqdm   # 

criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = AdamW(model.parameters(), lr=3e-4, betas=(0.9,0.98), eps=1e-9, weight_decay=0.01)

global_step = 0
PRINT_EVERY = 10   # 

# 
MAX_TGT_LEN = 80
causal_cache = build_causal_cache(MAX_TGT_LEN, device)

def run_one_epoch(model, loader, train: bool, epoch: int, epochs: int):
    global global_step

    model.train() if train else model.eval()

    total_loss = 0.0
    total_tokens = 0
    start_time = time.time()

    mode = "train" if train else "dev"
    pbar = tqdm(enumerate(loader, 1), total=len(loader), leave=False, mininterval=0.5,
                desc=f"{mode} {epoch}/{epochs}")

    for bi, batch in pbar:
        if bi == 1:
            print(f"✅ {mode} epoch {epoch}: first batch ok")

        src     = batch["src_ids"].to(device)
        tgt_in  = batch["tgt_in_ids"].to(device)
        tgt_out = batch["tgt_out_ids"].to(device)

        src_pad = batch["src_padding_mask"].to(device).bool()
        tgt_pad = batch["tgt_padding_mask"].to(device).bool()

        src_mask = make_src_mask(src_pad)                 # (B,1,1,S)
        tgt_mask = make_tgt_mask(tgt_pad, causal_cache)   # ✅ sửa dòng này

        if train:
            logits = model(src, tgt_in, src_mask, tgt_mask)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            global_step += 1
        else:
            with torch.no_grad():
                logits = model(src, tgt_in, src_mask, tgt_mask)
                loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        n_tokens = (tgt_out != pad_id).sum().item()
        total_loss   += loss.item() * n_tokens
        total_tokens += n_tokens

        avg_loss_so_far = total_loss / max(1, total_tokens)
        lr = optimizer.param_groups[0]["lr"]

        elapsed = time.time() - start_time
        it_s = bi / max(1e-9, elapsed)
        eta_s = (len(loader) - bi) / max(1e-9, it_s)

        pbar.set_postfix({
            "step": global_step if train else "-",
            "loss": f"{loss.item():.4f}",
            "avg":  f"{avg_loss_so_far:.4f}",
            "ppl":  f"{math.exp(min(20, avg_loss_so_far)):.2f}",
            "lr":   f"{lr:.2e}",
            "eta":  f"{eta_s/60:.1f}m"
        })

        if train and (global_step % PRINT_EVERY == 0):
            print(f"[epoch {epoch}/{epochs}] step={global_step} batch={bi}/{len(loader)} "
                  f"loss={loss.item():.4f} avg={avg_loss_so_far:.4f} "
                  f"ppl={math.exp(min(20, avg_loss_so_far)):.2f} lr={lr:.2e}")

    avg_loss = total_loss / max(1, total_tokens)
    ppl = math.exp(min(20, avg_loss))
    return avg_loss, ppl


best_dev = 1e9
OUT_DIR = Path("/kaggle/working/vlsp_finetune")
OUT_DIR.mkdir(parents=True, exist_ok=True)

EPOCHS = 3
for epoch in range(1, EPOCHS + 1):
    train_loss, train_ppl = run_one_epoch(model, train_loader, train=True,  epoch=epoch, epochs=EPOCHS)
    dev_loss, dev_ppl     = run_one_epoch(model, dev_loader,   train=False, epoch=epoch, epochs=EPOCHS)

    print(f"\nEpoch {epoch} DONE | train_loss={train_loss:.4f} ppl={train_ppl:.2f} | dev_loss={dev_loss:.4f} ppl={dev_ppl:.2f}\n")

    if dev_loss < best_dev:
        best_dev = dev_loss
        torch.save(model.state_dict(), OUT_DIR / "best_finetune.pt")
        print("✅ saved best:", OUT_DIR / "best_finetune.pt")


train 1/3:   0%|          | 0/7650 [00:00<?, ?it/s]

✅ train epoch 1: first batch ok
[epoch 1/3] step=10 batch=10/7650 loss=7.1836 avg=7.9732 ppl=2902.04 lr=3.00e-04
[epoch 1/3] step=20 batch=20/7650 loss=6.8976 avg=7.4648 ppl=1745.58 lr=3.00e-04
[epoch 1/3] step=30 batch=30/7650 loss=6.6993 avg=7.2193 ppl=1365.49 lr=3.00e-04
[epoch 1/3] step=40 batch=40/7650 loss=6.2765 avg=7.0484 ppl=1151.07 lr=3.00e-04
[epoch 1/3] step=50 batch=50/7650 loss=6.4063 avg=6.9228 ppl=1015.16 lr=3.00e-04
[epoch 1/3] step=60 batch=60/7650 loss=6.1744 avg=6.8044 ppl=901.84 lr=3.00e-04
[epoch 1/3] step=70 batch=70/7650 loss=5.9963 avg=6.7047 ppl=816.19 lr=3.00e-04
[epoch 1/3] step=80 batch=80/7650 loss=5.8949 avg=6.6154 ppl=746.52 lr=3.00e-04
[epoch 1/3] step=90 batch=90/7650 loss=5.6300 avg=6.5291 ppl=684.76 lr=3.00e-04
[epoch 1/3] step=100 batch=100/7650 loss=6.0306 avg=6.4560 ppl=636.53 lr=3.00e-04
[epoch 1/3] step=110 batch=110/7650 loss=5.6456 avg=6.3803 ppl=590.09 lr=3.00e-04
[epoch 1/3] step=120 batch=120/7650 loss=5.5894 avg=6.3130 ppl=551.68 lr=3.00e-

dev 1/3:   0%|          | 0/157 [00:00<?, ?it/s]

✅ dev epoch 1: first batch ok

Epoch 1 DONE | train_loss=2.4884 ppl=12.04 | dev_loss=1.6929 ppl=5.44

✅ saved best: /kaggle/working/vlsp_finetune/best_finetune.pt


train 2/3:   0%|          | 0/7650 [00:00<?, ?it/s]

✅ train epoch 2: first batch ok
[epoch 2/3] step=7660 batch=10/7650 loss=1.7755 avg=1.8079 ppl=6.10 lr=3.00e-04
[epoch 2/3] step=7670 batch=20/7650 loss=1.8677 avg=1.8112 ppl=6.12 lr=3.00e-04
[epoch 2/3] step=7680 batch=30/7650 loss=1.8308 avg=1.8108 ppl=6.12 lr=3.00e-04
[epoch 2/3] step=7690 batch=40/7650 loss=1.9466 avg=1.8014 ppl=6.06 lr=3.00e-04
[epoch 2/3] step=7700 batch=50/7650 loss=1.7230 avg=1.8039 ppl=6.07 lr=3.00e-04
[epoch 2/3] step=7710 batch=60/7650 loss=1.6065 avg=1.8056 ppl=6.08 lr=3.00e-04
[epoch 2/3] step=7720 batch=70/7650 loss=1.6489 avg=1.7936 ppl=6.01 lr=3.00e-04
[epoch 2/3] step=7730 batch=80/7650 loss=1.5639 avg=1.7943 ppl=6.02 lr=3.00e-04
[epoch 2/3] step=7740 batch=90/7650 loss=1.5829 avg=1.7937 ppl=6.01 lr=3.00e-04
[epoch 2/3] step=7750 batch=100/7650 loss=1.6540 avg=1.7902 ppl=5.99 lr=3.00e-04
[epoch 2/3] step=7760 batch=110/7650 loss=1.7773 avg=1.7918 ppl=6.00 lr=3.00e-04
[epoch 2/3] step=7770 batch=120/7650 loss=1.8820 avg=1.7974 ppl=6.03 lr=3.00e-04
[epoc

dev 2/3:   0%|          | 0/157 [00:00<?, ?it/s]

✅ dev epoch 2: first batch ok

Epoch 2 DONE | train_loss=1.7001 ppl=5.47 | dev_loss=1.4769 ppl=4.38

✅ saved best: /kaggle/working/vlsp_finetune/best_finetune.pt


train 3/3:   0%|          | 0/7650 [00:00<?, ?it/s]

✅ train epoch 3: first batch ok
[epoch 3/3] step=15310 batch=10/7650 loss=1.5533 avg=1.5292 ppl=4.61 lr=3.00e-04
[epoch 3/3] step=15320 batch=20/7650 loss=1.5537 avg=1.5291 ppl=4.61 lr=3.00e-04
[epoch 3/3] step=15330 batch=30/7650 loss=1.4956 avg=1.5505 ppl=4.71 lr=3.00e-04
[epoch 3/3] step=15340 batch=40/7650 loss=1.4473 avg=1.5393 ppl=4.66 lr=3.00e-04
[epoch 3/3] step=15350 batch=50/7650 loss=1.4484 avg=1.5363 ppl=4.65 lr=3.00e-04
[epoch 3/3] step=15360 batch=60/7650 loss=1.4628 avg=1.5322 ppl=4.63 lr=3.00e-04
[epoch 3/3] step=15370 batch=70/7650 loss=1.6011 avg=1.5375 ppl=4.65 lr=3.00e-04
[epoch 3/3] step=15380 batch=80/7650 loss=1.3176 avg=1.5366 ppl=4.65 lr=3.00e-04
[epoch 3/3] step=15390 batch=90/7650 loss=1.3634 avg=1.5351 ppl=4.64 lr=3.00e-04
[epoch 3/3] step=15400 batch=100/7650 loss=1.6328 avg=1.5365 ppl=4.65 lr=3.00e-04
[epoch 3/3] step=15410 batch=110/7650 loss=1.3519 avg=1.5358 ppl=4.64 lr=3.00e-04
[epoch 3/3] step=15420 batch=120/7650 loss=1.5937 avg=1.5369 ppl=4.65 lr=3.

dev 3/3:   0%|          | 0/157 [00:00<?, ?it/s]

✅ dev epoch 3: first batch ok

Epoch 3 DONE | train_loss=1.5268 ppl=4.60 | dev_loss=1.3663 ppl=3.92

✅ saved best: /kaggle/working/vlsp_finetune/best_finetune.pt


In [7]:
import torch
from tqdm.notebook import tqdm
import sacrebleu

# =========================
# Utils
# =========================
def load_lines(p):
    with open(p, encoding="utf-8") as f:
        return [line.strip() for line in f if line.strip()]

# =========================
# Causal cache for decoding
# =========================
DECODE_MAX_LEN = 120
causal_cache = build_causal_cache(max_tgt_len=DECODE_MAX_LEN + 5, device=device)

# =========================
# Greedy decoding
# =========================
@torch.no_grad()
def greedy_translate_one(src_ids: torch.Tensor, max_len=DECODE_MAX_LEN) -> str:
    model.eval()
    src_ids = src_ids.unsqueeze(0).to(device)  # (1,S)

    # 1 câu => không PAD
    src_pad_mask = torch.zeros_like(src_ids, dtype=torch.bool, device=device)  # (1,S)
    src_mask = make_src_mask(src_pad_mask)  # (1,1,1,S)

    enc = model.encode(src_ids, src_mask)

    ys = torch.tensor([[tok.bos_id]], device=device, dtype=torch.long)  # (1,1)
    for _ in range(max_len):
        tgt_pad = (ys == pad_id)  # (1,T)
        tgt_mask = make_tgt_mask(tgt_pad, causal_cache)  # (1,1,T,T)

        dec = model.decode(ys, enc, src_mask, tgt_mask)
        logits = model.projection(dec)  # (1,T,V)

        next_id = torch.argmax(logits[:, -1, :], dim=-1).item()
        ys = torch.cat([ys, torch.tensor([[next_id]], device=device)], dim=1)

        if next_id == tok.eos_id:
            break

    return tok.decode(ys.squeeze(0).tolist())

# =========================
# Beam search decoding
# =========================
@torch.no_grad()
def beam_translate_one(
    src_ids: torch.Tensor,
    beam_size: int = 4,
    max_len: int = DECODE_MAX_LEN,
    len_norm_alpha: float = 0.6,
) -> str:
    model.eval()
    src_ids = src_ids.unsqueeze(0).to(device)  # (1,S)

    src_pad_mask = torch.zeros_like(src_ids, dtype=torch.bool, device=device)
    src_mask = make_src_mask(src_pad_mask)

    enc = model.encode(src_ids, src_mask)

    # beam item: (tokens_list, sum_logprob, ended_bool)
    beams = [([tok.bos_id], 0.0, False)]

    def score(sum_logprob: float, length: int) -> float:
        # GNMT length normalization
        if len_norm_alpha <= 0:
            return sum_logprob
        denom = ((5.0 + length) / 6.0) ** len_norm_alpha
        return sum_logprob / denom

    for _ in range(max_len):
        candidates = []

        for tokens, sum_lp, ended in beams:
            if ended:
                candidates.append((tokens, sum_lp, True))
                continue

            ys = torch.tensor(tokens, device=device, dtype=torch.long).unsqueeze(0)  # (1,T)
            tgt_pad = (ys == pad_id)
            tgt_mask = make_tgt_mask(tgt_pad, causal_cache)

            dec = model.decode(ys, enc, src_mask, tgt_mask)
            logits = model.projection(dec)  # (1,T,V)
            log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)  # (V,)

            topk_lp, topk_id = torch.topk(log_probs, k=beam_size)
            topk_lp = topk_lp.tolist()
            topk_id = topk_id.tolist()

            for lp, tid in zip(topk_lp, topk_id):
                new_tokens = tokens + [tid]
                new_sum_lp = sum_lp + lp
                new_ended = (tid == tok.eos_id)
                candidates.append((new_tokens, new_sum_lp, new_ended))

        candidates.sort(key=lambda x: score(x[1], len(x[0])), reverse=True)
        beams = candidates[:beam_size]

        if all(b[2] for b in beams):  # all ended
            break

    best_tokens, _, _ = max(beams, key=lambda x: score(x[1], len(x[0])))
    return tok.decode(best_tokens)

# =========================
# Evaluate BLEU on dev
# =========================
dev_src = load_lines(PROCESSED_DIR / "dev.en")
dev_ref = load_lines(PROCESSED_DIR / "dev.vi")

N = None  
M = min(N, len(dev_src)) if N is not None else len(dev_src)

preds_greedy = []
preds_beam   = []

for i in tqdm(range(M), desc="Translating dev (greedy+beam)"):
    src_ids = torch.tensor(
        tok.encode_src(dev_src[i], add_bos=False, add_eos=True),
        dtype=torch.long
    )
    preds_greedy.append(greedy_translate_one(src_ids))
    preds_beam.append(beam_translate_one(src_ids, beam_size=6, len_norm_alpha=1.0))

bleu_g = sacrebleu.corpus_bleu(preds_greedy, [dev_ref[:M]])
bleu_b = sacrebleu.corpus_bleu(preds_beam,   [dev_ref[:M]])

print("DEV BLEU (greedy):", bleu_g.score)
print("DEV BLEU (beam4) :", bleu_b.score)


Translating dev (greedy+beam):   0%|          | 0/9992 [00:00<?, ?it/s]

DEV BLEU (greedy): 41.08170874469103
DEV BLEU (beam4) : 42.26405444624932
