# 01 · Train Tiny Transformer EN→VI (from scratch)

Pipeline tối giản: **dữ liệu → tiền xử lý → tokenization → mô hình (2E/2D) → huấn luyện → metrics → checkpoint**.
> Lưu ý: Notebook này tự đủ, không phụ thuộc module ngoài `requirements.txt`.

In [None]:
# 0) Imports & seed
import os, math, json, random, pathlib, time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] if '__file__' in globals() else pathlib.Path.cwd().parents[1]
DATA_RAW = PROJECT_ROOT / "data" / "raw"
DATA_PROCESSED = PROJECT_ROOT / "data" / "processed"
CKPT_DIR = PROJECT_ROOT / "models" / "checkpoints"
METRIC_DIR = PROJECT_ROOT / "reports" / "metrics"
OUT_DIR = PROJECT_ROOT / "outputs" / "translations"
for p in [CKPT_DIR, METRIC_DIR, OUT_DIR, DATA_PROCESSED]:
    p.mkdir(parents=True, exist_ok=True)

print("Device:", device)
print("Project root:", PROJECT_ROOT)


## 1) Nạp dữ liệu song ngữ
- Đặt file song ngữ vào `data/raw/` (ví dụ: `article_en_vi.txt`).
- Format tối giản: mỗi dòng là một cặp câu `en\tvi` (tab giữa EN và VI).

In [None]:
# Đường dẫn ví dụ (bạn thay bằng file thật)
raw_file = DATA_RAW / "article_en_vi.txt"
if not raw_file.exists():
    # Tạo file toy để notebook tự chạy (placeholder)
    toy = [
        "Hello world\tXin chào thế giới",
        "This is a small dataset.\tĐây là một tập dữ liệu nhỏ.",
        "We build a tiny transformer.\tChúng tôi xây một transformer nhỏ.",
        "Attention is all you need.\tSự chú ý là tất cả những gì bạn cần.",
    ]
    raw_file.write_text("\n".join(toy), encoding="utf-8")
print("Using data file:", raw_file)
pairs = [line.strip().split("\t") for line in raw_file.read_text(encoding="utf-8").splitlines() if "\t" in line]
print("Num pairs:", len(pairs))
pairs[:3]


## 2) Tokenization tối giản
Ta bắt đầu với token hoá whitespace (đủ cho toy). Sau này có thể nâng cấp BPE/SentencePiece nếu cần.

In [None]:
# Whitespace tokenizer
def tokenize_ws(s): 
    return s.lower().strip().split()

# Build vocab từ dữ liệu nhỏ
def build_vocab(texts, min_freq=1, specials=("<pad>", "<s>", "</s>", "<unk>")):
    from collections import Counter
    cnt = Counter()
    for t in texts:
        cnt.update(tokenize_ws(t))
    itos = list(specials)
    itos += [w for w,f in cnt.items() if f>=min_freq and w not in specials]
    stoi = {w:i for i,w in enumerate(itos)}
    return stoi, itos

en_texts = [en for en,_ in pairs]
vi_texts = [vi for _,vi in pairs]
src2i, i2src = build_vocab(en_texts)
tgt2i, i2tgt = build_vocab(vi_texts)

PAD, BOS, EOS, UNK = 0, 1, 2, 3

def encode_line(s, stoi, add_bos=False, add_eos=True, max_len=64):
    toks = tokenize_ws(s)
    ids = [stoi.get(t, UNK) for t in toks]
    if add_bos: ids = [BOS] + ids
    if add_eos: ids = ids + [EOS]
    return ids[:max_len]

print("src_vocab_size:", len(i2src), "tgt_vocab_size:", len(i2tgt))


## 3) Dataloader nhỏ

In [None]:
def pad_batch(seqs, pad=PAD):
    mx = max(len(s) for s in seqs)
    out = [s + [pad]*(mx-len(s)) for s in seqs]
    return torch.tensor(out, dtype=torch.long)

dataset = [(encode_line(en, src2i, add_bos=False, add_eos=True),
            encode_line(vi, tgt2i, add_bos=True,  add_eos=True)) for en,vi in pairs]

