In [1]:
import math
import random
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# -------------------------
# Seed (deterministic for reproducibility)
# -------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# ==============================
# 2️⃣ LOAD DỮ LIỆU
# ==============================
train_en = open("/kaggle/input/en-vi-ds/data/train.en", "r", encoding="utf-8").read().splitlines()
train_vi = open("/kaggle/input/en-vi-ds/data/train.vi", "r", encoding="utf-8").read().splitlines()
test_en  = open("/kaggle/input/en-vi-ds/data/tst2013.en", "r", encoding="utf-8").read().splitlines()
test_vi  = open("/kaggle/input/en-vi-ds/data/tst2013.vi", "r", encoding="utf-8").read().splitlines()

print("Train:", len(train_en), "Test:", len(test_en))

# ==============================
# 3️⃣ TOKENIZER (min_freq=1 to avoid too many <unk>)
# ==============================
class Tokenizer:
    def __init__(self, texts, min_freq=1):
        self.word2idx = {"<pad>":0, "<sos>":1, "<eos>":2, "<unk>":3}
        self.idx2word = {v:k for k,v in self.word2idx.items()}
        self.build_vocab(texts, min_freq)

    def build_vocab(self, texts, min_freq):
        counter = Counter()
        for line in texts:
            counter.update(line.strip().lower().split())
        for word, freq in counter.items():
            if freq >= min_freq and word not in self.word2idx:
                idx = len(self.word2idx)
                self.word2idx[word] = idx
                self.idx2word[idx] = word

    def encode(self, text):
        # returns list of token ids (no <sos>/<eos> here)
        return [self.word2idx.get(tok, self.word2idx["<unk>"]) for tok in text.strip().lower().split()]

    def decode(self, ids):
        words = []
        for i in ids:
            if i == self.word2idx["<eos>"]:
                break
            if i <= 3:  # 0-3 are special tokens or pad/unk/sos/eos: skip except unk? keep readable
                if i == self.word2idx["<unk>"]:
                    words.append("<unk>")
                continue
            words.append(self.idx2word.get(i, "<unk>"))
        return " ".join(words)

tok_src = Tokenizer(train_en, min_freq=1)
tok_trg = Tokenizer(train_vi, min_freq=1)

# ==============================
# 4️⃣ DATASET + collate
# ==============================
class TranslationDataset(Dataset):
    def __init__(self, src, trg, tok_src, tok_trg):
        self.src = src
        self.trg = trg
        self.tok_src = tok_src
        self.tok_trg = tok_trg

    def __len__(self): return len(self.src)

    def __getitem__(self, idx):
        s = [self.tok_src.word2idx["<sos>"]] + self.tok_src.encode(self.src[idx]) + [self.tok_src.word2idx["<eos>"]]
        t = [self.tok_trg.word2idx["<sos>"]] + self.tok_trg.encode(self.trg[idx]) + [self.tok_trg.word2idx["<eos>"]]
        return torch.tensor(s, dtype=torch.long), torch.tensor(t, dtype=torch.long)

def collate_fn(batch, pad_idx=0):
    src, trg = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=pad_idx)
    trg = nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=pad_idx)
    return src, trg

dataset = TranslationDataset(train_en, train_vi, tok_src, tok_trg)
train_len = int(0.9 * len(dataset))
val_len = len(dataset) - train_len
train_set, val_set = random_split(dataset, [train_len, val_len])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set, batch_size=32, shuffle=False, collate_fn=collate_fn)

# ==============================
# 5️⃣ POS ENCODING
# ==============================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))  # registered buffer -> moves with .to(device)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# ==============================
# 6️⃣ TRANSFORMER MODEL (with weight tying & device-safety)
# ==============================
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model=256, nhead=4, num_layers=3, pad_idx=0, dropout=0.1):
        super().__init__()
        self.pad_idx = pad_idx
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=pad_idx)
        self.trg_emb = nn.Embedding(trg_vocab, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )
        self.fc = nn.Linear(d_model, trg_vocab)
        # weight tying (output projection shares embedding weight)
        self.fc.weight = self.trg_emb.weight

    def forward(self, src, trg):
        # src: (B, S), trg: (B, T)
        device = src.device
        src_key_padding_mask = (src == self.pad_idx)  # (B, S)
        trg_key_padding_mask = (trg == self.pad_idx)  # (B, T)
        # subsequent mask for target (T, T)
        T = trg.size(1)
        trg_mask = nn.Transformer.generate_square_subsequent_mask(T).to(device)

        src_emb = self.pos(self.src_emb(src))  # (B, S, d_model)
        trg_emb = self.pos(self.trg_emb(trg))  # (B, T, d_model)

        out = self.transformer(
            src_emb, trg_emb,
            tgt_mask=trg_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=trg_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )  # (B, T, d_model)
        return self.fc(out)  # (B, T, trg_vocab)

