# 02 · Demo Translate (EN→VI)

Notebook này nạp checkpoint từ `models/checkpoints/tiny_transformer_en2vi.pt` và cho phép bạn nhập **câu tiếng Anh** để suy diễn bản dịch **tiếng Việt**.

In [None]:
import json, math, pathlib, torch, torch.nn as nn, torch.nn.functional as F
from pathlib import Path

PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] if '__file__' in globals() else pathlib.Path.cwd().parents[1]
CKPT = PROJECT_ROOT / "models" / "checkpoints" / "tiny_transformer_en2vi.pt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
ckpt = torch.load(CKPT, map_location=device)
src2i, i2src = ckpt["src2i"], ckpt["i2src"]
tgt2i, i2tgt = ckpt["tgt2i"], ckpt["i2tgt"]
PAD, BOS, EOS, UNK = 0,1,2,3


In [None]:
# Định nghĩa lại kiến trúc tối giản (khớp với train nb)
import math, torch, torch.nn as nn, torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask==0, float("-inf"))
    attn = torch.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)
    return torch.matmul(attn, v), attn

class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, query, key, value, mask=None):
        bs = query.size(0)
        def split(x):
            return x.view(bs, -1, self.h, self.d_k).transpose(1,2)
        q = split(self.linear_q(query))
        k = split(self.linear_k(key))
        v = split(self.linear_v(value))
        if mask is not None:
            mask = mask.unsqueeze(1)
        x, _ = attention(q, k, v, mask=mask, dropout=self.dropout)
        x = x.transpose(1,2).contiguous().view(bs, -1, self.h*self.d_k)
        return self.linear_out(x)

class PositionwiseFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return self.w2(self.dropout(F.relu(self.w1(x))))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, self_h, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(self_h, d_model, dropout)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, src_mask):
        x2 = self.self_attn(x, x, x, src_mask)
        x = self.norm1(x + self.drop(x2))
        x2 = self.ffn(x)
        x = self.norm2(x + self.drop(x2))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, self_h, cross_h, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(self_h, d_model, dropout)
        self.cross_attn = MultiHeadAttention(cross_h, d_model, dropout)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, memory, tgt_mask, src_mask):
        x2 = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.drop(x2))
        x2 = self.cross_attn(x, memory, memory, src_mask)
        x = self.norm2(x + self.drop(x2))
        x2 = self.ffn(x)
        x = self.norm3(x + self.drop(x2))
        return x

def subsequent_mask(sz):
    import torch
    attn_shape = (1, sz, sz)
    subsequent = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return (subsequent == 0)

class TinyTransformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=128, N=2, h=4, d_ff=256, dropout=0.1):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab, d_model, padding_idx=PAD)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD)
        self.pos = PositionalEncoding(d_model)
        self.enc_layers = nn.ModuleList([EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N)])
        self.dec_layers = nn.ModuleList([DecoderLayer(d_model, h, h, d_ff, dropout) for _ in range(N)])
        self.proj = nn.Linear(d_model, tgt_vocab)
    def encode(self, src, src_mask):
        x = self.pos(self.src_embed(src))
        for layer in self.enc_layers:
            x = layer(x, src_mask)
        return x
    def decode(self, tgt, memory, tgt_mask, src_mask):
        x = self.pos(self.tgt_embed(tgt))
        for layer in self.dec_layers:
            x = layer(x, memory, tgt_mask, src_mask)
        return x
    def forward(self, src, tgt, src_mask, tgt_mask):
        memory = self.encode(src, src_mask)
        dec = self.decode(tgt, memory, tgt_mask, src_mask)
        return self.proj(dec)


In [None]:
def tokenize_ws(s): 
    return s.lower().strip().split()

def encode_line(s, stoi, add_bos=False, add_eos=True, max_len=64):
    ids = [stoi.get(t, 3) for t in tokenize_ws(s)]  # UNK=3
    if add_bos: ids = [1] + ids
    if add_eos: ids = ids + [2]
    return ids[:max_len]

def make_src_mask(src): return (src != 0).unsqueeze(1).unsqueeze(2)
def make_tgt_mask(tgt):
    b, t = tgt.size()
    pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
    sub_mask = subsequent_mask(t).to(tgt.device)
    return pad_mask & sub_mask

def greedy_decode(model, src_ids, max_len=64):
    model.eval()
    src = torch.tensor([src_ids], dtype=torch.long, device=device)
    src_mask = make_src_mask(src)
    memory = model.encode(src, src_mask)
    ys = torch.tensor([[1]], dtype=torch.long, device=device) # BOS
    for _ in range(max_len-1):
        tgt_mask = make_tgt_mask(ys)
        out = model.decode(ys, memory, tgt_mask, src_mask)
        prob = model.proj(out)[:, -1, :].softmax(-1)
        next_token = prob.argmax(-1).item()
        ys = torch.cat([ys, torch.tensor([[next_token]], device=device)], dim=1)
        if next_token == 2: break  # EOS
    return ys.squeeze(0).tolist()

def detok(ids, i2w):
    toks = []
    for i in ids:
        if i in (0,1): continue
        if i == 2: break
        toks.append(i2w[i])
    return " ".join(toks)

# Khởi tạo & load state
model = TinyTransformer(len(i2src), len(i2tgt)).to(device)
model.load_state_dict(ckpt["model"], strict=True)
model.eval()
print("Checkpoint loaded from", CKPT)


In [None]:
# ✨ Gõ câu tiếng Anh để dịch
text_en = "This is a small dataset."
src_ids = encode_line(text_en, src2i, add_bos=False, add_eos=True)
pred_ids = greedy_decode(model, src_ids, max_len=64)
print("EN:", text_en)
print("VI:", detok(pred_ids, i2tgt))
