In [None]:
import math
import random
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# -------------------------
# Seed (deterministic for reproducibility)
# -------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# ==============================
# 2️⃣ LOAD DỮ LIỆU
# ==============================
train_en = open("/kaggle/input/en-vi-ds/data/train.en", "r", encoding="utf-8").read().splitlines()
train_vi = open("/kaggle/input/en-vi-ds/data/train.vi", "r", encoding="utf-8").read().splitlines()
test_en  = open("/kaggle/input/en-vi-ds/data/tst2013.en", "r", encoding="utf-8").read().splitlines()
test_vi  = open("/kaggle/input/en-vi-ds/data/tst2013.vi", "r", encoding="utf-8").read().splitlines()

print("Train:", len(train_en), "Test:", len(test_en))

# ==============================
# 3️⃣ TOKENIZER (min_freq=1 to avoid too many <unk>)
# ==============================
class Tokenizer:
    def __init__(self, texts, min_freq=1):
        self.word2idx = {"<pad>":0, "<sos>":1, "<eos>":2, "<unk>":3}
        self.idx2word = {v:k for k,v in self.word2idx.items()}
        self.build_vocab(texts, min_freq)

    def build_vocab(self, texts, min_freq):
        counter = Counter()
        for line in texts:
            counter.update(line.strip().lower().split())
        for word, freq in counter.items():
            if freq >= min_freq and word not in self.word2idx:
                idx = len(self.word2idx)
                self.word2idx[word] = idx
                self.idx2word[idx] = word

    def encode(self, text):
        # returns list of token ids (no <sos>/<eos> here)
        return [self.word2idx.get(tok, self.word2idx["<unk>"]) for tok in text.strip().lower().split()]

    def decode(self, ids):
        words = []
        for i in ids:
            if i == self.word2idx["<eos>"]:
                break
            if i <= 3:  # 0-3 are special tokens or pad/unk/sos/eos: skip except unk? keep readable
                if i == self.word2idx["<unk>"]:
                    words.append("<unk>")
                continue
            words.append(self.idx2word.get(i, "<unk>"))
        return " ".join(words)

tok_src = Tokenizer(train_en, min_freq=1)
tok_trg = Tokenizer(train_vi, min_freq=1)

# ==============================
# 4️⃣ DATASET + collate
# ==============================
class TranslationDataset(Dataset):
    def __init__(self, src, trg, tok_src, tok_trg):
        self.src = src
        self.trg = trg
        self.tok_src = tok_src
        self.tok_trg = tok_trg

    def __len__(self): return len(self.src)

    def __getitem__(self, idx):
        s = [self.tok_src.word2idx["<sos>"]] + self.tok_src.encode(self.src[idx]) + [self.tok_src.word2idx["<eos>"]]
        t = [self.tok_trg.word2idx["<sos>"]] + self.tok_trg.encode(self.trg[idx]) + [self.tok_trg.word2idx["<eos>"]]
        return torch.tensor(s, dtype=torch.long), torch.tensor(t, dtype=torch.long)

def collate_fn(batch, pad_idx=0):
    src, trg = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=pad_idx)
    trg = nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=pad_idx)
    return src, trg

dataset = TranslationDataset(train_en, train_vi, tok_src, tok_trg)
train_len = int(0.9 * len(dataset))
val_len = len(dataset) - train_len
train_set, val_set = random_split(dataset, [train_len, val_len])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set, batch_size=32, shuffle=False, collate_fn=collate_fn)

# ==============================
# 5️⃣ POS ENCODING
# ==============================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))  # registered buffer -> moves with .to(device)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# ==============================
# 6️⃣ TRANSFORMER MODEL (with weight tying & device-safety)
# ==============================
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model=256, nhead=4, num_layers=3, pad_idx=0, dropout=0.1):
        super().__init__()
        self.pad_idx = pad_idx
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=pad_idx)
        self.trg_emb = nn.Embedding(trg_vocab, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )
        self.fc = nn.Linear(d_model, trg_vocab)
        # weight tying (output projection shares embedding weight)
        self.fc.weight = self.trg_emb.weight

    def forward(self, src, trg):
        # src: (B, S), trg: (B, T)
        device = src.device
        src_key_padding_mask = (src == self.pad_idx)  # (B, S)
        trg_key_padding_mask = (trg == self.pad_idx)  # (B, T)
        # subsequent mask for target (T, T)
        T = trg.size(1)
        trg_mask = nn.Transformer.generate_square_subsequent_mask(T).to(device)

        src_emb = self.pos(self.src_emb(src))  # (B, S, d_model)
        trg_emb = self.pos(self.trg_emb(trg))  # (B, T, d_model)

        out = self.transformer(
            src_emb, trg_emb,
            tgt_mask=trg_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=trg_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )  # (B, T, d_model)
        return self.fc(out)  # (B, T, trg_vocab)

