#### Bài 1: Xây dựng kiến trúc Encoder-Decoder gồm 3 lớp LSTM cho module encoder và 3 lớp LSTM cho module decoder, với hidden size là 256, cho bài toán dịch máy từ tiếng Anh sang tiếng Việt. Huấn luyện mô hình này trên bộ dữ liệu PhoMT sử dụng Adam làm phương thức tối ưu tham số. Đánh giá độ hiệu quả của mô hình sử dụng độ đo ROUGE-L.

#IMPORT THƯ VIỆN

In [1]:
import json
import random
from collections import Counter
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

print("Torch:", torch.__version__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)


Torch: 2.9.0+cu126
Device: cuda


In [2]:
TRAIN_PATH = "/content/small-train.json"
DEV_PATH   = "/content/small-dev.json"
TEST_PATH  = "/content/small-test.json"


#Khảo sát dữ liệu

In [3]:
def basic_tokenize(s: str) -> List[str]:
    return s.strip().split()

def dataset_stats(path: str, n_show: int = 3):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    print(f"\n[EDA] File: {path}")
    print("[EDA] Num samples:", len(data))

    src_lens, tgt_lens = [], []
    src_all, tgt_all = [], []

    for ex in data:
        s = basic_tokenize(ex["english"])
        t = basic_tokenize(ex["vietnamese"])
        src_lens.append(len(s)); tgt_lens.append(len(t))
        src_all += s; tgt_all += t

    def summarize(lens):
        lens = sorted(lens)
        mean = sum(lens)/max(1,len(lens))
        p50 = lens[len(lens)//2] if lens else 0
        p90 = lens[int(0.9*len(lens))-1] if len(lens) > 0 else 0
        mx  = max(lens) if lens else 0
        return mean, p50, p90, mx

    sm, s50, s90, smax = summarize(src_lens)
    tm, t50, t90, tmax = summarize(tgt_lens)

    print(f"[EDA] EN len: mean={sm:.2f}, p50={s50}, p90={s90}, max={smax}")
    print(f"[EDA] VI len: mean={tm:.2f}, p50={t50}, p90={t90}, max={tmax}")

    print("[EDA] Top EN tokens:", Counter(src_all).most_common(10))
    print("[EDA] Top VI tokens:", Counter(tgt_all).most_common(10))

    print("\n[EDA] Samples:")
    for i in range(min(n_show, len(data))):
        print(f"- EN: {data[i]['english']}")
        print(f"  VI: {data[i]['vietnamese']}")

dataset_stats(TRAIN_PATH, n_show=3)
dataset_stats(DEV_PATH, n_show=2)
dataset_stats(TEST_PATH, n_show=2)



[EDA] File: /content/small-train.json
[EDA] Num samples: 20000
[EDA] EN len: mean=19.79, p50=16, p90=37, max=164
[EDA] VI len: mean=23.76, p50=19, p90=45, max=179
[EDA] Top EN tokens: [(',', 23180), ('.', 18302), ('the', 14556), ('to', 9433), ('of', 8532), ('and', 8242), ('a', 7647), ('that', 7083), ('I', 5973), ('in', 5556)]
[EDA] Top VI tokens: [(',', 19699), ('.', 16254), ('là', 8091), ('tôi', 7122), ('và', 7024), ('có', 6977), ('một', 6598), ('chúng', 5131), ('những', 5037), ('của', 4960)]

[EDA] Samples:
- EN: It begins with a countdown .
  VI: Câu chuyện bắt đầu với buổi lễ đếm ngược .
- EN: On August 14th , 1947 , a woman in Bombay goes into labor as the clock ticks towards midnight .
  VI: Ngày 14 , tháng 8 , năm 1947 , gần nửa đêm , ở Bombay , có một phụ nữ sắp lâm bồn .
- EN: Across India , people hold their breath for the declaration of independence after nearly two centuries of British occupation and rule .
  VI: Cùng lúc , trên khắp đất Ấn , người ta nín thở chờ đợi tuyên

#Xây dựng vocab

In [9]:
SPECIALS = ["<pad>", "<unk>", "<bos>", "<eos>"]
PAD, UNK, BOS, EOS = SPECIALS

class Vocab:
    def __init__(self, tokenized_sents: List[List[str]], min_freq=1, max_size=50000):
        freq = Counter()
        for sent in tokenized_sents:
            freq.update(sent)

        items = [(w,c) for w,c in freq.items() if c >= min_freq]
        items.sort(key=lambda x: (-x[1], x[0]))
        items = items[: max(0, max_size - len(SPECIALS))]

        self.itos = SPECIALS + [w for w,_ in items]
        self.stoi = {w:i for i,w in enumerate(self.itos)}
        self.pad_id = self.stoi[PAD]
        self.unk_id = self.stoi[UNK]
        self.bos_id = self.stoi[BOS]
        self.eos_id = self.stoi[EOS]

    def __len__(self):
        return len(self.itos)

    def encode(self, toks: List[str]) -> List[int]:
        return [self.stoi.get(t, self.unk_id) for t in toks]

    def decode(self, ids: List[int]) -> List[str]:
        out = []
        for i in ids:
            if i == self.eos_id: break
            if i in (self.pad_id, self.bos_id): continue
            out.append(self.itos[i] if 0 <= i < len(self.itos) else UNK)
        return out

class PhoMTJsonDataset(Dataset):
    def __init__(self, path: str):
        with open(path, "r", encoding="utf-8") as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        ex = self.data[idx]
        return ex["english"], ex["vietnamese"]

def build_vocabs(train_path: str, min_freq=1, max_size=50000):
    with open(train_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    src_tok = [basic_tokenize(x["english"]) for x in data]
    tgt_tok = [basic_tokenize(x["vietnamese"]) for x in data]
    return Vocab(src_tok, min_freq=min_freq, max_size=max_size), Vocab(tgt_tok, min_freq=min_freq, max_size=max_size)

@dataclass
class Batch:
    src: torch.Tensor
    tgt_in: torch.Tensor
    tgt_out: torch.Tensor

def pad_2d(seqs: List[torch.Tensor], pad_id: int) -> torch.Tensor:
    maxlen = max(s.size(0) for s in seqs)
    out = torch.full((len(seqs), maxlen), pad_id, dtype=torch.long)
    for i,s in enumerate(seqs):
        out[i, :s.size(0)] = s
    return out

MAX_LEN_SRC = 170
MAX_LEN_TGT = 190

def collate_fn(batch_items, src_vocab: Vocab, tgt_vocab: Vocab):
    src_list, tgt_in_list, tgt_out_list = [], [], []
    for en, vi in batch_items:
        src_toks = basic_tokenize(en)[:MAX_LEN_SRC]
        tgt_toks = basic_tokenize(vi)[:MAX_LEN_TGT]
        src_ids = src_vocab.encode(src_toks)
        tgt_ids = tgt_vocab.encode(tgt_toks)
        tgt_in  = [tgt_vocab.bos_id] + tgt_ids
        tgt_out = tgt_ids + [tgt_vocab.eos_id]
        src_list.append(torch.tensor(src_ids, dtype=torch.long))
        tgt_in_list.append(torch.tensor(tgt_in, dtype=torch.long))
        tgt_out_list.append(torch.tensor(tgt_out, dtype=torch.long))

    src = pad_2d(src_list, src_vocab.pad_id)
    tgt_in = pad_2d(tgt_in_list, tgt_vocab.pad_id)
    tgt_out = pad_2d(tgt_out_list, tgt_vocab.pad_id)
    return Batch(src=src, tgt_in=tgt_in, tgt_out=tgt_out)

src_vocab, tgt_vocab = build_vocabs(TRAIN_PATH, min_freq=1, max_size=50000)
print("src_vocab:", len(src_vocab))
print("tgt_vocab:", len(tgt_vocab))


src_vocab: 19065
tgt_vocab: 8297


#Xây dựng mô hình encoderdecoder

In [10]:
EMBED_DIM = 256
HIDDEN_SIZE = 256
NUM_LAYERS = 3
DROPOUT = 0.2

class Encoder(nn.Module):
    def __init__(self, vocab_size: int, pad_id: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=pad_id)
        self.lstm = nn.LSTM(
            input_size=EMBED_DIM,
            hidden_size=HIDDEN_SIZE,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT if NUM_LAYERS > 1 else 0.0,
            batch_first=True
        )

    def forward(self, src_ids):
        # src_ids: [B,S]
        x = self.emb(src_ids)      # [B,S,256]
        _, (h, c) = self.lstm(x)   # h,c: [3,B,256]
        return h, c

class Decoder(nn.Module):
    def __init__(self, vocab_size: int, pad_id: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=pad_id)
        self.lstm = nn.LSTM(
            input_size=EMBED_DIM,
            hidden_size=HIDDEN_SIZE,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT if NUM_LAYERS > 1 else 0.0,
            batch_first=True
        )
        self.proj = nn.Linear(HIDDEN_SIZE, vocab_size)

    def forward(self, tgt_in_ids, h, c):
        # tgt_in_ids: [B,T]
        x = self.emb(tgt_in_ids)        # [B,T,256]
        out, (h, c) = self.lstm(x, (h,c))  # out: [B,T,256]
        logits = self.proj(out)         # [B,T,V]
        return logits, h, c

class Seq2Seq(nn.Module):
    def __init__(self, enc: Encoder, dec: Decoder, pad_id: int):
        super().__init__()
        self.enc = enc
        self.dec = dec
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)

    def forward(self, src_ids, tgt_in, tgt_out):
        h, c = self.enc(src_ids)
        logits, _, _ = self.dec(tgt_in, h, c)
        loss = self.loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        return loss

enc = Encoder(len(src_vocab), src_vocab.pad_id)
dec = Decoder(len(tgt_vocab), tgt_vocab.pad_id)
model = Seq2Seq(enc, dec, pad_id=tgt_vocab.pad_id).to(DEVICE)
print(model)

Seq2Seq(
  (enc): Encoder(
    (emb): Embedding(19065, 256, padding_idx=0)
    (lstm): LSTM(256, 256, num_layers=3, batch_first=True, dropout=0.2)
  )
  (dec): Decoder(
    (emb): Embedding(8297, 256, padding_idx=0)
    (lstm): LSTM(256, 256, num_layers=3, batch_first=True, dropout=0.2)
    (proj): Linear(in_features=256, out_features=8297, bias=True)
  )
  (loss_fn): CrossEntropyLoss()
)


#DataLoader + kiểm tra 1 batch

In [11]:
train_ds = PhoMTJsonDataset(TRAIN_PATH)
dev_ds   = PhoMTJsonDataset(DEV_PATH)
test_ds  = PhoMTJsonDataset(TEST_PATH)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,
                          collate_fn=lambda b: collate_fn(b, src_vocab, tgt_vocab))
