In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

# raw data 
raw_data = [
    ("hello", "xin chào"),
    ("how are you", "bạn khỏe không"),
    ("i am fine", "mình khỏe"),
    ("thank you", "cảm ơn"),
    ("what is your name", "tên bạn là gì"),
    ("my name is john", "tên mình là john"),
    ("good morning", "chào buổi sáng"),
    ("good night", "chúc ngủ ngon"),
    ("see you later", "hẹn gặp lại"),
    ("i love you", "tôi yêu bạn"),
]
raw_data = [(s.lower().strip(), t.lower().strip()) for s, t in raw_data]

PAD_TOKEN = "<pad>"
BOS_TOKEN = "<s>"
EOS_TOKEN = "</s>"
UNK_TOKEN = "<unk>"

# Building vocab 
def build_vocab(pairs, min_freq=1):
    freq = {}
    for s, t in pairs:
        for w in (s.split() + t.split()):
            freq[w] = freq.get(w, 0) + 1
    itos = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + sorted([w for w, c in freq.items() if c >= min_freq])
    stoi = {w: i for i, w in enumerate(itos)}
    return stoi, itos

src_stoi, src_itos = build_vocab(raw_data)
tgt_stoi, tgt_itos = src_stoi, src_itos
SRC_VOCAB_SIZE = len(src_itos)
TGT_VOCAB_SIZE = len(tgt_itos)
print("Vocab size:", SRC_VOCAB_SIZE)

# Encode sentence
MAX_LEN = 10
def encode_sentence(sent, stoi, max_len=MAX_LEN):
    tokens = sent.split()
    ids = [stoi.get(w, stoi[UNK_TOKEN]) for w in tokens]
    ids = [stoi[BOS_TOKEN]] + ids + [stoi[EOS_TOKEN]]
    if len(ids) < max_len:
        ids += [stoi[PAD_TOKEN]] * (max_len - len(ids))
    else:
        ids = ids[:max_len-1] + [stoi[EOS_TOKEN]]
    return ids

# Dataset 
class TinyTranslationDataset(Dataset):
    def __init__(self, pairs, src_stoi, tgt_stoi, max_len=MAX_LEN):
        self.pairs = pairs
        self.src_stoi = src_stoi
        self.tgt_stoi = tgt_stoi
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        s, t = self.pairs[idx]
        src_ids = torch.tensor(encode_sentence(s, self.src_stoi, self.max_len), dtype=torch.long)
        tgt_ids = torch.tensor(encode_sentence(t, self.tgt_stoi, self.max_len), dtype=torch.long)
        return src_ids, tgt_ids

dataset = TinyTranslationDataset(raw_data, src_stoi, tgt_stoi, MAX_LEN)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

#POS ENCODING
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Model 
class TransformerSeq2Seq(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=64, nhead=4,
                 num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=128,
                 dropout=0.1, max_len=MAX_LEN):
        super().__init__()
        self.d_model = d_model
        self.src_tok_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_len)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True)
        self.output_layer = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        src_emb = self.pos_encoder(self.src_tok_emb(src) * math.sqrt(self.d_model))
        tgt_emb = self.pos_encoder(self.tgt_tok_emb(tgt) * math.sqrt(self.d_model))
        src_pad_mask = (src == src_stoi[PAD_TOKEN])
        tgt_pad_mask = (tgt == tgt_stoi[PAD_TOKEN])
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(1)).to(src.device)
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_pad_mask,
                               tgt_key_padding_mask=tgt_pad_mask,
                               memory_key_padding_mask=src_pad_mask)
        return self.output_layer(out)

# Train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerSeq2Seq(SRC_VOCAB_SIZE, TGT_VOCAB_SIZE).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tgt_stoi[PAD_TOKEN])
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def shift_right(tgt_batch): return tgt_batch[:, :-1]
def target_for_loss(tgt_batch): return tgt_batch[:, 1:]

EPOCHS = 120
losses = []

model.train()
for epoch in range(1, EPOCHS + 1):
    total_loss = 0
    for src_batch, tgt_batch in dataloader:
        src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
        decoder_input = shift_right(tgt_batch)
        optimizer.zero_grad()
        logits = model(src_batch, decoder_input)
        loss = criterion(logits.view(-1, logits.size(-1)),
                         target_for_loss(tgt_batch).contiguous().view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    losses.append(avg_loss)
    if epoch % 30 == 0 or epoch == 1:
        print(f"Epoch {epoch}/{EPOCHS} - loss: {avg_loss:.4f}")

# Loss figure
plt.figure(figsize=(8,5))
plt.plot(losses, label="Training Loss", linewidth=2)
plt.title("Transformer Training Loss", fontsize=14)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True, linestyle="--", alpha=0.6)
plt.legend()
plt.show()

# Translate 
model.eval()
def greedy_translate(model, src_sentence, src_stoi, tgt_itos, max_len=MAX_LEN):
    with torch.no_grad():
        src_ids = torch.tensor(encode_sentence(src_sentence, src_stoi, max_len), dtype=torch.long).unsqueeze(0).to(device)
        tgt_ids = torch.tensor([tgt_stoi[BOS_TOKEN]], dtype=torch.long).unsqueeze(0).to(device)
        for i in range(max_len - 1):
            logits = model(src_ids, tgt_ids)
            next_token = torch.argmax(logits[0, -1, :], dim=-1).unsqueeze(0).unsqueeze(0)
            tgt_ids = torch.cat([tgt_ids, next_token], dim=1)
            if next_token.item() == tgt_stoi[EOS_TOKEN]:
                break
        ids = tgt_ids.squeeze(0).tolist()
        words = []
        for idx in ids:
            tok = tgt_itos[idx]
            if tok in (BOS_TOKEN, PAD_TOKEN): continue
            if tok == EOS_TOKEN: break
            words.append(tok)
        return " ".join(words)

print("\n--- Translations (greedy decode) ---")
test_sentences = ["hello", "what is your name", "i love you", "good night", "see you later"]
for s in test_sentences:
    out = greedy_translate(model, s, src_stoi, tgt_itos)
    print(f"EN: {s}  -->  VN: {out}")

# save model
torch.save({
    "model_state": model.state_dict(),
    "src_itos": src_itos,
    "tgt_itos": tgt_itos,
    "src_stoi": src_stoi,
    "tgt_stoi": tgt_stoi
}, "tiny_transformer_en_vi.pth")