# ==============================
# 7️⃣ TRAINING FUNCTION (improvements: lr scheduler, AdamW, save best)
# ==============================
def train_model(model, train_loader, val_loader, device, epochs=10, lr=3e-4, pad_idx=0):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', patience=1, factor=0.5, verbose=True)

    best_val = float('inf')
    for ep in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for src, trg in train_loader:
            src, trg = src.to(device), trg.to(device)
            opt.zero_grad()
            # teacher forcing: feed trg tokens except last as input
            out = model(src, trg[:, :-1])  # predict next token for each pos
            # out: (B, T-1, V) -> flatten
            loss = loss_fn(out.reshape(-1, out.size(-1)), trg[:, 1:].reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total_loss += loss.item()

        avg_train = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for src, trg in val_loader:
                src, trg = src.to(device), trg.to(device)
                out = model(src, trg[:, :-1])
                loss = loss_fn(out.reshape(-1, out.size(-1)), trg[:, 1:].reshape(-1))
                val_loss += loss.item()
        avg_val = val_loss / len(val_loader)

        scheduler.step(avg_val)

        print(f"Epoch {ep}/{epochs} | Train {avg_train:.4f} | Val {avg_val:.4f}")

        if avg_val < best_val:
            best_val = avg_val
            torch.save({
                'model_state': model.state_dict(),
                'tok_src': tok_src.word2idx,
                'tok_trg': tok_trg.word2idx
            }, "best_model.pt")
            print("✔ Saved best model!")

# ==============================
# 8️⃣ TRANSLATE (greedy) - improved (stop when <eos>, limit length)
# ==============================
def translate(model, text, tok_src, tok_trg, device, max_len=60):
    model.eval()
    src = [tok_src.word2idx["<sos>"]] + tok_src.encode(text) + [tok_src.word2idx["<eos>"]]
    src = torch.tensor(src, dtype=torch.long).unsqueeze(0).to(device)  # (1, S)
    trg = torch.tensor([[tok_trg.word2idx["<sos>"]]], dtype=torch.long).to(device)  # (1,1)
    with torch.no_grad():
        for _ in range(max_len):
            out = model(src, trg)  # (1, T, V)
            next_tok = out[0, -1].argmax().item()
            trg = torch.cat([trg, torch.tensor([[next_tok]], device=device)], dim=1)
            if next_tok == tok_trg.word2idx["<eos>"]:
                break
    # strip leading <sos> when decoding
    return tok_trg.decode(trg[0].tolist()[1:])

# ==============================
# 9️⃣ EVALUATE BLEU (use smoothing)
# ==============================
def evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=50):
    model.eval()
    total_bleu = 0
    smooth = SmoothingFunction().method1
    n = min(n, len(test_en))
    for i in range(n):
        pred = translate(model, test_en[i], tok_src, tok_trg, device)
        bleu = sentence_bleu([test_vi[i].split()], pred.split(), smoothing_function=smooth)
        total_bleu += bleu
        if i < 10:  # print a few examples
            print(f"\nEN: {test_en[i]}")
            print(f"GT: {test_vi[i]}")
            print(f"PR: {pred}")
            print(f"BLEU: {bleu:.4f}")

    print("\nAVERAGE BLEU =", total_bleu / n)



In [None]:
# ==============================
# RUN (example)
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerModel(len(tok_src.word2idx), len(tok_trg.word2idx), d_model=256, nhead=4, num_layers=3, pad_idx=tok_src.word2idx["<pad>"])
train_model(model, train_loader, val_loader, device, epochs=10, lr=3e-4, pad_idx=tok_src.word2idx["<pad>"])
# then evaluate
# evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=50)


In [None]:


evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=10)