dev_loader   = DataLoader(dev_ds, batch_size=32, shuffle=False,
                          collate_fn=lambda b: collate_fn(b, src_vocab, tgt_vocab))
test_loader  = DataLoader(test_ds, batch_size=32, shuffle=False,
                          collate_fn=lambda b: collate_fn(b, src_vocab, tgt_vocab))

batch = next(iter(train_loader))
print("src:", batch.src.shape)
print("tgt_in:", batch.tgt_in.shape)
print("tgt_out:", batch.tgt_out.shape)

loss = model(batch.src.to(DEVICE), batch.tgt_in.to(DEVICE), batch.tgt_out.to(DEVICE))
print("sanity loss:", float(loss.item()))


src: torch.Size([32, 51])
tgt_in: torch.Size([32, 56])
tgt_out: torch.Size([32, 56])
sanity loss: 9.021029472351074


In [12]:
@torch.no_grad()
def greedy_decode(model: Seq2Seq, src_ids: torch.Tensor, tgt_vocab: Vocab, max_len= MAX_LEN_TGT):
    model.eval()
    h, c = model.enc(src_ids)
    B = src_ids.size(0)
    ys = torch.full((B,1), tgt_vocab.bos_id, dtype=torch.long, device=src_ids.device)

    for _ in range(max_len):
        logits, h, c = model.dec(ys, h, c)
        next_tok = logits[:, -1, :].argmax(-1, keepdim=True)
        ys = torch.cat([ys, next_tok], dim=1)
        if (next_tok.squeeze(1) == tgt_vocab.eos_id).all():
            break
    return ys

def lcs_length(a: List[str], b: List[str]) -> int:
    n, m = len(a), len(b)
    dp = [0]*(m+1)
    for i in range(1, n+1):
        prev = 0
        for j in range(1, m+1):
            tmp = dp[j]
            if a[i-1] == b[j-1]:
                dp[j] = prev + 1
            else:
                dp[j] = max(dp[j], dp[j-1])
            prev = tmp
    return dp[m]

def rouge_l_f1(pred: List[str], ref: List[str]) -> float:
    if not pred or not ref:
        return 0.0
    lcs = lcs_length(pred, ref)
    p = lcs / max(1, len(pred))
    r = lcs / max(1, len(ref))
    return 0.0 if (p+r)==0 else (2*p*r)/(p+r)

@torch.no_grad()
def evaluate_rouge_l(model: Seq2Seq, loader: DataLoader, tgt_vocab: Vocab):
    model.eval()
    scores = []
    for batch in loader:
        src = batch.src.to(DEVICE)
        pred_ids = greedy_decode(model, src, tgt_vocab, max_len=MAX_LEN_TGT)[:, 1:].tolist()  # drop BOS
        for i in range(len(pred_ids)):
            pred_toks = tgt_vocab.decode(pred_ids[i])
            ref_toks  = tgt_vocab.decode(batch.tgt_out[i].tolist())
            scores.append(rouge_l_f1(pred_toks, ref_toks))
    return sum(scores)/max(1,len(scores))


#Huấn luyện mô hình

In [14]:
LR = 3e-4
EPOCHS = 30
GRAD_CLIP = 1.0
PATIENCE = 5         # dừng nếu 5 epoch liên tiếp không cải thiện dev ROUGE-L
MIN_DELTA = 1e-4     # cải thiện tối thiểu để tính là "better"

optimizer = optim.Adam(model.parameters(), lr=LR)

best_dev = -1.0
best_epoch = 0
bad_epochs = 0
BEST_PATH = "best_lstm3_seq2seq_envi.pt"

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        src = batch.src.to(DEVICE)
        tgt_in = batch.tgt_in.to(DEVICE)
        tgt_out = batch.tgt_out.to(DEVICE)

        optimizer.zero_grad()
        loss = model(src, tgt_in, tgt_out)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()

        total_loss += float(loss.item())

    train_loss = total_loss / max(1, len(train_loader))
    dev_rouge = evaluate_rouge_l(model, dev_loader, tgt_vocab)

    print(f"\nEpoch {epoch:02d} | train_loss={train_loss:.4f} | dev_ROUGE-L(F1)={dev_rouge:.4f}")

    #Early stopping logic (maximize dev_rouge)
    if dev_rouge > best_dev + MIN_DELTA:
        best_dev = dev_rouge
        best_epoch = epoch
        bad_epochs = 0

        torch.save({
            "model_state": model.state_dict(),
            "src_itos": src_vocab.itos,
            "tgt_itos": tgt_vocab.itos,
            "cfg": {"embed_dim": 256, "hidden_size": 256, "num_layers": 3,
                    "max_len_src": MAX_LEN_SRC, "max_len_tgt": MAX_LEN_TGT}
        }, BEST_PATH)

        print(f"[SAVE] Best model -> {BEST_PATH} (dev_ROUGE-L={best_dev:.4f})")
    else:
        bad_epochs += 1
        print(f"[EARLY STOP] No improvement for {bad_epochs}/{PATIENCE} epochs (best={best_dev:.4f} at epoch {best_epoch})")

        if bad_epochs >= PATIENCE:
            print(f"[EARLY STOP] Stop training. Best dev ROUGE-L={best_dev:.4f} at epoch {best_epoch}.")
            break


Epoch 1/30: 100%|██████████| 625/625 [00:13<00:00, 45.98it/s]



Epoch 01 | train_loss=5.2009 | dev_ROUGE-L(F1)=0.1290
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1290)


