# VLSP 2025 MT — Vi → En Transformer (from scratch) with SentencePiece  
## Final (2×T4 + Token-based Dynamic Batching)

Notebook này:
- Chia dữ liệu **9/1 train/valid**
- SentencePiece **Unigram + byte_fallback**
- Transformer **from scratch** (Pre-LN + GEGLU + weight tying)
- Training: **AMP + Warmup+Cosine + EMA**
- DataLoader: **Token-based dynamic batching** (tối ưu cho **2×T4**)
- Decode: **Beam search**
- Metrics: **BLEU / TER / METEOR** + error analysis


In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [2]:
!pip -q install sentencepiece sacrebleu nltk tqdm regex

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import os, re, math, json, random, unicodedata
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm

import sentencepiece as spm
import sacrebleu
from sacrebleu.metrics import TER

import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
from nltk.translate.meteor_score import meteor_score

SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
if torch.cuda.is_available():
    print('GPU0:', torch.cuda.get_device_name(0))
print('GPU count:', torch.cuda.device_count())


[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...


Device: cuda
GPU0: Tesla T4
GPU count: 2


## 1) Data (9/1 train/valid)

In [4]:
DATA_DIR = "/kaggle/input/dataset2/MedicalDataset_VLSP"
TRAIN_VI = os.path.join(DATA_DIR, "train.vi.txt")
TRAIN_EN = os.path.join(DATA_DIR, "train.en.txt")
PUB_VI   = os.path.join(DATA_DIR, "public_test.vi.txt")
PUB_EN   = os.path.join(DATA_DIR, "public_test.en.txt")

for p in [TRAIN_VI, TRAIN_EN, PUB_VI, PUB_EN]:
    print(p, "=>", "OK" if os.path.exists(p) else "NOT FOUND")

assert os.path.exists(TRAIN_VI) and os.path.exists(TRAIN_EN)


/kaggle/input/dataset2/MedicalDataset_VLSP/train.vi.txt => OK
/kaggle/input/dataset2/MedicalDataset_VLSP/train.en.txt => OK
/kaggle/input/dataset2/MedicalDataset_VLSP/public_test.vi.txt => OK
/kaggle/input/dataset2/MedicalDataset_VLSP/public_test.en.txt => OK


In [5]:
def normalize_vi(x: str) -> str:
    x = unicodedata.normalize("NFC", x)
    x = re.sub(r"\s+", " ", x).strip()
    x = re.sub(r"\s+([,.;:!?])", r"\1", x)
    return x

def normalize_en(x: str) -> str:
    x = unicodedata.normalize("NFC", x)
    x = re.sub(r"\s+", " ", x).strip()
    x = re.sub(r"\s+([,.;:!?])", r"\1", x)
    return x

def load_parallel(path_vi: str, path_en: str) -> Tuple[List[str], List[str]]:
    with open(path_vi, encoding="utf-8") as f1, open(path_en, encoding="utf-8") as f2:
        vi = [normalize_vi(l.strip()) for l in f1]
        en = [normalize_en(l.strip()) for l in f2]
    n = min(len(vi), len(en))
    return vi[:n], en[:n]

vi_all, en_all = load_parallel(TRAIN_VI, TRAIN_EN)
print("Total pairs:", len(vi_all))


Total pairs: 500000


In [6]:
idx = np.arange(len(vi_all))
rng = np.random.default_rng(SEED)
rng.shuffle(idx)

split = int(0.9 * len(idx))
train_idx, valid_idx = idx[:split], idx[split:]

train_vi = [vi_all[i] for i in train_idx]
train_en = [en_all[i] for i in train_idx]
valid_vi = [vi_all[i] for i in valid_idx]
valid_en = [en_all[i] for i in valid_idx]

print("Train:", len(train_vi), "Valid:", len(valid_vi))


Train: 450000 Valid: 50000


## 2) SentencePiece (Unigram + byte_fallback)

In [7]:
SP_DIR = "./spm"
os.makedirs(SP_DIR, exist_ok=True)

SRC_MODEL_PREFIX = os.path.join(SP_DIR, "spm_vi")
TRG_MODEL_PREFIX = os.path.join(SP_DIR, "spm_en")

SRC_VOCAB_SIZE = 24000
TRG_VOCAB_SIZE = 24000

vi_txt = os.path.join(SP_DIR, "train_vi.txt")
en_txt = os.path.join(SP_DIR, "train_en.txt")

if not os.path.exists(vi_txt):
    with open(vi_txt, "w", encoding="utf-8") as f:
        f.write("\n".join(train_vi))
if not os.path.exists(en_txt):
    with open(en_txt, "w", encoding="utf-8") as f:
        f.write("\n".join(train_en))

if not os.path.exists(SRC_MODEL_PREFIX + ".model"):
    spm.SentencePieceTrainer.Train(
        input=vi_txt,
        model_prefix=SRC_MODEL_PREFIX,
        vocab_size=SRC_VOCAB_SIZE,
        model_type="unigram",
        character_coverage=0.9995,
        byte_fallback=True,
        pad_id=0, unk_id=1, bos_id=2, eos_id=3
    )

if not os.path.exists(TRG_MODEL_PREFIX + ".model"):
    spm.SentencePieceTrainer.Train(
        input=en_txt,
        model_prefix=TRG_MODEL_PREFIX,
        vocab_size=TRG_VOCAB_SIZE,
        model_type="unigram",
        character_coverage=1.0,
        byte_fallback=True,
        pad_id=0, unk_id=1, bos_id=2, eos_id=3
    )

sp_src = spm.SentencePieceProcessor(model_file=SRC_MODEL_PREFIX + ".model")
sp_trg = spm.SentencePieceProcessor(model_file=TRG_MODEL_PREFIX + ".model")

PAD, UNK, BOS, EOS = 0, 1, 2, 3

print("SRC vocab:", sp_src.get_piece_size(), "TRG vocab:", sp_trg.get_piece_size())


SRC vocab: 24000 TRG vocab: 24000


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: ./spm/train_vi.txt
  input_format: 
  model_prefix: ./spm/spm_vi
  model_type: UNIGRAM
  vocab_size: 24000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 1
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy:

## 3) Dataset + Token-based Dynamic Batching (2×T4 preset)

Thiết lập token budget lớn hơn để tận dụng 2 GPU.  
Nếu gặp OOM, giảm `MAX_TOKENS_TRAIN` theo bậc 4000.


In [8]:
MAX_LEN = 80

class SPMT(Dataset):
    def __init__(self, src_texts, trg_texts, max_len):
        self.src = src_texts
        self.trg = trg_texts
        self.max_len = max_len

    def __len__(self):
        return len(self.src)

    def __getitem__(self, i):
        s_ids = sp_src.encode(self.src[i], out_type=int)[: self.max_len - 2]
        t_ids = sp_trg.encode(self.trg[i], out_type=int)[: self.max_len - 2]
        s = [BOS] + s_ids + [EOS]
        t = [BOS] + t_ids + [EOS]
        return torch.tensor(s, dtype=torch.long), torch.tensor(t, dtype=torch.long)

def collate_fn(batch):
    src, trg = zip(*batch)
    src = pad_sequence(src, batch_first=True, padding_value=PAD)
    trg = pad_sequence(trg, batch_first=True, padding_value=PAD)
    return src, trg

train_ds = SPMT(train_vi, train_en, MAX_LEN)
valid_ds = SPMT(valid_vi, valid_en, MAX_LEN)


In [9]:
def build_lengths_cache(prefix: str, src_texts: List[str], trg_texts: List[str]):
    path = os.path.join(SP_DIR, f"{prefix}_toklen.npy")
    if os.path.exists(path):
        return np.load(path)
    lens = np.zeros(len(src_texts), dtype=np.int32)
    for i in tqdm(range(len(src_texts)), desc=f"TokLen {prefix}"):
        s = sp_src.encode(src_texts[i], out_type=int)
        t = sp_trg.encode(trg_texts[i], out_type=int)
        lens[i] = min(len(s), MAX_LEN-2) + min(len(t), MAX_LEN-2) + 4
    np.save(path, lens)
    return lens

train_lens = build_lengths_cache("train", train_vi, train_en)
valid_lens = build_lengths_cache("valid", valid_vi, valid_en)

print("Train avg tok/sample:", float(train_lens.mean()))
print("Valid avg tok/sample:", float(valid_lens.mean()))


TokLen train:   0%|          | 0/450000 [00:00<?, ?it/s]

TokLen valid:   0%|          | 0/50000 [00:00<?, ?it/s]

Train avg tok/sample: 64.23703333333333
Valid avg tok/sample: 64.20196


In [10]:
class TokenBatchSampler(Sampler):
    def __init__(self, lengths, max_tokens, shuffle=True, drop_last=False, bucket_size=4096, seed=42):
        self.lengths = np.asarray(lengths, dtype=np.int64)
        self.max_tokens = int(max_tokens)
        self.shuffle = bool(shuffle)
        self.drop_last = bool(drop_last)
        self.bucket_size = int(bucket_size)
        self.rng = np.random.default_rng(seed)

    def __iter__(self):
        idx = np.arange(len(self.lengths))
        if self.shuffle:
            self.rng.shuffle(idx)

        if self.bucket_size and self.bucket_size > 0:
            buckets = [idx[i:i+self.bucket_size] for i in range(0, len(idx), self.bucket_size)]
            out = []
            for b in buckets:
                b = b[np.argsort(self.lengths[b])]
                out.append(b)
            idx = np.concatenate(out, axis=0)

        batch = []
        tok = 0
        for i in idx:
            L = int(self.lengths[i])
            if L > self.max_tokens:
                if not self.drop_last:
                    yield [int(i)]
                continue
            if tok + L > self.max_tokens and len(batch) > 0:
                yield batch
                batch = [int(i)]
                tok = L
            else:
                batch.append(int(i))
                tok += L
        if len(batch) > 0 and (not self.drop_last):
            yield batch

    def __len__(self):
        cnt = 0
        tok = 0
        cur = 0
        for L in self.lengths:
            L = int(L)
            if L > self.max_tokens:
                if not self.drop_last:
                    cnt += 1
                continue
            if tok + L > self.max_tokens and cur > 0:
                cnt += 1
                tok = L
                cur = 1
            else:
                tok += L
                cur += 1
        if cur > 0 and (not self.drop_last):
            cnt += 1
        return cnt


In [11]:
NUM_WORKERS = 2

MAX_TOKENS_TRAIN = 6000
MAX_TOKENS_VALID = 8000

train_batch_sampler = TokenBatchSampler(train_lens, MAX_TOKENS_TRAIN, shuffle=True, drop_last=True, bucket_size=4096, seed=SEED)
valid_batch_sampler = TokenBatchSampler(valid_lens, MAX_TOKENS_VALID, shuffle=False, drop_last=False, bucket_size=0, seed=SEED)

train_loader = DataLoader(train_ds, batch_sampler=train_batch_sampler, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_ds, batch_sampler=valid_batch_sampler, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)

print("Train batches:", len(train_loader), "Valid batches:", len(valid_loader))


Train batches: 4849 Valid batches: 404


## 4) Transformer from scratch (Pre-LN + GEGLU + weight tying)

In [12]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return self.drop(x + self.pe[:, :x.size(1)].to(x.device))

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pe = PositionalEncoding(d_model, max_len=max_len, dropout=dropout)
        self.scale = math.sqrt(d_model)

    def forward(self, x):
        return self.pe(self.emb(x) * self.scale)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.h = n_heads
        self.d = d_model // n_heads
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        B, Tq, D = q.size()

        def split(x):
            return x.view(B, -1, self.h, self.d).transpose(1, 2)

        Q = split(self.wq(q))
        K = split(self.wk(k))
        V = split(self.wv(v))

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d)
        if mask is not None:
            scores = scores.masked_fill(~mask, torch.finfo(scores.dtype).min)

        attn = torch.softmax(scores, dim=-1)
        attn = self.drop(attn)
        out = attn @ V
        out = out.transpose(1, 2).contiguous().view(B, Tq, D)
        return self.wo(out)