In [None]:
sent = "See you again"
print(translate(model, sent, tok_src, tok_trg, device))


In [1]:
# full_transformer_nmt.py
import math
import random
from collections import Counter
import os
import sys
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# -------------------------
# CẤU HÌNH / SEED
# -------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

DATA_DIR = "/kaggle/input/en-vi-ds/data"
TRAIN_EN = os.path.join(DATA_DIR, "train.en")
TRAIN_VI = os.path.join(DATA_DIR, "train.vi")
TEST_EN  = os.path.join(DATA_DIR, "tst2013.en")
TEST_VI  = os.path.join(DATA_DIR, "tst2013.vi")

# Hyperparams (chỉnh nếu cần)
BATCH_SIZE = 32
D_MODEL = 256
NHEAD = 4
NUM_LAYERS = 3
EPOCHS = 20
LR = 3e-4
WARMUP_STEPS = 4000
BEAM_SIZE = 5
MAX_LEN = 60
MIN_FREQ = 1   # word-level vocab threshold
PAD_IDX = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ==============================
# 1️⃣ LOAD DỮ LIỆU
# ==============================
def load_lines(path):
    with open(path, "r", encoding="utf-8") as f:
        return f.read().splitlines()

train_en = load_lines(TRAIN_EN)
train_vi = load_lines(TRAIN_VI)
test_en = load_lines(TEST_EN)
test_vi = load_lines(TEST_VI)
print("Train pairs:", len(train_en), "Test pairs:", len(test_en))

# ==============================
# 2️⃣ TOKENIZER (word-level) - compatible with your pipeline
# ==============================
class Tokenizer:
    def __init__(self, texts, min_freq=1):
        self.word2idx = {"<pad>":0, "<sos>":1, "<eos>":2, "<unk>":3}
        self.idx2word = {v:k for k,v in self.word2idx.items()}
        self.build_vocab(texts, min_freq)

    def build_vocab(self, texts, min_freq):
        counter = Counter()
        for line in texts:
            counter.update(line.strip().lower().split())
        for word, freq in counter.items():
            if freq >= min_freq and word not in self.word2idx:
                idx = len(self.word2idx)
                self.word2idx[word] = idx
                self.idx2word[idx] = word

    def encode(self, text):
        return [self.word2idx.get(tok, self.word2idx["<unk>"]) for tok in text.strip().lower().split()]

    def decode(self, ids):
        words = []
        for i in ids:
            if i == self.word2idx["<eos>"]:
                break
            if i <= 3:
                if i == self.word2idx["<unk>"]:
                    words.append("<unk>")
                continue
            words.append(self.idx2word.get(i, "<unk>"))
        return " ".join(words)

tok_src = Tokenizer(train_en, min_freq=MIN_FREQ)
tok_trg = Tokenizer(train_vi, min_freq=MIN_FREQ)
print("Src vocab:", len(tok_src.word2idx), "Trg vocab:", len(tok_trg.word2idx))

# ==============================
# 3️⃣ DATASET
# ==============================
class TranslationDataset(Dataset):
    def __init__(self, src_lines, trg_lines, tok_src, tok_trg):
        self.src = src_lines
        self.trg = trg_lines
        self.tok_src = tok_src
        self.tok_trg = tok_trg

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        s = [self.tok_src.word2idx["<sos>"]] + self.tok_src.encode(self.src[idx]) + [self.tok_src.word2idx["<eos>"]]
        t = [self.tok_trg.word2idx["<sos>"]] + self.tok_trg.encode(self.trg[idx]) + [self.tok_trg.word2idx["<eos>"]]
        return torch.tensor(s, dtype=torch.long), torch.tensor(t, dtype=torch.long)

def collate_fn(batch, pad_idx=0):
    src, trg = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=pad_idx)
    trg = nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=pad_idx)
    return src, trg