Epoch 2/30: 100%|██████████| 625/625 [00:13<00:00, 46.48it/s]



Epoch 02 | train_loss=5.0063 | dev_ROUGE-L(F1)=0.1347
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1347)


Epoch 3/30: 100%|██████████| 625/625 [00:13<00:00, 45.93it/s]



Epoch 03 | train_loss=4.8434 | dev_ROUGE-L(F1)=0.1518
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1518)


Epoch 4/30: 100%|██████████| 625/625 [00:13<00:00, 46.46it/s]



Epoch 04 | train_loss=4.7016 | dev_ROUGE-L(F1)=0.1621
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1621)


Epoch 5/30: 100%|██████████| 625/625 [00:13<00:00, 46.09it/s]



Epoch 05 | train_loss=4.5757 | dev_ROUGE-L(F1)=0.1689
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1689)


Epoch 6/30: 100%|██████████| 625/625 [00:13<00:00, 45.85it/s]



Epoch 06 | train_loss=4.4631 | dev_ROUGE-L(F1)=0.1746
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1746)


Epoch 7/30: 100%|██████████| 625/625 [00:13<00:00, 46.53it/s]



Epoch 07 | train_loss=4.3608 | dev_ROUGE-L(F1)=0.1787
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1787)


Epoch 8/30: 100%|██████████| 625/625 [00:13<00:00, 46.43it/s]



Epoch 08 | train_loss=4.2695 | dev_ROUGE-L(F1)=0.1821
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1821)


Epoch 9/30: 100%|██████████| 625/625 [00:13<00:00, 46.29it/s]



Epoch 09 | train_loss=4.1854 | dev_ROUGE-L(F1)=0.1834
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1834)


Epoch 10/30: 100%|██████████| 625/625 [00:13<00:00, 46.22it/s]



Epoch 10 | train_loss=4.1073 | dev_ROUGE-L(F1)=0.1857
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1857)


Epoch 11/30: 100%|██████████| 625/625 [00:13<00:00, 46.04it/s]



Epoch 11 | train_loss=4.0328 | dev_ROUGE-L(F1)=0.1852
[EARLY STOP] No improvement for 1/5 epochs (best=0.1857 at epoch 10)


Epoch 12/30: 100%|██████████| 625/625 [00:13<00:00, 46.49it/s]



Epoch 12 | train_loss=3.9647 | dev_ROUGE-L(F1)=0.1866
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1866)


Epoch 13/30: 100%|██████████| 625/625 [00:13<00:00, 46.01it/s]



Epoch 13 | train_loss=3.8999 | dev_ROUGE-L(F1)=0.1892
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1892)


Epoch 14/30: 100%|██████████| 625/625 [00:13<00:00, 46.20it/s]



Epoch 14 | train_loss=3.8376 | dev_ROUGE-L(F1)=0.1896
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1896)


Epoch 15/30: 100%|██████████| 625/625 [00:13<00:00, 46.07it/s]



Epoch 15 | train_loss=3.7794 | dev_ROUGE-L(F1)=0.1899
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1899)


Epoch 16/30: 100%|██████████| 625/625 [00:13<00:00, 45.68it/s]



Epoch 16 | train_loss=3.7232 | dev_ROUGE-L(F1)=0.1905
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1905)


Epoch 17/30: 100%|██████████| 625/625 [00:13<00:00, 45.39it/s]



Epoch 17 | train_loss=3.6694 | dev_ROUGE-L(F1)=0.1919
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1919)


Epoch 18/30: 100%|██████████| 625/625 [00:13<00:00, 45.36it/s]



Epoch 18 | train_loss=3.6176 | dev_ROUGE-L(F1)=0.1918
[EARLY STOP] No improvement for 1/5 epochs (best=0.1919 at epoch 17)


Epoch 19/30: 100%|██████████| 625/625 [00:13<00:00, 45.98it/s]



Epoch 19 | train_loss=3.5691 | dev_ROUGE-L(F1)=0.1906
[EARLY STOP] No improvement for 2/5 epochs (best=0.1919 at epoch 17)


Epoch 20/30: 100%|██████████| 625/625 [00:13<00:00, 45.90it/s]



Epoch 20 | train_loss=3.5203 | dev_ROUGE-L(F1)=0.1938
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1938)


Epoch 21/30: 100%|██████████| 625/625 [00:13<00:00, 44.94it/s]



Epoch 21 | train_loss=3.4746 | dev_ROUGE-L(F1)=0.1908
[EARLY STOP] No improvement for 1/5 epochs (best=0.1938 at epoch 20)


Epoch 22/30: 100%|██████████| 625/625 [00:13<00:00, 45.89it/s]



Epoch 22 | train_loss=3.4317 | dev_ROUGE-L(F1)=0.1917
[EARLY STOP] No improvement for 2/5 epochs (best=0.1938 at epoch 20)


Epoch 23/30: 100%|██████████| 625/625 [00:13<00:00, 46.08it/s]



Epoch 23 | train_loss=3.3855 | dev_ROUGE-L(F1)=0.1941
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1941)


Epoch 24/30: 100%|██████████| 625/625 [00:13<00:00, 45.37it/s]



Epoch 24 | train_loss=3.3446 | dev_ROUGE-L(F1)=0.1932
[EARLY STOP] No improvement for 1/5 epochs (best=0.1941 at epoch 23)


Epoch 25/30: 100%|██████████| 625/625 [00:13<00:00, 46.16it/s]



Epoch 25 | train_loss=3.3048 | dev_ROUGE-L(F1)=0.1955
[SAVE] Best model -> best_lstm3_seq2seq_envi.pt (dev_ROUGE-L=0.1955)


Epoch 26/30: 100%|██████████| 625/625 [00:13<00:00, 46.59it/s]



Epoch 26 | train_loss=3.2627 | dev_ROUGE-L(F1)=0.1954
[EARLY STOP] No improvement for 1/5 epochs (best=0.1955 at epoch 25)


Epoch 27/30: 100%|██████████| 625/625 [00:13<00:00, 46.18it/s]



Epoch 27 | train_loss=3.2254 | dev_ROUGE-L(F1)=0.1931
[EARLY STOP] No improvement for 2/5 epochs (best=0.1955 at epoch 25)


Epoch 28/30: 100%|██████████| 625/625 [00:13<00:00, 46.52it/s]



Epoch 28 | train_loss=3.1861 | dev_ROUGE-L(F1)=0.1919
[EARLY STOP] No improvement for 3/5 epochs (best=0.1955 at epoch 25)


Epoch 29/30: 100%|██████████| 625/625 [00:13<00:00, 46.14it/s]



Epoch 29 | train_loss=3.1521 | dev_ROUGE-L(F1)=0.1933
[EARLY STOP] No improvement for 4/5 epochs (best=0.1955 at epoch 25)


Epoch 30/30: 100%|██████████| 625/625 [00:13<00:00, 46.26it/s]



Epoch 30 | train_loss=3.1191 | dev_ROUGE-L(F1)=0.1935
[EARLY STOP] No improvement for 5/5 epochs (best=0.1955 at epoch 25)
[EARLY STOP] Stop training. Best dev ROUGE-L=0.1955 at epoch 25.


In [15]:
ckpt = torch.load(BEST_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])

test_rouge = evaluate_rouge_l(model, test_loader, tgt_vocab)
print("TEST ROUGE-L(F1) =", test_rouge)


TEST ROUGE-L(F1) = 0.19458110674019324


In [16]:
model.eval()
for i in range(min(5, len(test_ds))):
    en, vi = test_ds[i]
    src_ids = torch.tensor([src_vocab.encode(basic_tokenize(en)[:MAX_LEN_SRC])], dtype=torch.long).to(DEVICE)
    pred = greedy_decode(model, src_ids, tgt_vocab, max_len=MAX_LEN_TGT)[0].tolist()
    pred_text = " ".join(tgt_vocab.decode(pred[1:]))

    print("\nEN:", en)
    print("GT:", vi)
    print("PR:", pred_text)



