In [1]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Словарь
vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "hello": 3, "hi": 4, "bye": 5, "goodbye": 6}
inv_vocab = {i: w for w, i in vocab.items()}
vocab_size = len(vocab)
embedding_dim = 8
hidden_size = 16
max_len = 5

# === Модель
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)

    def forward(self, x):
        x = self.embed(x)
        _, h = self.rnn(x)
        return h

class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h):
        x = self.embed(x)
        out, h = self.rnn(x, h)
        out = self.fc(out)
        return out, h

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt):
        h = self.encoder(src)
        out, _ = self.decoder(tgt, h)
        return out

encoder = Encoder(vocab_size, embedding_dim, hidden_size).to(device)
decoder = Decoder(vocab_size, embedding_dim, hidden_size).to(device)
model = Seq2Seq(encoder, decoder).to(device)

# === Данные
pairs = [
    (["hello"], ["hi"]),
    (["hi"], ["hello"]),
    (["bye"], ["goodbye"]),
]

def encode(sentence, vocab):
    tokens = [vocab.get(w, vocab["<pad>"]) for w in sentence] + [vocab["<eos>"]]
    tokens += [vocab["<pad>"]] * (max_len - len(tokens))
    return torch.tensor(tokens[:max_len])

src_data = torch.stack([encode(src, vocab) for src, _ in pairs])
tgt_data = torch.stack([encode(["<sos>"] + tgt, vocab) for _, tgt in pairs])
tgt_labels = torch.stack([encode(tgt + ["<eos>"], vocab) for _, tgt in pairs])

# === Обучение
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])

model.train()
for epoch in range(1000):
    optimizer.zero_grad()
    output = model(src_data.to(device), tgt_data.to(device))
    loss = loss_fn(output.view(-1, vocab_size), tgt_labels.view(-1).to(device))
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 200 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# === Инференс
def respond_to_input(input_word):
    input_idx = vocab.get(input_word, vocab["<pad>"])
    src = torch.tensor([[input_idx, vocab["<eos>"], 0, 0, 0]]).to(device)
    h = model.encoder(src)
    inputs = torch.tensor([[vocab["<sos>"]]]).to(device)
    output_sentence = []

    for _ in range(max_len):
        out, h = model.decoder(inputs, h)
        next_token = out.argmax(-1)[:, -1]
        word = inv_vocab[next_token.item()]
        if word == "<eos>":
            break
        output_sentence.append(word)
        inputs = next_token.unsqueeze(0)

    return " ".join(output_sentence)

# === Чат
model.eval()
with torch.no_grad():
    print("Type 'exit' to quit.")
    while True:
        user_input = input("You: ").strip().lower()
        if user_input in {"exit", "quit"}:
            print("Bot: goodbye")
            break
        response = respond_to_input(user_input)
        print("Bot:", response)


Epoch 200, Loss: 0.0032
Epoch 400, Loss: 0.0010
Epoch 600, Loss: 0.0005
Epoch 800, Loss: 0.0003
Epoch 1000, Loss: 0.0002
Type 'exit' to quit.
You: jjj
Bot: hello
You: bye
Bot: goodbye
You: exit
Bot: goodbye