class GEGLU(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, 2*d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        u, v = self.fc1(x).chunk(2, dim=-1)
        x = F.gelu(u) * v
        x = self.drop(x)
        return self.fc2(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = GEGLU(d_model, d_ff, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        h = self.norm1(x)
        x = x + self.drop(self.attn(h, h, h, src_mask))
        h = self.norm2(x)
        x = x + self.drop(self.ffn(h))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = GEGLU(d_model, d_ff, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, enc, trg_mask, src_mask):
        h = self.norm1(x)
        x = x + self.drop(self.self_attn(h, h, h, trg_mask))
        h = self.norm2(x)
        x = x + self.drop(self.cross_attn(h, enc, enc, src_mask))
        h = self.norm3(x)
        x = x + self.drop(self.ffn(h))
        return x

def make_src_mask(src):
    return (src != PAD).unsqueeze(1).unsqueeze(2)

def make_trg_mask(trg):
    pad_mask = (trg != PAD).unsqueeze(1).unsqueeze(2)
    T = trg.size(1)
    causal = torch.tril(torch.ones(T, T, device=trg.device, dtype=torch.bool)).unsqueeze(0).unsqueeze(1)
    return pad_mask & causal

class Transformer(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model, n_heads, d_ff, n_enc, n_dec, dropout, max_len):
        super().__init__()
        self.src_emb = TokenEmbedding(src_vocab, d_model, max_len=max_len, dropout=dropout)
        self.trg_emb = TokenEmbedding(trg_vocab, d_model, max_len=max_len, dropout=dropout)
        self.enc = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_enc)])
        self.dec = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_dec)])
        self.norm_e = nn.LayerNorm(d_model)
        self.norm_d = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, trg_vocab, bias=False)
        self.fc_out.weight = self.trg_emb.emb.weight
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src):
        src_mask = make_src_mask(src)
        x = self.src_emb(src)
        for layer in self.enc:
            x = layer(x, src_mask)
        return self.norm_e(x), src_mask

    def decode(self, trg, enc, src_mask):
        trg_mask = make_trg_mask(trg)
        x = self.trg_emb(trg)
        for layer in self.dec:
            x = layer(x, enc, trg_mask, src_mask)
        x = self.norm_d(x)
        return self.fc_out(x)

    def forward(self, src, trg_in):
        enc, src_mask = self.encode(src)
        return self.decode(trg_in, enc, src_mask)