EN: Brother Albert Barnett and his wife , Sister Susan Barnett , from the West Congregation in Tuscaloosa , Alabama
GT: Anh Albert Barnett và chị Susan Barnett , thuộc hội thánh West ở Tuscaloosa , Alabama
PR: Người phụ nữ này , và không .

EN: Severe storms ripped through parts of the southern and midwestern United States on January 11 and 12 , 2020 .
GT: Ngày 11 và 12-1-2020 , những cơn bão lớn đã quét qua và phá huỷ nhiều vùng ở miền nam và miền trung Hoa Kỳ .
PR: Những người nghèo dùng và những người .

EN: Two days of heavy rain , high winds , and numerous tornadoes caused major damage across multiple states .
GT: Những trận mưa to và gió lớn trong suốt hai ngày cùng với nhiều cơn lốc xoáy đã gây thiệt hại nặng nề cho nhiều bang .
PR: Trong vòng 10 năm , chúng ta không .

EN: Sadly , Brother Albert Barnett and his wife , Sister Susan Barnett , 85 and 75 years old respectively , were killed when a tornado struck their mobile home .
GT: Đáng buồn là anh Albert Barnett 85 tuổi , và 

#### Bài 2: Xây dựng kiến trúc Encoder-Decoder gồm 3 lớp LSTM cho module encoder và 3 lớp LSTM cho module decoder, với hidden size là 256, cho bài toán dịch máy từ tiếng Anh sang tiếng Việt. Module decoder được trang bị kỹ thuật attention theo mô tả của nghiên cứu "[Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)". Huấn luyện mô hình này trên bộ dữ liệu PhoMT sử dụng Adam làm phương thức tối ưu tham số. Đánh giá độ hiệu quả của mô hình sử dụn độ đo ROUGE-L.

In [18]:
EMBED_DIM = 256
HIDDEN_SIZE = 256
NUM_LAYERS = 3
DROPOUT = 0.2

class Encoder(nn.Module):
    def __init__(self, vocab_size: int, pad_id: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=pad_id)
        self.lstm = nn.LSTM(
            EMBED_DIM, HIDDEN_SIZE,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT if NUM_LAYERS > 1 else 0.0,
            batch_first=True
        )

    def forward(self, src_ids):
        x = self.emb(src_ids)                 # [B,S,E]
        enc_outputs, (h, c) = self.lstm(x)    # [B,S,H], ([L,B,H],[L,B,H])
        return enc_outputs, (h, c)

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)  # encoder side
        self.W_s = nn.Linear(hidden_size, hidden_size, bias=False)  # decoder side
        self.v   = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, s_t, enc_outputs, src_mask=None):
        """
        s_t: [B,H]  (decoder top-layer hidden at time t)
        enc_outputs: [B,S,H]
        src_mask: [B,S] True for valid tokens
        """
        # energy: [B,S,H]
        energy = torch.tanh(self.W_h(enc_outputs) + self.W_s(s_t).unsqueeze(1))
        scores = self.v(energy).squeeze(-1)  # [B,S]
        if src_mask is not None:
            scores = scores.masked_fill(~src_mask, -1e9)
        attn_w = F.softmax(scores, dim=-1)   # [B,S]
        context = torch.bmm(attn_w.unsqueeze(1), enc_outputs).squeeze(1)  # [B,H]
        return context, attn_w

class AttnDecoder(nn.Module):
    def __init__(self, vocab_size: int, pad_id: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=pad_id)
        self.attn = BahdanauAttention(HIDDEN_SIZE)

        self.lstm = nn.LSTM(
            input_size=EMBED_DIM + HIDDEN_SIZE,
            hidden_size=HIDDEN_SIZE,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT if NUM_LAYERS > 1 else 0.0,
            batch_first=True
        )
        self.proj = nn.Linear(HIDDEN_SIZE + HIDDEN_SIZE, vocab_size)

    def forward(self, tgt_in_ids, enc_outputs, init_state, src_mask=None):
        h, c = init_state
        B, T = tgt_in_ids.size()
        logits_all = []

        for t in range(T):
            y_t = tgt_in_ids[:, t]         # [B]
            emb_t = self.emb(y_t)          # [B,E]
            s_t = h[-1]                    # [B,H] top-layer
            context, _ = self.attn(s_t, enc_outputs, src_mask=src_mask)  # [B,H]

            lstm_in = torch.cat([emb_t, context], dim=-1).unsqueeze(1)   # [B,1,E+H]
            out_t, (h, c) = self.lstm(lstm_in, (h, c))                   # out_t: [B,1,H]
            out_t = out_t.squeeze(1)                                     # [B,H]

            logits_t = self.proj(torch.cat([out_t, context], dim=-1))    # [B,V]
            logits_all.append(logits_t.unsqueeze(1))                     # [B,1,V]

        return torch.cat(logits_all, dim=1)  # [B,T,V]

class Seq2SeqAttn(nn.Module):
    def __init__(self, enc: Encoder, dec: AttnDecoder, pad_id_src: int, pad_id_tgt: int):
        super().__init__()
        self.enc = enc
        self.dec = dec
        self.pad_id_src = pad_id_src
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id_tgt)

    def make_src_mask(self, src_ids):
        return (src_ids != self.pad_id_src)  # [B,S] True for valid

    def forward(self, src_ids, tgt_in, tgt_out):
        enc_outputs, (h, c) = self.enc(src_ids)
        src_mask = self.make_src_mask(src_ids)
        logits = self.dec(tgt_in, enc_outputs, (h, c), src_mask=src_mask)
        loss = self.loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        return loss

enc = Encoder(len(src_vocab), src_vocab.pad_id)
dec = AttnDecoder(len(tgt_vocab), tgt_vocab.pad_id)
model = Seq2SeqAttn(enc, dec, pad_id_src=src_vocab.pad_id, pad_id_tgt=tgt_vocab.pad_id).to(DEVICE)
print(model)


Seq2SeqAttn(
  (enc): Encoder(
    (emb): Embedding(19065, 256, padding_idx=0)
    (lstm): LSTM(256, 256, num_layers=3, batch_first=True, dropout=0.2)
  )
  (dec): AttnDecoder(
    (emb): Embedding(8297, 256, padding_idx=0)
    (attn): BahdanauAttention(
      (W_h): Linear(in_features=256, out_features=256, bias=False)
      (W_s): Linear(in_features=256, out_features=256, bias=False)
      (v): Linear(in_features=256, out_features=1, bias=False)
    )
    (lstm): LSTM(512, 256, num_layers=3, batch_first=True, dropout=0.2)
    (proj): Linear(in_features=512, out_features=8297, bias=True)
  )
  (loss_fn): CrossEntropyLoss()
)


In [19]:
batch = next(iter(train_loader))
loss = model(batch.src.to(DEVICE), batch.tgt_in.to(DEVICE), batch.tgt_out.to(DEVICE))
print("sanity loss:", float(loss.item()))


sanity loss: 9.02454662322998


In [20]:
def lcs_length(a: List[str], b: List[str]) -> int:
    n, m = len(a), len(b)
    dp = [0]*(m+1)
    for i in range(1, n+1):
        prev = 0
        for j in range(1, m+1):
            tmp = dp[j]
            if a[i-1] == b[j-1]:
                dp[j] = prev + 1
            else:
                dp[j] = max(dp[j], dp[j-1])
            prev = tmp
    return dp[m]

def rouge_l_f1(pred: List[str], ref: List[str]) -> float:
    if not pred or not ref:
        return 0.0
    lcs = lcs_length(pred, ref)
    p = lcs / max(1, len(pred))
    r = lcs / max(1, len(ref))
    return 0.0 if (p+r)==0 else (2*p*r)/(p+r)

