In [None]:
# encoder_decoder_reverse_seq.py
import random
import torch
import torch.nn as nn
import torch.optim as optim

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(0); torch.manual_seed(0)

# ===== 1) Vocab & dữ liệu toy =====
# Vocab: PAD=0, SOS=1, EOS=2, token thực tế: 'a'=3,'b'=4,'c'=5,'d'=6
PAD, SOS, EOS = 0, 1, 2
TOKENS = ['a','b','c','d']
stoi = {ch:i for i,ch in enumerate(['<pad>','<sos>','<eos>'] + TOKENS)}
itos = {i:s for s,i in stoi.items()}
V = len(stoi)  # 7

def rand_seq(min_len=3, max_len=7):
    L = random.randint(min_len, max_len)
    seq = [stoi[random.choice(TOKENS)] for _ in range(L)]
    return seq

# Bài toán: out = reverse(in)
def make_pair():
    inp = rand_seq()
    out = list(reversed(inp))
    return inp, out

# Helper: thêm <sos> / <eos>
def wrap_inp(seq):  # encoder không cần <sos>, chỉ thêm <eos> cho chắc
    return seq + [EOS]
def wrap_out(seq):  # decoder cần <sos> ... <eos>
    return [SOS] + seq + [EOS]

# ===== 2) Model =====
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size=64, hidden_size=128):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size, padding_idx=PAD)
        self.rnn = nn.GRU(emb_size, hidden_size, batch_first=True)
    def forward(self, x):
        # x: [B, T_in]
        x = self.emb(x)             # [B, T_in, E]
        outputs, h_T = self.rnn(x)  # outputs: [B, T_in, H], h_T: [1, B, H]
        return h_T

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size=64, hidden_size=128):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size, padding_idx=PAD)
        self.rnn = nn.GRU(emb_size, hidden_size, batch_first=True)
        self.fc  = nn.Linear(hidden_size, vocab_size)
    def forward(self, y_prev, h):
        # y_prev: [B, 1] id của token trước; h: [1,B,H]
        e = self.emb(y_prev)           # [B,1,E]
        out, h = self.rnn(e, h)        # out: [B,1,H]
        logits = self.fc(out[:, -1])   # [B,V]
        return logits, h

class Seq2Seq(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc, self.dec = enc, dec
    def forward(self, src, tgt, teacher_forcing=0.5):
        """
        src: [B,T_in] (đã có EOS)
        tgt: [B,T_out] (đã có SOS ... EOS)
        """
        B, T_out = tgt.size()
        h = self.enc(src)              # [1,B,H]
        y = tgt[:, 0:1]                # bắt đầu bằng SOS
        logits_all = []
        for t in range(1, T_out):
            logits, h = self.dec(y, h) # [B,V]
            logits_all.append(logits)
            use_tf = random.random() < teacher_forcing
            next_y = tgt[:, t:t+1] if use_tf else logits.argmax(dim=-1, keepdim=True)
            y = next_y
        return torch.stack(logits_all, dim=1)  # [B,T_out-1,V]

# ===== 3) Train loop =====
enc = Encoder(V).to(DEVICE)
dec = Decoder(V).to(DEVICE)
model = Seq2Seq(enc, dec).to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index=PAD)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def tensorize(ids):
    return torch.tensor(ids, dtype=torch.long, device=DEVICE).unsqueeze(0)  # [1,T]

EPOCHS = 2000
for ep in range(1, EPOCHS+1):
    model.train()
    inp_ids, out_ids = make_pair()
    src = tensorize(wrap_inp(inp_ids))     # [1,T_in+1]
    tgt = tensorize(wrap_out(out_ids))     # [1,T_out+2] (SOS ... EOS)

    optimizer.zero_grad()
    logits = model(src, tgt, teacher_forcing=0.5)  # [1, T_out+1, V]
    # Mục tiêu là tgt[:,1:] (bắt đầu dự đoán sau SOS)
    loss = criterion(logits.reshape(-1, V), tgt[:, 1:].reshape(-1))
    loss.backward()
    optimizer.step()

    if ep % 200 == 0 or ep == 1:
        print(f"[ep {ep}] loss={loss.item():.4f}")

# ===== 4) Inference =====
@torch.no_grad()
def infer(inp_tokens, max_len=20):
    model.eval()
    src = tensorize(wrap_inp(inp_tokens))
    h = model.enc(src)
    y = torch.tensor([[SOS]], device=DEVICE)
    out = []
    for _ in range(max_len):
        logits, h = model.dec(y, h)
        pred = int(logits.argmax(dim=-1)[0].item())
        if pred == EOS: break
        out.append(pred)
        y = torch.tensor([[pred]], device=DEVICE)
    return out

def pretty(ids):
    return " ".join(itos[i] for i in ids)

# Thử nghiệm vài mẫu
for _ in range(5):
    inp, tgt = make_pair()
    pred = infer(inp)
    print("\nIN :", pretty(inp))
    print("GOLD:", pretty(list(reversed(inp))))
    print("OUT:", pretty(pred))
