In [40]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import pickle
import time
import random
import nltk

In [41]:

nltk.download('punkt_tab')

def tokenize(text):
    return nltk.word_tokenize(text.lower())

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\nikolaev_aa\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


## Загрузка пар

In [42]:

with open("got_pairs.pkl", "rb") as f:
    pairs = pickle.load(f)

print(f"Загружено {len(pairs)} пар реплик.")
print("Пример:", pairs[24576])

Загружено 43787 пар реплик.
Пример: ('when my sister arranged this love match', 'did she mention that lollys has an older sister')


## Построение словаря

In [43]:
all_words = []
for src, tgt in pairs:
    all_words.extend(tokenize(src))
    all_words.extend(tokenize(tgt))

word_counts = Counter(all_words)
vocab = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"] + [
    word for word, count in word_counts.items() if count > 1  # фильтр редких слов
]

word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
vocab_size = len(vocab)

print(f"Размер словаря: {vocab_size}")

Размер словаря: 9125


## Dataset

In [44]:
def words_to_indices(words, word_to_idx, max_len=20):
    indices = [word_to_idx.get(w, word_to_idx["<UNK>"]) for w in words]
    indices = indices[:max_len - 1]  # оставляем место для <EOS>
    indices.append(word_to_idx["<EOS>"])
    return indices + [word_to_idx["<PAD>"]] * (max_len - len(indices))

class TranscriptDataset(Dataset):
    def __init__(self, pairs, word_to_idx, max_len=20):
        self.pairs = pairs
        self.word_to_idx = word_to_idx
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src_text, tgt_text = self.pairs[idx]
        src_tokens = tokenize(src_text)
        tgt_tokens = tokenize(tgt_text)

        src_indices = words_to_indices(src_tokens, self.word_to_idx, self.max_len)
        tgt_indices = words_to_indices(tgt_tokens, self.word_to_idx, self.max_len)

        return torch.tensor(src_indices), torch.tensor(tgt_indices)

## Модель (Encoder + Decoder)

In [45]:
def create_encoder(vocab_size, embed_size, hidden_size):
    embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
    lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
    return embedding, lstm

def create_decoder(vocab_size, embed_size, hidden_size):
    embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
    lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
    fc = nn.Linear(hidden_size, vocab_size)
    return embedding, lstm, fc

def seq2seq_forward(enc_emb, enc_lstm, dec_emb, dec_lstm, dec_fc, src, trg, teacher_forcing_ratio=0.5):
    batch_size, trg_len = trg.shape
    vocab_size = dec_fc.out_features

    # Encoder
    enc_emb_out = enc_emb(src)
    enc_out, (hidden, cell) = enc_lstm(enc_emb_out)

    # Первый вход декодера — <SOS>
    dec_input = torch.full((batch_size, 1), word_to_idx["<SOS>"], dtype=torch.long)
    dec_outputs = torch.zeros(batch_size, trg_len, vocab_size)

    for t in range(trg_len):
        dec_emb_out = dec_emb(dec_input)
        dec_out, (hidden, cell) = dec_lstm(dec_emb_out, (hidden, cell))
        pred = dec_fc(dec_out.squeeze(1))
        dec_outputs[:, t, :] = pred

        # Teacher forcing
        use_teacher_forcing = random.random() < teacher_forcing_ratio
        if use_teacher_forcing and t < trg_len - 1:
            dec_input = trg[:, t + 1].unsqueeze(1)
        else:
            top1 = pred.argmax(1)
            dec_input = top1.unsqueeze(1)

    return dec_outputs

## Обучение

In [None]:
device = torch.device("cpu") 
print("Device:", device)


EPOCHS = 10
BATCH_SIZE = 32
MAX_LEN = 20
EMBED_SIZE = 128
HIDDEN_SIZE = 256

dataset = TranscriptDataset(pairs, word_to_idx, max_len=MAX_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

enc_emb, enc_lstm = create_encoder(vocab_size, EMBED_SIZE, HIDDEN_SIZE)
dec_emb, dec_lstm, dec_fc = create_decoder(vocab_size, EMBED_SIZE, HIDDEN_SIZE)

enc_emb.to(device)
enc_lstm.to(device)
dec_emb.to(device)
dec_lstm.to(device)
dec_fc.to(device)


params = list(enc_emb.parameters()) + list(enc_lstm.parameters()) + \
         list(dec_emb.parameters()) + list(dec_lstm.parameters()) + list(dec_fc.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx["<PAD>"])


for epoch in range(EPOCHS):
    start_time = time.time()
    total_loss = 0
    for src, trg in dataloader:
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()

        output = seq2seq_forward(enc_emb, enc_lstm, dec_emb, dec_lstm, dec_fc, src, trg)
        # output: [B, T, V], trg: [B, T]
        output = output[:, :-1, :].reshape(-1, vocab_size)
        target = trg[:, 1:].reshape(-1)

        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | Time: {elapsed:.2f}s")

model_state = {
    "encoder_embedding": enc_emb.state_dict(),
    "encoder_lstm": enc_lstm.state_dict(),
    "decoder_embedding": dec_emb.state_dict(),
    "decoder_lstm": dec_lstm.state_dict(),
    "decoder_fc": dec_fc.state_dict(),
    "word_to_idx": word_to_idx,
    "idx_to_word": idx_to_word,
    "embed_size": EMBED_SIZE,
    "hidden_size": HIDDEN_SIZE,
    "vocab_size": vocab_size,
    "max_len": MAX_LEN,
}
torch.save(model_state, "chatbot_model.pth")
print("✅ Модель сохранена в chatbot_model.pth")

Device: cpu
Epoch 1/10 | Loss: 5.6867 | Time: 217.49s
Epoch 2/10 | Loss: 5.3648 | Time: 221.88s
Epoch 3/10 | Loss: 5.2487 | Time: 219.20s
Epoch 4/10 | Loss: 5.1614 | Time: 225.65s
Epoch 5/10 | Loss: 5.0797 | Time: 225.50s
Epoch 6/10 | Loss: 5.0258 | Time: 225.16s
Epoch 7/10 | Loss: 4.9727 | Time: 218.70s


In [47]:
print("Топ-10 самых частых слов:")
for word, idx in list(word_to_idx.items())[:15]:
    print(f"{word}: {idx}")

Топ-10 самых частых слов:
<PAD>: 0
<SOS>: 1
<EOS>: 2
<UNK>: 3
easy: 4
boy: 5
what: 6
do: 7
you: 8
expect: 9
they: 10
re: 11
savages: 12
one: 13
lot: 14