@torch.no_grad()
def greedy_decode_attn(model: Seq2SeqAttn, src_ids: torch.Tensor, tgt_vocab: Vocab, max_len=190):
    model.eval()
    enc_outputs, (h, c) = model.enc(src_ids)
    src_mask = (src_ids != model.pad_id_src)

    B = src_ids.size(0)
    ys = torch.full((B,1), tgt_vocab.bos_id, dtype=torch.long, device=src_ids.device)

    for _ in range(max_len):
        y_last = ys[:, -1]                 # [B]
        emb_t = model.dec.emb(y_last)      # [B,E]
        s_t = h[-1]                        # [B,H]
        context, _ = model.dec.attn(s_t, enc_outputs, src_mask=src_mask)

        lstm_in = torch.cat([emb_t, context], dim=-1).unsqueeze(1)  # [B,1,E+H]
        out_t, (h, c) = model.dec.lstm(lstm_in, (h, c))
        out_t = out_t.squeeze(1)                                    # [B,H]

        logits = model.dec.proj(torch.cat([out_t, context], dim=-1)) # [B,V]
        next_tok = logits.argmax(-1, keepdim=True)                   # [B,1]
        ys = torch.cat([ys, next_tok], dim=1)

        if (next_tok.squeeze(1) == tgt_vocab.eos_id).all():
            break

    return ys

@torch.no_grad()
def evaluate_rouge_l_attn(model, loader, tgt_vocab: Vocab):
    model.eval()
    scores = []
    for batch in loader:
        src = batch.src.to(DEVICE)
        pred_ids = greedy_decode_attn(model, src, tgt_vocab, max_len=MAX_LEN_TGT)[:, 1:].tolist()
        for i in range(len(pred_ids)):
            pred_toks = tgt_vocab.decode(pred_ids[i])
            ref_toks  = tgt_vocab.decode(batch.tgt_out[i].tolist())
            scores.append(rouge_l_f1(pred_toks, ref_toks))
    return sum(scores)/max(1,len(scores))


In [21]:
LR = 3e-4
EPOCHS = 30
GRAD_CLIP = 1.0
PATIENCE = 5
MIN_DELTA = 1e-4

optimizer = optim.Adam(model.parameters(), lr=LR)

best_dev = -1.0
best_epoch = 0
bad_epochs = 0
BEST_PATH = "best_lstm3_bahdanau_envi.pt"

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        src = batch.src.to(DEVICE)
        tgt_in = batch.tgt_in.to(DEVICE)
        tgt_out = batch.tgt_out.to(DEVICE)

        optimizer.zero_grad()
        loss = model(src, tgt_in, tgt_out)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()

        total_loss += float(loss.item())

    train_loss = total_loss / max(1, len(train_loader))
    dev_rouge = evaluate_rouge_l_attn(model, dev_loader, tgt_vocab)

    print(f"\nEpoch {epoch:02d} | train_loss={train_loss:.4f} | dev_ROUGE-L(F1)={dev_rouge:.4f}")

    if dev_rouge > best_dev + MIN_DELTA:
        best_dev = dev_rouge
        best_epoch = epoch
        bad_epochs = 0
        torch.save({"model_state": model.state_dict()}, BEST_PATH)
        print(f"[SAVE] Best -> {BEST_PATH} (dev_ROUGE-L={best_dev:.4f})")
    else:
        bad_epochs += 1
        print(f"[EARLY STOP] No improvement {bad_epochs}/{PATIENCE} (best={best_dev:.4f} at epoch {best_epoch})")
        if bad_epochs >= PATIENCE:
            print(f"[EARLY STOP] Stop. Best dev ROUGE-L={best_dev:.4f} at epoch {best_epoch}")
            break


Epoch 1/30: 100%|██████████| 625/625 [02:34<00:00,  4.03it/s]



Epoch 01 | train_loss=6.2215 | dev_ROUGE-L(F1)=0.0603
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.0603)


Epoch 2/30: 100%|██████████| 625/625 [02:33<00:00,  4.08it/s]



Epoch 02 | train_loss=5.7126 | dev_ROUGE-L(F1)=0.1053
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.1053)


Epoch 3/30: 100%|██████████| 625/625 [02:33<00:00,  4.07it/s]



Epoch 03 | train_loss=5.3568 | dev_ROUGE-L(F1)=0.1494
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.1494)


Epoch 4/30: 100%|██████████| 625/625 [02:33<00:00,  4.08it/s]



Epoch 04 | train_loss=5.0525 | dev_ROUGE-L(F1)=0.1778
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.1778)


Epoch 5/30: 100%|██████████| 625/625 [02:33<00:00,  4.06it/s]



Epoch 05 | train_loss=4.7949 | dev_ROUGE-L(F1)=0.1939
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.1939)


Epoch 6/30: 100%|██████████| 625/625 [02:34<00:00,  4.04it/s]



Epoch 06 | train_loss=4.5758 | dev_ROUGE-L(F1)=0.2108
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2108)


Epoch 7/30: 100%|██████████| 625/625 [02:33<00:00,  4.08it/s]



Epoch 07 | train_loss=4.3751 | dev_ROUGE-L(F1)=0.2142
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2142)


Epoch 8/30: 100%|██████████| 625/625 [02:35<00:00,  4.02it/s]



Epoch 08 | train_loss=4.2027 | dev_ROUGE-L(F1)=0.2178
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2178)


Epoch 9/30: 100%|██████████| 625/625 [02:34<00:00,  4.05it/s]



Epoch 09 | train_loss=4.0471 | dev_ROUGE-L(F1)=0.2382
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2382)


Epoch 10/30: 100%|██████████| 625/625 [02:36<00:00,  4.01it/s]



Epoch 10 | train_loss=3.9050 | dev_ROUGE-L(F1)=0.2426
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2426)


Epoch 11/30: 100%|██████████| 625/625 [02:33<00:00,  4.06it/s]



Epoch 11 | train_loss=3.7753 | dev_ROUGE-L(F1)=0.2535
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2535)


Epoch 12/30: 100%|██████████| 625/625 [02:33<00:00,  4.08it/s]



Epoch 12 | train_loss=3.6562 | dev_ROUGE-L(F1)=0.2545
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2545)


Epoch 13/30: 100%|██████████| 625/625 [02:32<00:00,  4.09it/s]



Epoch 13 | train_loss=3.5456 | dev_ROUGE-L(F1)=0.2601
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2601)


Epoch 14/30: 100%|██████████| 625/625 [02:32<00:00,  4.10it/s]



Epoch 14 | train_loss=3.4434 | dev_ROUGE-L(F1)=0.2650
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2650)


Epoch 15/30: 100%|██████████| 625/625 [02:32<00:00,  4.11it/s]



Epoch 15 | train_loss=3.3502 | dev_ROUGE-L(F1)=0.2744
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2744)


Epoch 16/30: 100%|██████████| 625/625 [02:33<00:00,  4.07it/s]



Epoch 16 | train_loss=3.2638 | dev_ROUGE-L(F1)=0.2731
[EARLY STOP] No improvement 1/5 (best=0.2744 at epoch 15)


Epoch 17/30: 100%|██████████| 625/625 [02:33<00:00,  4.07it/s]



Epoch 17 | train_loss=3.1776 | dev_ROUGE-L(F1)=0.2750
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2750)


Epoch 18/30: 100%|██████████| 625/625 [02:32<00:00,  4.09it/s]



Epoch 18 | train_loss=3.0969 | dev_ROUGE-L(F1)=0.2763
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2763)


Epoch 19/30: 100%|██████████| 625/625 [02:34<00:00,  4.04it/s]



Epoch 19 | train_loss=3.0203 | dev_ROUGE-L(F1)=0.2821
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2821)


Epoch 20/30: 100%|██████████| 625/625 [02:34<00:00,  4.04it/s]



Epoch 20 | train_loss=2.9473 | dev_ROUGE-L(F1)=0.2821
[EARLY STOP] No improvement 1/5 (best=0.2821 at epoch 19)


Epoch 21/30: 100%|██████████| 625/625 [02:33<00:00,  4.06it/s]



Epoch 21 | train_loss=2.8772 | dev_ROUGE-L(F1)=0.2865
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2865)


Epoch 22/30: 100%|██████████| 625/625 [02:34<00:00,  4.05it/s]



Epoch 22 | train_loss=2.8141 | dev_ROUGE-L(F1)=0.2894
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2894)


Epoch 23/30: 100%|██████████| 625/625 [02:34<00:00,  4.05it/s]



Epoch 23 | train_loss=2.7489 | dev_ROUGE-L(F1)=0.2896
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2896)


Epoch 24/30: 100%|██████████| 625/625 [02:32<00:00,  4.09it/s]