# ==============================
# 7️⃣ TRAINING FUNCTION (improvements: lr scheduler, AdamW, save best)
# ==============================
def train_model(model, train_loader, val_loader, device, epochs=10, lr=3e-4, pad_idx=0):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', patience=1, factor=0.5, verbose=True)

    best_val = float('inf')
    for ep in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for src, trg in train_loader:
            src, trg = src.to(device), trg.to(device)
            opt.zero_grad()
            # teacher forcing: feed trg tokens except last as input
            out = model(src, trg[:, :-1])  # predict next token for each pos
            # out: (B, T-1, V) -> flatten
            loss = loss_fn(out.reshape(-1, out.size(-1)), trg[:, 1:].reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total_loss += loss.item()

        avg_train = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for src, trg in val_loader:
                src, trg = src.to(device), trg.to(device)
                out = model(src, trg[:, :-1])
                loss = loss_fn(out.reshape(-1, out.size(-1)), trg[:, 1:].reshape(-1))
                val_loss += loss.item()
        avg_val = val_loss / len(val_loader)

        scheduler.step(avg_val)

        print(f"Epoch {ep}/{epochs} | Train {avg_train:.4f} | Val {avg_val:.4f}")

        if avg_val < best_val:
            best_val = avg_val
            torch.save({
                'model_state': model.state_dict(),
                'tok_src': tok_src.word2idx,
                'tok_trg': tok_trg.word2idx
            }, "best_model.pt")
            print("✔ Saved best model!")

# ==============================
# 8️⃣ TRANSLATE (greedy) - improved (stop when <eos>, limit length)
# ==============================
def translate(model, text, tok_src, tok_trg, device, max_len=60):
    model.eval()
    src = [tok_src.word2idx["<sos>"]] + tok_src.encode(text) + [tok_src.word2idx["<eos>"]]
    src = torch.tensor(src, dtype=torch.long).unsqueeze(0).to(device)  # (1, S)
    trg = torch.tensor([[tok_trg.word2idx["<sos>"]]], dtype=torch.long).to(device)  # (1,1)
    with torch.no_grad():
        for _ in range(max_len):
            out = model(src, trg)  # (1, T, V)
            next_tok = out[0, -1].argmax().item()
            trg = torch.cat([trg, torch.tensor([[next_tok]], device=device)], dim=1)
            if next_tok == tok_trg.word2idx["<eos>"]:
                break
    # strip leading <sos> when decoding
    return tok_trg.decode(trg[0].tolist()[1:])

# ==============================
# 9️⃣ EVALUATE BLEU (use smoothing)
# ==============================
def evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=50):
    model.eval()
    total_bleu = 0
    smooth = SmoothingFunction().method1
    n = min(n, len(test_en))
    for i in range(n):
        pred = translate(model, test_en[i], tok_src, tok_trg, device)
        bleu = sentence_bleu([test_vi[i].split()], pred.split(), smoothing_function=smooth)
        total_bleu += bleu
        if i < 10:  # print a few examples
            print(f"\nEN: {test_en[i]}")
            print(f"GT: {test_vi[i]}")
            print(f"PR: {pred}")
            print(f"BLEU: {bleu:.4f}")

    print("\nAVERAGE BLEU =", total_bleu / n)

Train: 133317 Test: 1268


In [2]:
# ==============================
# RUN (example)
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerModel(len(tok_src.word2idx), len(tok_trg.word2idx), d_model=256, nhead=4, num_layers=3, pad_idx=tok_src.word2idx["<pad>"])
train_model(model, train_loader, val_loader, device, epochs=15, lr=3e-4, pad_idx=tok_src.word2idx["<pad>"])


  output = torch._nested_tensor_from_mask(


Epoch 1/15 | Train 7.0317 | Val 4.2933
✔ Saved best model!
Epoch 2/15 | Train 3.9949 | Val 3.5773
✔ Saved best model!
Epoch 3/15 | Train 3.4649 | Val 3.2186
✔ Saved best model!
Epoch 4/15 | Train 3.1537 | Val 2.9783
✔ Saved best model!
Epoch 5/15 | Train 2.9399 | Val 2.8171
✔ Saved best model!
Epoch 6/15 | Train 2.7756 | Val 2.7042
✔ Saved best model!
Epoch 7/15 | Train 2.6472 | Val 2.6203
✔ Saved best model!
Epoch 8/15 | Train 2.5423 | Val 2.5537
✔ Saved best model!
Epoch 9/15 | Train 2.4529 | Val 2.5009
✔ Saved best model!
Epoch 10/15 | Train 2.3783 | Val 2.4501
✔ Saved best model!
Epoch 11/15 | Train 2.3134 | Val 2.4152
✔ Saved best model!
Epoch 12/15 | Train 2.2537 | Val 2.3849
✔ Saved best model!
Epoch 13/15 | Train 2.2022 | Val 2.3647
✔ Saved best model!
Epoch 14/15 | Train 2.1551 | Val 2.3438
✔ Saved best model!
Epoch 15/15 | Train 2.1136 | Val 2.3305
✔ Saved best model!


In [3]:
evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=10)


EN: When I was little , I thought my country was the best on the planet , and I grew up singing a song called &quot; Nothing To Envy . &quot;
GT: Khi tôi còn nhỏ , Tôi nghĩ rằng BắcTriều Tiên là đất nước tốt nhất trên thế giới và tôi thường hát bài &quot; Chúng ta chẳng có gì phải ghen tị . &quot;
PR: khi tôi còn nhỏ , tôi nghĩ đất nước tôi là đất nước tốt nhất trên hành tinh , và tôi lớn lên một bài hát &quot; chẳng có gì để ghen tị . &quot;
BLEU: 0.3277

EN: And I was very proud .
GT: Tôi đã rất tự hào về đất nước tôi .
PR: và tôi rất tự hào .
BLEU: 0.1179

EN: In school , we spent a lot of time studying the history of Kim Il-Sung , but we never learned much about the outside world , except that America , South Korea , Japan are the enemies .
GT: Ở trường , chúng tôi dành rất nhiều thời gian để học về cuộc đời của chủ tịch Kim II- Sung , nhưng lại không học nhiều về thế giới bên ngoài , ngoại trừ việc Hoa Kỳ , Hàn Quốc và Nhật Bản là kẻ thù của chúng tôi .
PR: ở trường , chúng tôi d

In [4]:
sent = "See you again"
print(translate(model, sent, tok_src, tok_trg, device))

bạn thấy lần nữa .