# Simple train/dev split
random.shuffle(dataset)
n_dev = max(1, int(0.2*len(dataset)))
dev_data = dataset[:n_dev]
train_data = dataset[n_dev:]

def batches(data, bs=32):
    for i in range(0, len(data), bs):
        yield data[i:i+bs]


## 4) Kiến trúc Transformer tối giản (2E/2D)
- Multi-Head Attention, FFN, Positional Encoding.
- Không dùng library `transformers`; chỉ PyTorch thuần.

In [None]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # [1, max_len, d_model]
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

# Scaled Dot-Product Attention
def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask==0, float("-inf"))
    attn = torch.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)
    return torch.matmul(attn, v), attn

class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, query, key, value, mask=None):
        bs = query.size(0)
        def split(x):
            return x.view(bs, -1, self.h, self.d_k).transpose(1,2)
        q = split(self.linear_q(query))
        k = split(self.linear_k(key))
        v = split(self.linear_v(value))
        if mask is not None:
            mask = mask.unsqueeze(1)  # broadcast over heads
        x, attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        x = x.transpose(1,2).contiguous().view(bs, -1, self.h*self.d_k)
        return self.linear_out(x)

class PositionwiseFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return self.w2(self.dropout(F.relu(self.w1(x))))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, self_h, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(self_h, d_model, dropout)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, src_mask):
        x2 = self.self_attn(x, x, x, src_mask)
        x = self.norm1(x + self.drop(x2))
        x2 = self.ffn(x)
        x = self.norm2(x + self.drop(x2))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, self_h, cross_h, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(self_h, d_model, dropout)
        self.cross_attn = MultiHeadAttention(cross_h, d_model, dropout)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, memory, tgt_mask, src_mask):
        x2 = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.drop(x2))
        x2 = self.cross_attn(x, memory, memory, src_mask)
        x = self.norm2(x + self.drop(x2))
        x2 = self.ffn(x)
        x = self.norm3(x + self.drop(x2))
        return x

def subsequent_mask(sz):
    # mask cho decoder để cấm nhìn tương lai
    attn_shape = (1, sz, sz)
    subsequent = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return (subsequent == 0)

class TinyTransformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=128, N=2, h=4, d_ff=256, dropout=0.1):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab, d_model, padding_idx=PAD)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD)
        self.pos = PositionalEncoding(d_model)
        self.enc_layers = nn.ModuleList([EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N)])
        self.dec_layers = nn.ModuleList([DecoderLayer(d_model, h, h, d_ff, dropout) for _ in range(N)])
        self.proj = nn.Linear(d_model, tgt_vocab)
    def encode(self, src, src_mask):
        x = self.pos(self.src_embed(src))
        for layer in self.enc_layers:
            x = layer(x, src_mask)
        return x
    def decode(self, tgt, memory, tgt_mask, src_mask):
        x = self.pos(self.tgt_embed(tgt))
        for layer in self.dec_layers:
            x = layer(x, memory, tgt_mask, src_mask)
        return x
    def forward(self, src, tgt, src_mask, tgt_mask):
        memory = self.encode(src, src_mask)
        dec = self.decode(tgt, memory, tgt_mask, src_mask)
        return self.proj(dec)


## 5) Huấn luyện tối giản
- Loss: CrossEntropy (tgt bị dịch trái 1 token).
- Optim: Adam.
- Metrics: loss/accuracy; thêm BLEU nhỏ trên dev.

In [None]:
def make_src_mask(src):
    return (src != PAD).unsqueeze(1).unsqueeze(2)  # [B,1,1,S]

def make_tgt_mask(tgt):
    b, t = tgt.size()
    pad_mask = (tgt != PAD).unsqueeze(1).unsqueeze(2)  # [B,1,1,T]
    sub_mask = subsequent_mask(t).to(tgt.device)        # [1,T,T]
    return pad_mask & sub_mask