Epoch 24 | train_loss=2.6878 | dev_ROUGE-L(F1)=0.2901
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2901)


Epoch 25/30: 100%|██████████| 625/625 [02:32<00:00,  4.09it/s]



Epoch 25 | train_loss=2.6326 | dev_ROUGE-L(F1)=0.2961
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2961)


Epoch 26/30: 100%|██████████| 625/625 [02:34<00:00,  4.05it/s]



Epoch 26 | train_loss=2.5760 | dev_ROUGE-L(F1)=0.2945
[EARLY STOP] No improvement 1/5 (best=0.2961 at epoch 25)


Epoch 27/30: 100%|██████████| 625/625 [02:33<00:00,  4.08it/s]



Epoch 27 | train_loss=2.5226 | dev_ROUGE-L(F1)=0.2937
[EARLY STOP] No improvement 2/5 (best=0.2961 at epoch 25)


Epoch 28/30: 100%|██████████| 625/625 [02:32<00:00,  4.10it/s]



Epoch 28 | train_loss=2.4699 | dev_ROUGE-L(F1)=0.2954
[EARLY STOP] No improvement 3/5 (best=0.2961 at epoch 25)


Epoch 29/30: 100%|██████████| 625/625 [02:33<00:00,  4.08it/s]



Epoch 29 | train_loss=2.4195 | dev_ROUGE-L(F1)=0.2969
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2969)


Epoch 30/30: 100%|██████████| 625/625 [02:34<00:00,  4.05it/s]



Epoch 30 | train_loss=2.3732 | dev_ROUGE-L(F1)=0.2988
[SAVE] Best -> best_lstm3_bahdanau_envi.pt (dev_ROUGE-L=0.2988)


In [22]:
ckpt = torch.load(BEST_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])

test_rouge = evaluate_rouge_l_attn(model, test_loader, tgt_vocab)
print("TEST ROUGE-L(F1) =", test_rouge)


TEST ROUGE-L(F1) = 0.31492271826802154


In [23]:
model.eval()
for i in range(min(5, len(test_ds))):
    en, vi = test_ds[i]
    src_ids = torch.tensor([src_vocab.encode(basic_tokenize(en)[:MAX_LEN_SRC])], dtype=torch.long).to(DEVICE)
    pred = greedy_decode_attn(model, src_ids, tgt_vocab, max_len=MAX_LEN_TGT)[0].tolist()
    pred_text = " ".join(tgt_vocab.decode(pred[1:]))

    print("\nEN:", en)
    print("GT:", vi)
    print("PR:", pred_text)



EN: Brother Albert Barnett and his wife , Sister Susan Barnett , from the West Congregation in Tuscaloosa , Alabama
GT: Anh Albert Barnett và chị Susan Barnett , thuộc hội thánh West ở Tuscaloosa , Alabama
PR: Người đàn ông và Domitia khóc , Tom nhân viên , anh ấy , nhà thờ trường học ở trường , trong những người không gian .

EN: Severe storms ripped through parts of the southern and midwestern United States on January 11 and 12 , 2020 .
GT: Ngày 11 và 12-1-2020 , những cơn bão lớn đã quét qua và phá huỷ nhiều vùng ở miền nam và miền trung Hoa Kỳ .
PR: Một người mẹ đã trải qua các thành phố trong vòng và các nước Mỹ và Trung Quốc đã mất một tỷ lệ và tội phạm .

EN: Two days of heavy rain , high winds , and numerous tornadoes caused major damage across multiple states .
GT: Những trận mưa to và gió lớn trong suốt hai ngày cùng với nhiều cơn lốc xoáy đã gây thiệt hại nặng nề cho nhiều bang .
PR: Hai năm trước , những người ủng hộ , và Angelina thẳng thắn của các người khác đều bị bạo l

#### Bài 3: Xây dựng kiến trúc Encoder-Decoder gồm 3 lớp LSTM cho module encoder và 3 lớp LSTM cho module decoder, với hidden size là 256, cho bài toán dịch máy từ tiếng Anh sang tiếng Việt. Module decoder được trang bị kỹ thuật attention theo mô tả của nghiên cứu "[Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)". Huấn luyện mô hình này trên bộ dữ liệu PhoMT sử dụng Adam làm phương thức tối ưu tham số. Đánh giá độ hiệu quả của mô hình sử dụn độ đo ROUGE-L.

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F

EMBED_DIM = 256
HIDDEN_SIZE = 256
NUM_LAYERS = 3
DROPOUT = 0.2

class Encoder(nn.Module):
    def __init__(self, vocab_size: int, pad_id: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=pad_id)
        self.lstm = nn.LSTM(
            input_size=EMBED_DIM,
            hidden_size=HIDDEN_SIZE,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT if NUM_LAYERS > 1 else 0.0,
            batch_first=True
        )

    def forward(self, src_ids):
        # src_ids: [B,S]
        x = self.emb(src_ids)                 # [B,S,E]
        enc_outputs, (h, c) = self.lstm(x)    # enc_outputs: [B,S,H]
        return enc_outputs, (h, c)


#Class attention giống bạn đưa + thêm mask PAD
class LuongAttention(nn.Module):
    """
    Luong 'general' attention:
      score(h_t, h_i) = h_t^T W h_i
    """
    def __init__(self, hidden_size):
        super().__init__()
        self.W = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, decoder_hidden, encoder_outputs, src_mask=None):
        """
        decoder_hidden: [B,H]
        encoder_outputs: [B,S,H]
        src_mask: [B,S] True for valid tokens (non-PAD)
        """
        query = decoder_hidden.unsqueeze(1)         # [B,1,H]
        keys = self.W(encoder_outputs)              # [B,S,H]
        scores = torch.bmm(query, keys.transpose(1,2)).squeeze(1)  # [B,S]

        if src_mask is not None:
            scores = scores.masked_fill(~src_mask, -1e9)

        attn_w = torch.softmax(scores, dim=1)       # [B,S]
        context = torch.bmm(attn_w.unsqueeze(1), encoder_outputs).squeeze(1)  # [B,H]
        return context, attn_w


class LuongDecoderInputFeeding(nn.Module):
    """
    - input_t = [emb(y_{t-1}); tilde_{t-1}]
    - attention dùng h_t (output LSTM) để lấy context
    - tilde_t = tanh(Wc [context; h_t])
    - dự đoán token từ tilde_t
    """
    def __init__(self, vocab_size: int, pad_id: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=pad_id)
        self.attn = LuongAttention(HIDDEN_SIZE)

        self.lstm = nn.LSTM(
            input_size=EMBED_DIM + HIDDEN_SIZE,   # input-feeding
            hidden_size=HIDDEN_SIZE,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT if NUM_LAYERS > 1 else 0.0,
            batch_first=True
        )

        self.Wc = nn.Linear(HIDDEN_SIZE + HIDDEN_SIZE, HIDDEN_SIZE)  # attentional vector
        self.proj = nn.Linear(HIDDEN_SIZE, vocab_size)

    def forward(self, tgt_in_ids, enc_outputs, init_state, src_mask=None):
        """
        tgt_in_ids: [B,T] (BOS + y)
        returns logits: [B,T,V]
        """
        h, c = init_state
        B, T = tgt_in_ids.size()

        tilde_prev = torch.zeros(B, HIDDEN_SIZE, device=tgt_in_ids.device)  # \tilde{h}_0 = 0
        logits_all = []

        for t in range(T):
            y_t = tgt_in_ids[:, t]             # [B]
            emb_t = self.emb(y_t)              # [B,E]

            lstm_in = torch.cat([emb_t, tilde_prev], dim=-1).unsqueeze(1)  # [B,1,E+H]
            out_t, (h, c) = self.lstm(lstm_in, (h, c))
            h_t = out_t.squeeze(1)             # [B,H]

            context, _ = self.attn(h_t, enc_outputs, src_mask=src_mask)    # [B,H]
            tilde_t = torch.tanh(self.Wc(torch.cat([context, h_t], dim=-1))) # [B,H]

            logits_t = self.proj(tilde_t)      # [B,V]
            logits_all.append(logits_t.unsqueeze(1))

            tilde_prev = tilde_t

        return torch.cat(logits_all, dim=1)    # [B,T,V]


