In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# 샘플 데이터 (더 확장 가능)
pairs = [
    ["hello", "hi"],
    ["how are you", "i am fine"],
    ["what is your name", "i am chatbot"],
    ["bye", "see you"]
]


# 간단한 토큰 사전 생성
class Vocab:
    def __init__(self):
        self.word2idx = {"PAD": 0, "SOS": 1, "EOS": 2}
        self.idx2word = {0: "PAD", 1: "SOS", 2: "EOS"}
        self.word_count = 3

    def add_sentence(self, sentence):
        for word in sentence.split():
            if word not in self.word2idx:
                self.word2idx[word] = self.word_count
                self.idx2word[self.word_count] = word
                self.word_count += 1

    def sentence_to_ids(self, sentence):
        return [self.word2idx[word] for word in sentence.split()] + [2]  # EOS

    def ids_to_sentence(self, ids):
        return ' '.join([self.idx2word[i] for i in ids if i not in [0, 1, 2]])


vocab = Vocab()
for q, a in pairs:
    vocab.add_sentence(q)
    vocab.add_sentence(a)

# 파라미터
VOCAB_SIZE = vocab.word_count
EMBED_SIZE = 16
HIDDEN_SIZE = 32


# 인코더
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
        self.gru = nn.GRU(EMBED_SIZE, HIDDEN_SIZE)

    def forward(self, x):
        embedded = self.embedding(x)
        output, hidden = self.gru(embedded)
        return hidden


# 디코더
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
        self.gru = nn.GRU(EMBED_SIZE, HIDDEN_SIZE)
        self.out = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE)

    def forward(self, x, hidden):
        embedded = self.embedding(x).unsqueeze(0)
        output, hidden = self.gru(embedded, hidden)
        output = self.out(output.squeeze(0))
        return output, hidden


encoder = Encoder()
decoder = Decoder()
criterion = nn.CrossEntropyLoss()
enc_optimizer = torch.optim.Adam(encoder.parameters())
dec_optimizer = torch.optim.Adam(decoder.parameters())

# 훈련 루프
for epoch in range(1000):
    idx = random.randint(0, len(pairs) - 1)
    input_seq = torch.tensor(vocab.sentence_to_ids(pairs[idx][0]), dtype=torch.long).unsqueeze(1)
    target_seq = torch.tensor(vocab.sentence_to_ids(pairs[idx][1]), dtype=torch.long)

    enc_hidden = encoder(input_seq)[-1].unsqueeze(0)

    loss = 0
    dec_input = torch.tensor([1])  # SOS
    dec_hidden = enc_hidden

    for t in range(len(target_seq)):
        output, dec_hidden = decoder(dec_input, dec_hidden)
        loss += criterion(output, target_seq[t].unsqueeze(0))
        dec_input = target_seq[t].unsqueeze(0)

    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
    loss.backward()
    enc_optimizer.step()
    dec_optimizer.step()

    if epoch % 200 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


# 추론 함수
def respond(sentence):
    with torch.no_grad():
        input_seq = torch.tensor(vocab.sentence_to_ids(sentence), dtype=torch.long).unsqueeze(1)
        enc_hidden = encoder(input_seq)

        dec_input = torch.tensor([1])  # SOS
        dec_hidden = enc_hidden
        result = []

        for _ in range(10):
            output, dec_hidden = decoder(dec_input, dec_hidden)
            topv, topi = output.topk(1)
            next_word = topi.item()
            if next_word == 2:  # EOS
                break
            result.append(next_word)
            dec_input = topi.detach()

        return vocab.ids_to_sentence(result)


# 예시
print(respond("hello"))
print(respond("what is your name"))