## 5) Training (AMP + Warmup+Cosine + EMA)

In [13]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing: float, vocab_size: int, ignore_index: int = PAD):
        super().__init__()
        self.smoothing = float(smoothing)
        self.vocab_size = int(vocab_size)
        self.ignore_index = int(ignore_index)

    def forward(self, logits, target):
        logp = F.log_softmax(logits, dim=-1)
        nll = F.nll_loss(logp, target, reduction='none', ignore_index=self.ignore_index)
        smooth = -logp.mean(dim=-1)
        mask = (target != self.ignore_index).float()
        nll = (nll * mask).sum() / mask.sum().clamp_min(1.0)
        smooth = (smooth * mask).sum() / mask.sum().clamp_min(1.0)
        return (1.0 - self.smoothing) * nll + self.smoothing * smooth

@dataclass
class HParams:
    d_model: int = 512
    n_heads: int = 8
    d_ff: int = 2048
    n_enc: int = 6
    n_dec: int = 6
    dropout: float = 0.12

    lr: float = 5e-4
    weight_decay: float = 0.01
    warmup_steps: int = 4000
    min_lr_ratio: float = 0.1

    epochs: int = 15
    grad_clip: float = 1.0
    accum_steps: int = 4
    label_smooth: float = 0.1

    use_ema: bool = False
    ema_decay: float = 0.999