class Seq2SeqLuong(nn.Module):
    def __init__(self, enc: Encoder, dec: LuongDecoderInputFeeding, pad_id_src: int, pad_id_tgt: int):
        super().__init__()
        self.enc = enc
        self.dec = dec
        self.pad_id_src = pad_id_src
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id_tgt)

    def make_src_mask(self, src_ids):
        return (src_ids != self.pad_id_src)  # [B,S] True for non-pad

    def forward(self, src_ids, tgt_in, tgt_out):
        enc_outputs, (h, c) = self.enc(src_ids)
        src_mask = self.make_src_mask(src_ids)
        logits = self.dec(tgt_in, enc_outputs, (h, c), src_mask=src_mask)
        loss = self.loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        return loss


# build model
enc = Encoder(len(src_vocab), src_vocab.pad_id)
dec = LuongDecoderInputFeeding(len(tgt_vocab), tgt_vocab.pad_id)
model = Seq2SeqLuong(enc, dec, pad_id_src=src_vocab.pad_id, pad_id_tgt=tgt_vocab.pad_id).to(DEVICE)
print(model)


Seq2SeqLuong(
  (enc): Encoder(
    (emb): Embedding(19065, 256, padding_idx=0)
    (lstm): LSTM(256, 256, num_layers=3, batch_first=True, dropout=0.2)
  )
  (dec): LuongDecoderInputFeeding(
    (emb): Embedding(8297, 256, padding_idx=0)
    (attn): LuongAttention(
      (W): Linear(in_features=256, out_features=256, bias=False)
    )
    (lstm): LSTM(512, 256, num_layers=3, batch_first=True, dropout=0.2)
    (Wc): Linear(in_features=512, out_features=256, bias=True)
    (proj): Linear(in_features=256, out_features=8297, bias=True)
  )
  (loss_fn): CrossEntropyLoss()
)


In [25]:
batch = next(iter(train_loader))
loss = model(batch.src.to(DEVICE), batch.tgt_in.to(DEVICE), batch.tgt_out.to(DEVICE))
print("sanity loss:", float(loss.item()))


sanity loss: 9.026552200317383


In [26]:
def lcs_length(a: List[str], b: List[str]) -> int:
    n, m = len(a), len(b)
    dp = [0]*(m+1)
    for i in range(1, n+1):
        prev = 0
        for j in range(1, m+1):
            tmp = dp[j]
            if a[i-1] == b[j-1]:
                dp[j] = prev + 1
            else:
                dp[j] = max(dp[j], dp[j-1])
            prev = tmp
    return dp[m]

def rouge_l_f1(pred: List[str], ref: List[str]) -> float:
    if not pred or not ref:
        return 0.0
    lcs = lcs_length(pred, ref)
    p = lcs / max(1, len(pred))
    r = lcs / max(1, len(ref))
    return 0.0 if (p+r)==0 else (2*p*r)/(p+r)

@torch.no_grad()
def greedy_decode_luong(model: Seq2SeqLuong, src_ids: torch.Tensor, tgt_vocab: Vocab, max_len=190):
    model.eval()
    enc_outputs, (h, c) = model.enc(src_ids)
    src_mask = (src_ids != model.pad_id_src)

    B = src_ids.size(0)
    ys = torch.full((B,1), tgt_vocab.bos_id, dtype=torch.long, device=src_ids.device)

    tilde_prev = torch.zeros(B, HIDDEN_SIZE, device=src_ids.device)

    for _ in range(max_len):
        y_last = ys[:, -1]                    # [B]
        emb_t = model.dec.emb(y_last)         # [B,E]
        lstm_in = torch.cat([emb_t, tilde_prev], dim=-1).unsqueeze(1)  # [B,1,E+H]

        out_t, (h, c) = model.dec.lstm(lstm_in, (h, c))
        h_t = out_t.squeeze(1)                # [B,H]

        context, _ = model.dec.attn(h_t, enc_outputs, src_mask=src_mask)
        tilde_t = torch.tanh(model.dec.Wc(torch.cat([context, h_t], dim=-1)))
        logits = model.dec.proj(tilde_t)      # [B,V]

        next_tok = logits.argmax(-1, keepdim=True)
        ys = torch.cat([ys, next_tok], dim=1)

        tilde_prev = tilde_t

        if (next_tok.squeeze(1) == tgt_vocab.eos_id).all():
            break

    return ys

@torch.no_grad()
def evaluate_rouge_l_luong(model, loader, tgt_vocab: Vocab):
    model.eval()
    scores = []
    for batch in loader:
        src = batch.src.to(DEVICE)
        pred_ids = greedy_decode_luong(model, src, tgt_vocab, max_len=MAX_LEN_TGT)[:, 1:].tolist()
        for i in range(len(pred_ids)):
            pred_toks = tgt_vocab.decode(pred_ids[i])
            ref_toks  = tgt_vocab.decode(batch.tgt_out[i].tolist())
            scores.append(rouge_l_f1(pred_toks, ref_toks))
    return sum(scores)/max(1,len(scores))


In [27]:
LR = 3e-4
EPOCHS = 30
GRAD_CLIP = 1.0
PATIENCE = 5
MIN_DELTA = 1e-4

optimizer = optim.Adam(model.parameters(), lr=LR)

best_dev = -1.0
best_epoch = 0
bad_epochs = 0
BEST_PATH = "best_lstm3_luong_envi.pt"

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        src = batch.src.to(DEVICE)
        tgt_in = batch.tgt_in.to(DEVICE)
        tgt_out = batch.tgt_out.to(DEVICE)

        optimizer.zero_grad()
        loss = model(src, tgt_in, tgt_out)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()

        total_loss += float(loss.item())

    train_loss = total_loss / max(1, len(train_loader))
    dev_rouge = evaluate_rouge_l_luong(model, dev_loader, tgt_vocab)

    print(f"\nEpoch {epoch:02d} | train_loss={train_loss:.4f} | dev_ROUGE-L(F1)={dev_rouge:.4f}")

    if dev_rouge > best_dev + MIN_DELTA:
        best_dev = dev_rouge
        best_epoch = epoch
        bad_epochs = 0
        torch.save({"model_state": model.state_dict()}, BEST_PATH)
        print(f"[SAVE] Best -> {BEST_PATH} (dev_ROUGE-L={best_dev:.4f})")
    else:
        bad_epochs += 1
        print(f"[EARLY STOP] No improvement {bad_epochs}/{PATIENCE} (best={best_dev:.4f} at epoch {best_epoch})")
        if bad_epochs >= PATIENCE:
            print(f"[EARLY STOP] Stop. Best dev ROUGE-L={best_dev:.4f} at epoch {best_epoch}")
            break


Epoch 1/30: 100%|██████████| 625/625 [02:27<00:00,  4.24it/s]



Epoch 01 | train_loss=6.3189 | dev_ROUGE-L(F1)=0.0750
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.0750)


Epoch 2/30: 100%|██████████| 625/625 [02:28<00:00,  4.21it/s]



Epoch 02 | train_loss=5.9989 | dev_ROUGE-L(F1)=0.1125
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.1125)


Epoch 3/30: 100%|██████████| 625/625 [02:30<00:00,  4.16it/s]



Epoch 03 | train_loss=5.6581 | dev_ROUGE-L(F1)=0.1185
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.1185)


Epoch 4/30: 100%|██████████| 625/625 [02:29<00:00,  4.18it/s]



Epoch 04 | train_loss=5.3050 | dev_ROUGE-L(F1)=0.1477
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.1477)


Epoch 5/30: 100%|██████████| 625/625 [02:28<00:00,  4.20it/s]



Epoch 05 | train_loss=5.0263 | dev_ROUGE-L(F1)=0.1973
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.1973)


Epoch 6/30: 100%|██████████| 625/625 [02:27<00:00,  4.23it/s]



Epoch 06 | train_loss=4.7866 | dev_ROUGE-L(F1)=0.2083
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2083)


Epoch 7/30: 100%|██████████| 625/625 [02:27<00:00,  4.24it/s]