dataset = TranslationDataset(train_en, train_vi, tok_src, tok_trg)
train_len = int(0.9 * len(dataset))
val_len = len(dataset) - train_len
train_set, val_set = random_split(dataset, [train_len, val_len])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# ==============================
# 4️⃣ POS ENCODING
# ==============================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# ==============================
# 5️⃣ TRANSFORMER MODEL
# ==============================
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model=256, nhead=4, num_layers=3, pad_idx=0, dropout=0.1):
        super().__init__()
        self.pad_idx = pad_idx
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=pad_idx)
        self.trg_emb = nn.Embedding(trg_vocab, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )
        self.fc = nn.Linear(d_model, trg_vocab)
        # tie weights
        try:
            self.fc.weight = self.trg_emb.weight
        except Exception:
            pass

    def forward(self, src, trg):
        device = src.device
        src_key_padding_mask = (src == self.pad_idx)
        trg_key_padding_mask = (trg == self.pad_idx)
        T = trg.size(1)
        trg_mask = nn.Transformer.generate_square_subsequent_mask(T).to(device)

        src_emb = self.pos(self.src_emb(src))
        trg_emb = self.pos(self.trg_emb(trg))

        out = self.transformer(
            src_emb, trg_emb,
            tgt_mask=trg_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=trg_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )
        return self.fc(out)

# ==============================
# 6️⃣ LABEL SMOOTHING LOSS
# ==============================
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1, ignore_index=0):
        super().__init__()
        self.ignore_index = ignore_index
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.classes = classes

    def forward(self, pred, target):
        # pred: (N, C) logits, target: (N,)
        pred = pred.log_softmax(dim=-1)  # (N, C)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.classes - 1))
            mask = (target != self.ignore_index)
            # scatter target positions
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            # zero out pad positions
            true_dist = true_dist * mask.unsqueeze(1)
        loss = torch.mean(torch.sum(-true_dist * pred, dim=1))
        return loss

# ==============================
# 7️⃣ LR SCHEDULER (warmup style)
# ==============================
def get_warmup_scheduler(optimizer, d_model=D_MODEL, warmup=WARMUP_STEPS):
    def lr_lambda(step):
        step = max(1, step)
        return (d_model ** -0.5) * min(step ** -0.5, step * (warmup ** -1.5))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# ==============================
# 8️⃣ BEAM SEARCH (translate)
# ==============================
def translate_beam(model, text, tok_src, tok_trg, device, beam_size=5, max_len=60):
    model.eval()
    sos = tok_trg.word2idx["<sos>"]
    eos = tok_trg.word2idx["<eos>"]

    src = [tok_src.word2idx["<sos>"]] + tok_src.encode(text) + [tok_src.word2idx["<eos>"]]
    src = torch.tensor(src, dtype=torch.long).unsqueeze(0).to(device)

    # beams: list of (sequence_tensor, score)
    beams = [(torch.tensor([[sos]], device=device, dtype=torch.long), 0.0)]
    completed = []

    with torch.no_grad():
        for _step in range(max_len):
            new_beams = []
            for seq, score in beams:
                if seq[0, -1].item() == eos:
                    # already ended, keep in completed
                    completed.append((seq, score))
                    continue

                out = model(src, seq)  # (1, T, V)
                logits = out[0, -1]    # (V,)
                log_probs = torch.log_softmax(logits, dim=-1)  # (V,)
                topk = torch.topk(log_probs, beam_size)

                for k in range(topk.indices.size(0)):
                    tok = topk.indices[k].item()
                    lp = topk.values[k].item()
                    new_seq = torch.cat([seq, torch.tensor([[tok]], device=device)], dim=1)
                    new_score = score + lp
                    new_beams.append((new_seq, new_score))

            # keep top beam_size
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]

            # stop if we have enough completed and best beam ended
            if len(completed) >= beam_size:
                break

        # add remaining beams to completed (if none ended)
        completed.extend(beams)
        # sort completed by score
        completed = sorted(completed, key=lambda x: x[1], reverse=True)
        best_seq = completed[0][0][0].tolist()[1:]  # remove <sos>
        return tok_trg.decode(best_seq)