hp = HParams()

SRC_V = sp_src.get_piece_size()
TRG_V = sp_trg.get_piece_size()

base_model = Transformer(SRC_V, TRG_V, hp.d_model, hp.n_heads, hp.d_ff, hp.n_enc, hp.n_dec, hp.dropout, max_len=4096).to(device)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(base_model).to(device)
else:
    model = base_model

def unwrap(m):
    return m.module if isinstance(m, nn.DataParallel) else m

criterion = LabelSmoothingLoss(hp.label_smooth, TRG_V, ignore_index=PAD)
optimizer = torch.optim.AdamW(model.parameters(), lr=hp.lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=hp.weight_decay)

steps_per_epoch = max(1, len(train_loader) // hp.accum_steps)
total_steps = max(1, hp.epochs * steps_per_epoch)

def lr_lambda(step):
    step = max(1, step)
    if step <= hp.warmup_steps:
        return step / float(hp.warmup_steps)
    progress = (step - hp.warmup_steps) / float(max(1, total_steps - hp.warmup_steps))
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    return hp.min_lr_ratio + (1.0 - hp.min_lr_ratio) * cosine

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

class EMA:
    def __init__(self, model, decay):
        self.decay = float(decay)
        self.shadow = {n: p.detach().clone() for n, p in model.named_parameters() if p.requires_grad}
        self.backup = {}

    @torch.no_grad()
    def update(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)

    @torch.no_grad()
    def apply(self, model):
        self.backup = {}
        for n, p in model.named_parameters():
            if p.requires_grad:
                self.backup[n] = p.detach().clone()
                p.data.copy_(self.shadow[n])

    @torch.no_grad()
    def restore(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.backup[n])
        self.backup = {}

ema = EMA(unwrap(model), hp.ema_decay) if hp.use_ema else None

print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)


