In [None]:
# transformer_toy_example.py
import math, random, torch, torch.nn as nn, torch.optim as optim

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)

# ===== 1) Dữ liệu toy =====
# Ta map a,b,c,d,e -> x,y,z,u,v (dịch đơn giản)
src_vocab = {'<pad>':0, '<sos>':1, '<eos>':2, 'a':3, 'b':4, 'c':5, 'd':6, 'e':7}
tgt_vocab = {'<pad>':0, '<sos>':1, '<eos>':2, 'x':3, 'y':4, 'z':5, 'u':6, 'v':7}
inv_tgt = {v:k for k,v in tgt_vocab.items()}

def make_sample():
    seq = random.choices(['a','b','c','d','e'], k=random.randint(3,5))
    trg = [chr(ord(ch)+23) for ch in seq]  # a->x, b->y, c->z, d->u, e->v
    return seq, trg

def encode(seq, vocab):
    return [vocab['<sos>']] + [vocab[s] for s in seq] + [vocab['<eos>']]

def pad_seq(seq, max_len):
    return seq + [0]*(max_len - len(seq))

# Tạo batch toy
def gen_batch(batch_size=4):
    srcs, tgts = [], []
    for _ in range(batch_size):
        s, t = make_sample()
        srcs.append(encode(s, src_vocab))
        tgts.append(encode(t, tgt_vocab))
    max_s, max_t = max(len(s) for s in srcs), max(len(t) for t in tgts)
    src_pad = torch.tensor([pad_seq(s, max_s) for s in srcs], device=DEVICE)
    tgt_pad = torch.tensor([pad_seq(t, max_t) for t in tgts], device=DEVICE)
    return src_pad, tgt_pad

# ===== 2) Mô hình =====
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))  # [1,max_len,d_model]
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class TransformerToy(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.src_emb = nn.Embedding(len(src_vocab), d_model)
        self.tgt_emb = nn.Embedding(len(tgt_vocab), d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=128,
            batch_first=True
        )
        self.fc_out = nn.Linear(d_model, len(tgt_vocab))

    def forward(self, src, tgt):
        # src,tgt: [B, T]
        src_mask = self.make_pad_mask(src)
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(1)).to(DEVICE)
        src_emb = self.pos_enc(self.src_emb(src))
        tgt_emb = self.pos_enc(self.tgt_emb(tgt))
        out = self.transformer(
            src_emb, tgt_emb,
            src_key_padding_mask=src_mask,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=self.make_pad_mask(tgt)
        )
        return self.fc_out(out)

    def make_pad_mask(self, x):
        # True tại vị trí PAD
        return (x == 0)

# ===== 3) Huấn luyện nhanh =====
model = TransformerToy(src_vocab, tgt_vocab).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 300
for ep in range(1, EPOCHS+1):
    model.train()
    src, tgt = gen_batch(batch_size=16)
    optimizer.zero_grad()
    out = model(src, tgt[:, :-1])  # predict bước kế từ <sos>...<n-1>
    loss = criterion(out.reshape(-1, len(tgt_vocab)), tgt[:, 1:].reshape(-1))
    loss.backward()
    optimizer.step()
    if ep % 50 == 0 or ep == 1:
        print(f"[Epoch {ep}] loss={loss.item():.4f}")

# ===== 4) Suy luận =====
@torch.no_grad()
def translate(seq):
    model.eval()
    src = torch.tensor([encode(seq, src_vocab)], device=DEVICE)
    tgt = torch.tensor([[tgt_vocab['<sos>']]], device=DEVICE)
    for _ in range(10):
        out = model(src, tgt)
        next_tok = out[:, -1, :].argmax(-1).unsqueeze(1)
        tgt = torch.cat([tgt, next_tok], dim=1)
        if next_tok.item() == tgt_vocab['<eos>']:
            break
    pred = [inv_tgt[i.item()] for i in tgt[0][1:-1]]
    return "".join(pred)

# Thử một vài chuỗi
for s in [["a","b","c"], ["d","e"], ["a","c","e","b"]]:
    print(f"{s} -> {translate(s)}")