# ==============================
# 9️⃣ TRAIN + VALIDATION
# ==============================
def train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR, pad_idx=PAD_IDX):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = get_warmup_scheduler(optimizer, d_model=D_MODEL, warmup=WARMUP_STEPS)
    loss_fn = LabelSmoothingLoss(classes=len(tok_trg.word2idx), smoothing=0.1, ignore_index=pad_idx)

    best_val = float('inf')
    global_step = 0

    for ep in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        t0 = time.time()

        for src, trg in train_loader:
            src, trg = src.to(device), trg.to(device)
            optimizer.zero_grad()
            out = model(src, trg[:, :-1])  # predict for positions 1..T-1
            N, T, V = out.shape
            loss = loss_fn(out.reshape(-1, V), trg[:, 1:].reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            global_step += 1
            total_loss += loss.item()

        avg_train = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for src, trg in val_loader:
                src, trg = src.to(device), trg.to(device)
                out = model(src, trg[:, :-1])
                N, T, V = out.shape
                loss = loss_fn(out.reshape(-1, V), trg[:, 1:].reshape(-1))
                val_loss += loss.item()
        avg_val = val_loss / len(val_loader)
        t1 = time.time()

        print(f"Epoch {ep}/{epochs} | Train {avg_train:.4f} | Val {avg_val:.4f} | Time {t1-t0:.1f}s")

        if avg_val < best_val:
            best_val = avg_val
            torch.save({
                'model_state': model.state_dict(),
                'tok_src': tok_src.word2idx,
                'tok_trg': tok_trg.word2idx
            }, "best_model.pt")
            print("✔ Saved best model!")

# ==============================
# 10️⃣ EVALUATE BLEU
# ==============================
def evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=50, use_beam=True):
    model.to(device)
    model.eval()
    total_bleu = 0.0
    smooth = SmoothingFunction().method1
    n = min(n, len(test_en))
    for i in range(n):
        sent = test_en[i]
        if use_beam:
            pred = translate_beam(model, sent, tok_src, tok_trg, device, beam_size=BEAM_SIZE, max_len=MAX_LEN)
        else:
            pred = translate_greedy(model, sent, tok_src, tok_trg, device, max_len=MAX_LEN)
        bleu = sentence_bleu([test_vi[i].split()], pred.split(), smoothing_function=smooth)
        total_bleu += bleu
        if i < 10:
            print("\nEN:", sent)
            print("GT:", test_vi[i])
            print("PR:", pred)
            print(f"BLEU: {bleu:.4f}")

    print("\nAVERAGE BLEU =", total_bleu / n)

# greedy translate for fallback printing
def translate_greedy(model, text, tok_src, tok_trg, device, max_len=60):
    model.eval()
    src = [tok_src.word2idx["<sos>"]] + tok_src.encode(text) + [tok_src.word2idx["<eos>"]]
    src = torch.tensor(src, dtype=torch.long).unsqueeze(0).to(device)
    trg = torch.tensor([[tok_trg.word2idx["<sos>"]]], dtype=torch.long).to(device)
    with torch.no_grad():
        for _ in range(max_len):
            out = model(src, trg)
            next_tok = out[0, -1].argmax().item()
            trg = torch.cat([trg, torch.tensor([[next_tok]], device=device)], dim=1)
            if next_tok == tok_trg.word2idx["<eos>"]:
                break
    return tok_trg.decode(trg[0].tolist()[1:])

# ==============================
# 11️⃣ RUN: tạo model, train, evaluate, interactive
# ==============================
if __name__ == "__main__":
    model = TransformerModel(
        src_vocab=len(tok_src.word2idx),
        trg_vocab=len(tok_trg.word2idx),
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        pad_idx=PAD_IDX,
        dropout=0.1
    )

    print("Start training ...")
    train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR, pad_idx=PAD_IDX)

    # load best model
    if os.path.exists("best_model.pt"):
        ckpt = torch.load("best_model.pt", map_location=device)
        model.load_state_dict(ckpt['model_state'])
        model.to(device)
        print("Loaded best_model.pt")

    # evaluate
    evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=50, use_beam=True)

    # interactive
    print("\nNhập câu tiếng Anh để dịch (gõ 'exit' để thoát):")
    while True:
        text = input("EN> ").strip()
        if text.lower() in ("exit", "quit"): break
        out = translate_beam(model, text, tok_src, tok_trg, device, beam_size=BEAM_SIZE, max_len=MAX_LEN)
        print("VI>", out)


Device: cuda
Train pairs: 133317 Test pairs: 1268
Src vocab: 47861 Trg vocab: 22443
Start training ...




OutOfMemoryError: CUDA out of memory. Tried to allocate 1.73 GiB. GPU 0 has a total capacity of 14.74 GiB of which 1.61 GiB is free. Process 35505 has 13.13 GiB memory in use. Of the allocated memory 11.54 GiB is allocated by PyTorch, and 1.45 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)