Params (M): 81.324032


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


In [14]:
def run_epoch(model, loader, train: bool):
    total_loss, total_tok = 0.0, 0
    if train:
        model.train()
    else:
        model.eval()

    optimizer.zero_grad(set_to_none=True)

    ctx = torch.enable_grad() if train else torch.no_grad()
    with ctx:
        for step, (src, trg) in enumerate(tqdm(loader, leave=False), start=1):
            src = src.to(device, non_blocking=True)
            trg = trg.to(device, non_blocking=True)
            trg_in = trg[:, :-1]
            trg_out = trg[:, 1:]

            with torch.cuda.amp.autocast(enabled=scaler.is_enabled()):
                logits = model(src, trg_in)
                loss = criterion(logits.reshape(-1, logits.size(-1)).float(), trg_out.reshape(-1))
                if train:
                    loss = loss / hp.accum_steps

            if train:
                scaler.scale(loss).backward()
                if step % hp.accum_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                    scheduler.step()
                    if ema is not None:
                        ema.update(unwrap(model))

            non_pad = (trg_out != PAD).sum().item()
            total_loss += (loss.item() * (hp.accum_steps if train else 1.0)) * non_pad
            total_tok += non_pad

    return total_loss / max(1, total_tok)

def save_ckpt(path, model, best_val, epoch, step):
    state = {
        "model": unwrap(model).state_dict(),
        "hparams": hp.__dict__,
        "best_val": float(best_val),
        "epoch": int(epoch),
        "step": int(step),
        "spm_vi": SRC_MODEL_PREFIX + ".model",
        "spm_en": TRG_MODEL_PREFIX + ".model",
    }
    torch.save(state, path)

def load_ckpt(path, model):
    ckpt = torch.load(path, map_location="cpu")
    unwrap(model).load_state_dict(ckpt["model"], strict=True)
    return ckpt


In [15]:
best_val = 1e9
best_path = "best_vlsp_transformer_spm.pt"
global_step = 0

for epoch in range(1, hp.epochs + 1):
    tr_loss = run_epoch(model, train_loader, train=True)

    if ema is not None:
        ema.apply(unwrap(model))

    va_loss = run_epoch(model, valid_loader, train=False)

    if ema is not None:
        ema.restore(unwrap(model))

    lr = optimizer.param_groups[0]["lr"]
    global_step += steps_per_epoch

    print(f"Epoch {epoch}/{hp.epochs} | train_loss {tr_loss:.4f} | val_loss {va_loss:.4f} | lr {lr:.2e}")

    if va_loss < best_val:
        best_val = va_loss
        save_ckpt(best_path, model, best_val, epoch, global_step)
        print("  ✔ Saved best:", best_path)

print("Best val loss:", best_val)


  0%|          | 0/4849 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=scaler.is_enabled()):


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 1/15 | train_loss 7.5251 | val_loss 6.1209 | lr 1.51e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 2/15 | train_loss 5.1713 | val_loss 4.2591 | lr 3.03e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 3/15 | train_loss 3.8832 | val_loss 3.4207 | lr 4.55e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7baf1e9865c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7baf1e9865c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 4/15 | train_loss 3.3302 | val_loss 3.0937 | lr 4.96e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 5/15 | train_loss 3.0459 | val_loss 2.9234 | lr 4.77e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 6/15 | train_loss 2.8688 | val_loss 2.8192 | lr 4.43e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 7/15 | train_loss 2.7363 | val_loss 2.7358 | lr 3.98e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 8/15 | train_loss 2.6261 | val_loss 2.6724 | lr 3.43e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 9/15 | train_loss 2.5296 | val_loss 2.6180 | lr 2.84e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7baf1e9865c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7baf1e9865c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 10/15 | train_loss 2.4447 | val_loss 2.5771 | lr 2.24e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 11/15 | train_loss 2.3702 | val_loss 2.5433 | lr 1.68e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 12/15 | train_loss 2.3081 | val_loss 2.5154 | lr 1.19e-04
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 13/15 | train_loss 2.2582 | val_loss 2.4970 | lr 8.17e-05
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 14/15 | train_loss 2.2217 | val_loss 2.4856 | lr 5.81e-05
  ✔ Saved best: best_vlsp_transformer_spm.pt


  0%|          | 0/4849 [00:00<?, ?it/s]

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch 15/15 | train_loss 2.1980 | val_loss 2.4775 | lr 5.00e-05
  ✔ Saved best: best_vlsp_transformer_spm.pt