Epoch 07 | train_loss=4.5783 | dev_ROUGE-L(F1)=0.2112
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2112)


Epoch 8/30: 100%|██████████| 625/625 [02:27<00:00,  4.23it/s]



Epoch 08 | train_loss=4.3974 | dev_ROUGE-L(F1)=0.2239
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2239)


Epoch 9/30: 100%|██████████| 625/625 [02:28<00:00,  4.21it/s]



Epoch 09 | train_loss=4.2381 | dev_ROUGE-L(F1)=0.2335
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2335)


Epoch 10/30: 100%|██████████| 625/625 [02:28<00:00,  4.20it/s]



Epoch 10 | train_loss=4.0953 | dev_ROUGE-L(F1)=0.2442
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2442)


Epoch 11/30: 100%|██████████| 625/625 [02:28<00:00,  4.21it/s]



Epoch 11 | train_loss=3.9665 | dev_ROUGE-L(F1)=0.2485
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2485)


Epoch 12/30: 100%|██████████| 625/625 [02:30<00:00,  4.16it/s]



Epoch 12 | train_loss=3.8470 | dev_ROUGE-L(F1)=0.2558
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2558)


Epoch 13/30: 100%|██████████| 625/625 [02:28<00:00,  4.20it/s]



Epoch 13 | train_loss=3.7381 | dev_ROUGE-L(F1)=0.2629
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2629)


Epoch 14/30: 100%|██████████| 625/625 [02:27<00:00,  4.23it/s]



Epoch 14 | train_loss=3.6384 | dev_ROUGE-L(F1)=0.2695
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2695)


Epoch 15/30: 100%|██████████| 625/625 [02:29<00:00,  4.19it/s]



Epoch 15 | train_loss=3.5456 | dev_ROUGE-L(F1)=0.2742
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2742)


Epoch 16/30: 100%|██████████| 625/625 [02:28<00:00,  4.20it/s]



Epoch 16 | train_loss=3.4600 | dev_ROUGE-L(F1)=0.2813
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2813)


Epoch 17/30: 100%|██████████| 625/625 [02:27<00:00,  4.24it/s]



Epoch 17 | train_loss=3.3791 | dev_ROUGE-L(F1)=0.2800
[EARLY STOP] No improvement 1/5 (best=0.2813 at epoch 16)


Epoch 18/30: 100%|██████████| 625/625 [02:25<00:00,  4.29it/s]



Epoch 18 | train_loss=3.3034 | dev_ROUGE-L(F1)=0.2885
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2885)


Epoch 19/30: 100%|██████████| 625/625 [02:27<00:00,  4.24it/s]



Epoch 19 | train_loss=3.2314 | dev_ROUGE-L(F1)=0.2893
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2893)


Epoch 20/30: 100%|██████████| 625/625 [02:26<00:00,  4.26it/s]



Epoch 20 | train_loss=3.1630 | dev_ROUGE-L(F1)=0.2918
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2918)


Epoch 21/30: 100%|██████████| 625/625 [02:26<00:00,  4.27it/s]



Epoch 21 | train_loss=3.0984 | dev_ROUGE-L(F1)=0.2969
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2969)


Epoch 22/30: 100%|██████████| 625/625 [02:27<00:00,  4.25it/s]



Epoch 22 | train_loss=3.0370 | dev_ROUGE-L(F1)=0.2976
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.2976)


Epoch 23/30: 100%|██████████| 625/625 [02:25<00:00,  4.28it/s]



Epoch 23 | train_loss=2.9778 | dev_ROUGE-L(F1)=0.3009
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.3009)


Epoch 24/30: 100%|██████████| 625/625 [02:27<00:00,  4.23it/s]



Epoch 24 | train_loss=2.9188 | dev_ROUGE-L(F1)=0.3028
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.3028)


Epoch 25/30: 100%|██████████| 625/625 [02:27<00:00,  4.24it/s]



Epoch 25 | train_loss=2.8646 | dev_ROUGE-L(F1)=0.3066
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.3066)


Epoch 26/30: 100%|██████████| 625/625 [02:26<00:00,  4.26it/s]



Epoch 26 | train_loss=2.8094 | dev_ROUGE-L(F1)=0.3084
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.3084)


Epoch 27/30: 100%|██████████| 625/625 [02:26<00:00,  4.26it/s]



Epoch 27 | train_loss=2.7589 | dev_ROUGE-L(F1)=0.3106
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.3106)


Epoch 28/30: 100%|██████████| 625/625 [02:30<00:00,  4.16it/s]



Epoch 28 | train_loss=2.7093 | dev_ROUGE-L(F1)=0.3144
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.3144)


Epoch 29/30: 100%|██████████| 625/625 [02:29<00:00,  4.18it/s]



Epoch 29 | train_loss=2.6597 | dev_ROUGE-L(F1)=0.3117
[EARLY STOP] No improvement 1/5 (best=0.3144 at epoch 28)


Epoch 30/30: 100%|██████████| 625/625 [02:27<00:00,  4.24it/s]



Epoch 30 | train_loss=2.6157 | dev_ROUGE-L(F1)=0.3162
[SAVE] Best -> best_lstm3_luong_envi.pt (dev_ROUGE-L=0.3162)


In [28]:
ckpt = torch.load(BEST_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])

test_rouge = evaluate_rouge_l_luong(model, test_loader, tgt_vocab)
print("TEST ROUGE-L(F1) =", test_rouge)


TEST ROUGE-L(F1) = 0.33071683254656103


In [29]:
model.eval()
for i in range(min(5, len(test_ds))):
    en, vi = test_ds[i]
    src_ids = torch.tensor([src_vocab.encode(basic_tokenize(en)[:MAX_LEN_SRC])], dtype=torch.long).to(DEVICE)
    pred = greedy_decode_luong(model, src_ids, tgt_vocab, max_len=MAX_LEN_TGT)[0].tolist()
    pred_text = " ".join(tgt_vocab.decode(pred[1:]))

    print("\nEN:", en)
    print("GT:", vi)
    print("PR:", pred_text)



EN: Brother Albert Barnett and his wife , Sister Susan Barnett , from the West Congregation in Tuscaloosa , Alabama
GT: Anh Albert Barnett và chị Susan Barnett , thuộc hội thánh West ở Tuscaloosa , Alabama
PR: Những người phản ứng và vợ tôi , đã cho phép một người ủng hộ , từ những nhà lãnh đạo ở New York ,

EN: Severe storms ripped through parts of the southern and midwestern United States on January 11 and 12 , 2020 .
GT: Ngày 11 và 12-1-2020 , những cơn bão lớn đã quét qua và phá huỷ nhiều vùng ở miền nam và miền trung Hoa Kỳ .
PR: Nó cho phép nó vượt qua những phần ở Kenya và các nhà báo Mỹ đã giảm tuổi da , và 12 ngày , người Nhật .

EN: Two days of heavy rain , high winds , and numerous tornadoes caused major damage across multiple states .
GT: Những trận mưa to và gió lớn trong suốt hai ngày cùng với nhiều cơn lốc xoáy đã gây thiệt hại nặng nề cho nhiều bang .
PR: Một ngày trong thời gian của tôi , những nhà thờ Châu Phi , và có vẻ phản ứng có vẻ phản ứng trên khắp khắp khắp kh

#Lưu lại bộ vocab và token

In [30]:
VOCAB_PATH = "phomt_vocab_shared.pt"

torch.save({
    "src_itos": src_vocab.itos,
    "tgt_itos": tgt_vocab.itos,
    "src_special": {"pad": src_vocab.pad_id, "unk": src_vocab.unk_id, "bos": src_vocab.bos_id, "eos": src_vocab.eos_id},
    "tgt_special": {"pad": tgt_vocab.pad_id, "unk": tgt_vocab.unk_id, "bos": tgt_vocab.bos_id, "eos": tgt_vocab.eos_id},
    "max_len_src": MAX_LEN_SRC,
    "max_len_tgt": MAX_LEN_TGT,
}, VOCAB_PATH)

print("Saved shared vocab ->", VOCAB_PATH)


Saved shared vocab -> phomt_vocab_shared.pt
