In [None]:
# rnn_minimal_next_token.py
import torch
import torch.nn as nn
import torch.optim as optim

# ===== 1) Dữ liệu toy: He -> feels -> happy -> today -> (quay vòng) =====
vocab = ["He", "feels", "happy", "today"]
stoi = {w:i for i,w in enumerate(vocab)}
itos = {i:w for w,i in stoi.items()}

# Xây các cửa sổ độ dài 3 để dự đoán từ thứ 4
# [He, feels, happy] -> today
# [feels, happy, today] -> He
# [happy, today, He] -> feels
# [today, He, feels] -> happy
windows = [
    ["He", "feels", "happy"],
    ["feels", "happy", "today"],
    ["happy", "today", "He"],
    ["today", "He", "feels"]
]
targets = ["today", "He", "feels", "happy"]

X = torch.tensor([[stoi[w] for w in win] for win in windows])  # [batch=4, seq_len=3]
Y = torch.tensor([stoi[t] for t in targets])                   # [batch=4]

# ===== 2) Mô hình: Embedding -> RNN -> Linear =====
class NextTokenRNN(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int = 16, hidden_size: int = 32):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size)
        # Dùng RNN có sẵn của PyTorch
        self.rnn = nn.RNN(input_size=emb_size, hidden_size=hidden_size, batch_first=True)
        self.fc  = nn.Linear(hidden_size, vocab_size)

    def forward(self, x_ids):  # x_ids: [batch, seq_len]
        x = self.emb(x_ids)            # [batch, seq_len, emb]
        out, hT = self.rnn(x)          # out: [batch, seq_len, H]
        last = out[:, -1, :]           # lấy output ở bước cuối: [batch, H]
        logits = self.fc(last)         # [batch, V]
        return logits

model = NextTokenRNN(vocab_size=len(vocab))
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=0.05)

# ===== 3) Train ngắn gọn =====
for ep in range(1, 401):
    model.train()
    opt.zero_grad()
    logits = model(X)
    loss = criterion(logits, Y)
    loss.backward()
    opt.step()

    if ep % 50 == 0 or ep == 1:
        pred = logits.argmax(dim=-1).tolist()
        pred_words = [itos[i] for i in pred]
        print(f"[ep {ep:3d}] loss={loss.item():.4f} | pred={pred_words}")

# ===== 4) Suy luận (inference): dự đoán tiếp theo cho 1 chuỗi cho trước =====
def predict_next(words3):
    model.eval()
    ids = torch.tensor([[stoi[w] for w in words3]])  # [1,3]
    with torch.no_grad():
        logits = model(ids)
        next_id = int(logits.argmax(dim=-1)[0])
    return itos[next_id]

print("\n--- Inference ---")
test_seq = ["He", "feels", "happy"]
print(test_seq, "->", predict_next(test_seq))
test_seq = ["feels", "happy", "today"]
print(test_seq, "->", predict_next(test_seq))