Best val loss: 2.477543691571929


## 6) Beam search decode

In [16]:
@torch.no_grad()
def _no_repeat_ngram_ok(seq, next_tok, n):
    if n <= 0 or len(seq) < n - 1:
        return True
    prefix = seq[-(n-1):]
    new_ng = tuple(prefix + [next_tok])
    seen = set()
    for i in range(len(seq) - n + 1):
        seen.add(tuple(seq[i:i+n]))
    return new_ng not in seen

@torch.no_grad()
def beam_search_decode(model, src_text: str, beam_size=6, max_len=MAX_LEN, length_penalty=0.7, no_repeat_ngram=3):
    model.eval()
    src_ids = [BOS] + sp_src.encode(src_text, out_type=int)[:max_len-2] + [EOS]
    src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)
    enc, src_mask = unwrap(model).encode(src)

    beams = [(0.0, [BOS])]
    completed = []

    for _ in range(max_len):
        new_beams = []
        for logp, seq in beams:
            if seq[-1] == EOS:
                completed.append((logp, seq))
                continue

            trg = torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0)
            logits = unwrap(model).decode(trg, enc, src_mask)
            lp = torch.log_softmax(logits[0, -1], dim=-1)

            topk = torch.topk(lp, k=beam_size*3)
            for score, tok in zip(topk.values.tolist(), topk.indices.tolist()):
                if no_repeat_ngram and not _no_repeat_ngram_ok(seq, tok, no_repeat_ngram):
                    continue
                new_beams.append((logp + score, seq + [tok]))

        if not new_beams:
            break

        def norm(item):
            lp, s = item
            L = max(1, len(s))
            return lp / (L ** length_penalty)

        new_beams.sort(key=norm, reverse=True)
        beams = new_beams[:beam_size]

    if not completed:
        completed = beams

    best = max(completed, key=lambda x: x[0] / (max(1, len(x[1])) ** length_penalty))[1]
    out_ids = [i for i in best if i not in (BOS, EOS, PAD)]
    return sp_trg.decode(out_ids)


## 7) Metrics (BLEU / TER / METEOR)

In [19]:
from sacrebleu.metrics import TER
ter_metric = TER()

assert os.path.exists(PUB_VI), "Không tìm thấy public_test.vi"
assert os.path.exists(PUB_EN), "Không tìm thấy public_test.en -> không thể tính BLEU trên test"

with open(PUB_VI, encoding="utf-8") as f:
    test_vi = [normalize_vi(l.strip()) for l in f if l.strip()]

with open(PUB_EN, encoding="utf-8") as f:
    test_en = [normalize_en(l.strip()) for l in f if l.strip()]

n0 = min(len(test_vi), len(test_en))
test_vi = test_vi[:n0]
test_en = test_en[:n0]
print("Loaded TEST pairs:", n0)

@torch.no_grad()
def evaluate_test_subset(model, src_list, ref_list, beam_size=6, limit=500):
    n = min(limit, len(src_list), len(ref_list))
    preds = []
    refs = ref_list[:n]

    for i in tqdm(range(n), desc=f"Decoding TEST({n})", leave=False):
        preds.append(beam_search_decode(model, src_list[i], beam_size=beam_size, max_len=MAX_LEN))

    bleu = sacrebleu.corpus_bleu(preds, [refs]).score
    ter  = ter_metric.corpus_score(preds, [refs]).score
    meteors = [meteor_score([refs[i].split()], preds[i].split()) for i in range(n)]
    meteor_avg = float(np.mean(meteors))

    return {"BLEU": float(bleu), "TER": float(ter), "METEOR": float(meteor_avg), "n": int(n)}, preds