def shift_tgt(tgt):
    # input to decoder (tgt_in) and labels (tgt_out)
    return tgt[:, :-1], tgt[:, 1:]

model = TinyTransformer(len(i2src), len(i2tgt)).to(device)
optim = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=PAD)

def run_epoch(data, train=True, bs=32):
    model.train(train)
    total_loss, total_tok, correct = 0.0, 0, 0
    for batch in batches(data, bs):
        src, tgt = zip(*batch)
        src = pad_batch(list(src)).to(device)
        tgt = pad_batch(list(tgt)).to(device)
        tgt_in, tgt_out = shift_tgt(tgt)
        src_mask = make_src_mask(src)
        tgt_mask = make_tgt_mask(tgt_in)
        logits = model(src, tgt_in, src_mask, tgt_mask)
        # logits: [B,T,V]
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        if train:
            optim.zero_grad(); loss.backward(); optim.step()
        total_loss += loss.item() * tgt_out.ne(PAD).sum().item()
        preds = logits.argmax(-1)
        mask = tgt_out.ne(PAD)
        correct += (preds[mask] == tgt_out[mask]).sum().item()
        total_tok += mask.sum().item()
    return total_loss/ max(1,total_tok), correct/ max(1,total_tok)

EPOCHS = 5
history = {"train_loss":[], "train_acc":[], "dev_loss":[], "dev_acc":[]}
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_data, train=True, bs=16)
    dv_loss, dv_acc = run_epoch(dev_data, train=False, bs=16)
    history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)
    history["dev_loss"].append(dv_loss); history["dev_acc"].append(dv_acc)
    print(f"Epoch {ep:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | dev loss {dv_loss:.4f} acc {dv_acc:.3f}")

# Lưu checkpoint & metrics
ckpt_path = CKPT_DIR / "tiny_transformer_en2vi.pt"
torch.save({"model": model.state_dict(),
            "src2i": src2i, "i2src": i2src,
            "tgt2i": tgt2i, "i2tgt": i2tgt},
           ckpt_path)
(json.dumps(history, indent=2))
metrics_path = METRIC_DIR / "history.json"
metrics_path.write_text(json.dumps(history, ensure_ascii=False, indent=2), encoding="utf-8")
print("Saved:", ckpt_path, "and", metrics_path)


## 6) BLEU (dev nhỏ)

In [None]:
from sacrebleu.metrics import BLEU

def greedy_decode(model, src_ids, max_len=64):
    model.eval()
    src = torch.tensor([src_ids], dtype=torch.long, device=device)
    src_mask = make_src_mask(src)
    memory = model.encode(src, src_mask)
    ys = torch.tensor([[BOS]], dtype=torch.long, device=device)
    for _ in range(max_len-1):
        tgt_mask = make_tgt_mask(ys)
        out = model.decode(ys, memory, tgt_mask, src_mask)
        prob = model.proj(out)[:, -1, :].softmax(-1)
        next_token = prob.argmax(-1).item()
        ys = torch.cat([ys, torch.tensor([[next_token]], device=device)], dim=1)
        if next_token == EOS:
            break
    return ys.squeeze(0).tolist()

def detok(ids, i2w):
    toks = []
    for i in ids:
        if i in (PAD, BOS): continue
        if i == EOS: break
        toks.append(i2w[i])
    return " ".join(toks)

bleu = BLEU()
refs, hyps = [], []
for src_ids, tgt_ids in dev_data:
    pred_ids = greedy_decode(model, src_ids)
    refs.append([detok(tgt_ids, i2tgt)])
    hyps.append(detok(pred_ids, i2tgt))

score = bleu.corpus_score(hyps, refs)
print("DEV BLEU:", score)
# Lưu sample
sample_path = OUT_DIR / "dev_samples.txt"
sample = "\n".join([f"HYP: {h}\nREF: {r[0]}" for h,r in zip(hyps, refs)])
sample_path.write_text(sample, encoding="utf-8")
print("Saved samples to", sample_path)


> Notebook kết thúc: bạn đã có **checkpoint** và **metrics** để dùng trong demo.