In [1]:
import torch
from torch import nn
from torch.optim import Adam
from model import load_models, Encoder, AttentionDecoder, EMBEDDING_SIZE
import random
from prepared import load_voc, batch2train_data, input_var, indexesFromSentence
import matplotlib.pyplot as plt
import string

In [2]:
data, Vocabulary = load_voc()

In [3]:
device = "cuda:0"
batch_size = 2

In [4]:
Vocabulary.num_words

35691

In [5]:
# embedding = nn.Embedding(2**16, EMBEDDING_SIZE).to(device)
# encoder = Encoder(embedding).to(device)
# decoder = AttentionDecoder(embedding).to(device)

In [6]:
encoder, decoder, embedding = load_models()

encoder_optim = Adam(encoder.parameters(), lr=1e-5)
decoder_optim = Adam(decoder.parameters(), lr=5e-5)

encoder.epochs

45000

In [7]:
def calculate_loss(inp, target, mask):

    loss = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = loss.masked_select(mask).mean()
    loss = loss.to(device)

    return loss

In [8]:
def train(epochs):
    # для вывода графика
    history = []
    short_mem = []

    for _ in range(epochs):
        # всё стандартно
        encoder_optim.zero_grad()
        decoder_optim.zero_grad()

        inp, lenghts, target, mask, max_target_len = batch2train_data([random.choice(data) \
                                                                       for _ in range(batch_size)])

        inp = inp.to(device)
        lenghts = lenghts.to(device)
        target = target.to(device)
        mask = mask.to(device)
        # провожу через енкодинг
        encoder_out, encoder_hidden = encoder(inp, lenghts)
        # начальное значение для работы декодера
        decoder_input = torch.ones(batch_size).long().to(device).unsqueeze(0)

        loss = 0

        decoder_hidden = encoder_hidden[:2]

        for i in range(max_target_len):
            decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_out)

            decoder_input = torch.LongTensor([[decoder_out.topk(1)[1][x][0] for\
                                               x in range(batch_size)]]).to(device)

            loss += calculate_loss(decoder_out, target[i], mask[i])

        short_mem.append(loss)

        if not encoder.epochs % 1000:
            print(f'{encoder.epochs}  {loss}')
            history.append(torch.tensor(short_mem).mean())
            short_mem = []

        loss.backward()

        encoder_optim.step()
        decoder_optim.step()

        encoder.epochs += 1

        # save model

        if not encoder.epochs % 1000:

            torch.save(encoder, f"models/encoder{encoder.epochs}")
            torch.save(decoder, f"models/decoder{encoder.epochs}")
            torch.save(embedding, f"models/embedding{encoder.epochs}")

    return history

In [None]:
%%time
history = train(50000)

45000  57.240234375
46000  137.4744873046875
47000  297.71710205078125
48000  816.4173583984375


In [None]:
encoder.epochs

In [None]:
plt.plot(history)

In [18]:
def greedy_search(sequence, length, maximum=30):
    # всё также как на тренировке
    encoder_out, encoder_hidden = encoder(sequence, length)

    decoder_hidden = encoder_hidden[:2]
    decoder_input = torch.ones(1, 1, device=device, dtype=torch.long)
    # здесь складываються ответы по жадному методу
    all_tokens = torch.zeros([0], device=device, dtype=torch.long)
    # всего будет 30 оборотов, при выводе отрежуться ненужные токены, а 30, чтобы больше не было
    for _ in range(maximum):

        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_out)

        _, decoder_input = torch.max(decoder_output, dim=1)

        all_tokens = torch.cat((all_tokens, decoder_input), dim=0)

        decoder_input = decoder_input.unsqueeze(0)

    return all_tokens

In [20]:
clean = lambda text: "".join(x for x in text.lower() if x not in string.punctuation)

In [None]:
while "Матеша идёт":

    text = input('Я > ')
    print(text)
    # ввод текста и очистка
    text = clean(text)

    # перевод в форму для енкодера
    text = torch.tensor([indexesFromSentence(text)])
    length = torch.tensor([len(text[0])])

    text = text.transpose(0, 1).long().to(device)

    text = greedy_search(text, length)

    text = [Vocabulary.index2word[token.item()] for token in text]

    text[:] = [x for x in text if not (x == 'EOS' or x == "PAD")]

    print(f'Машина > {" ".join(text)}')

Привет
Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             
Как дела?
Машина > а                             

Машина > а                             

Машина > а                             

Машина > а                             