load_ckpt(best_path, model)
model.to(device).eval()

TEST_LIMIT = 500
metrics_test, test_preds = evaluate_test_subset(
    model, test_vi, test_en, beam_size=6, limit=TEST_LIMIT
)

metrics_test


Loaded TEST pairs: 3000


Decoding TEST(500):   0%|          | 0/500 [00:00<?, ?it/s]

{'BLEU': 42.815546103242156,
 'TER': 49.45105215004574,
 'METEOR': 0.6359299289390538,
 'n': 500}

## 8) Error analysis

In [20]:
def extract_caps_phrases(s: str):
    return re.findall(r"(?:\b[A-Z][a-z]+\b(?:\s+\b[A-Z][a-z]+\b)+)", s)

def extract_numbers(s: str):
    return re.findall(r"\d+(?:[.,]\d+)?", s)

def punctuation_signature(s: str):
    return ''.join(re.findall(r'[.,;:!?()\[\]"\'`-]', s))

def length_bucket(ratio: float):
    if ratio < 0.6:
        return "too_short"
    if ratio > 1.6:
        return "too_long"
    return "ok"

def error_analysis(srcs, refs, preds, topk=15):
    rows = []
    for s, r, p in zip(srcs, refs, preds):
        r_caps = set(extract_caps_phrases(r))
        p_caps = set(extract_caps_phrases(p))
        cap_miss = len(r_caps - p_caps)

        r_nums = set(extract_numbers(r))
        p_nums = set(extract_numbers(p))
        num_miss = len(r_nums.symmetric_difference(p_nums))

        punct_diff = (punctuation_signature(r) != punctuation_signature(p))

        ratio = (len(p.split()) + 1e-6) / (len(r.split()) + 1e-6)
        rows.append((cap_miss, num_miss, int(punct_diff), abs(ratio - 1.0), ratio, s, r, p))

    rows.sort(key=lambda x: (x[0], x[1], x[2], x[3]), reverse=True)
    return rows[:topk]

# ---- ERROR ANALYSIS ON TEST (500) ----
assert "metrics_test" in globals(), "Chưa có metrics_test – hãy chạy cell evaluate TEST trước"
assert "test_preds" in globals(), "Chưa có test_preds – hãy chạy cell evaluate TEST trước"
assert "test_vi" in globals() and "test_en" in globals(), "Chưa có test_vi/test_en"

n = metrics_test["n"]
bad = error_analysis(test_vi[:n], test_en[:n], test_preds[:n], topk=15)

for i, (cap_miss, num_miss, punct_diff, dist, ratio, s, r, p) in enumerate(bad, 1):
    print("=" * 90)
    print(f"[{i}] cap_miss={cap_miss} | num_miss={num_miss} | punct_diff={bool(punct_diff)} | len_ratio={ratio:.2f} ({length_bucket(ratio)})")
    print("SRC:", s)
    print("REF:", r)
    print("PRD:", p)


[1] cap_miss=5 | num_miss=0 | punct_diff=True | len_ratio=0.50 (too_short)
SRC: Đánh giá một số đặc điểm lâm sàng, kết quả điều trị theo các thang điểm GCS, MRC, NIHSS và mRS.
REF: Assessment some clinical features, outcome based on Glasgow Coma Scale (GCS), Medical Research Council UK (MRC), National Institute of Health Stroke Scale (NIHSS) and modified Rankin Scale (mRS).
PRD: Evaluate some clinical features, treatment results according to GCS, MRC, NIHSS and mRS scores.
[2] cap_miss=3 | num_miss=0 | punct_diff=True | len_ratio=0.97 (ok)
SRC: Nghiên cứu sử dụng các yếu tố (chuẩn mực chủ quan, nhận thức kiểm soát hành vi) trong mô hình lý thuyết về hành vi có kế hoạch (TPB) cùng việc kết hợp một số yếu tố được chỉ ra từ các nghiên cứu liên quan trước đó để lường sự phù hợp của nó.
REF: The study used the factors (Subjective Norm (SN), Perceived Behavioral Control) in the model Theory of Planned Behavior (TPB) and use some other factors, based on the previous studies on the same subjec