In [1]:
import torch
from torch import nn
from torch.optim import Adam
from model import load_models, Encoder, AttentionDecoder, EMBEDDING_SIZE
import random
from prepared import load_voc, batch2train_data
import os
import matplotlib.pyplot as plt
import numpy as np

In [2]:
data, Vocabulary = load_voc()

In [3]:
device = "cuda:0"
batch_size = 3

In [4]:
Vocabulary.num_words

35691

In [5]:
embedding = nn.Embedding(2**16, EMBEDDING_SIZE).to(device)
encoder = Encoder(embedding).to(device)
decoder = AttentionDecoder(embedding).to(device)

In [6]:
# encoder, decoder, embedding = load_models()

encoder_optim = Adam(encoder.parameters(), lr=1e-4)
decoder_optim = Adam(decoder.parameters(), lr=5e-4)

encoder.epochs

0

In [7]:
def calculate_loss(inp, target, mask):

    loss = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = loss.masked_select(mask).mean()
    loss = loss.to(device)

    return loss

In [8]:
def train(epochs):
    # для вывода графика
    history = []
    short_mem = []

    for _ in range(epochs):
        # всё стандартно
        encoder_optim.zero_grad()
        decoder_optim.zero_grad()

        inp, lenghts, target, mask, max_target_len = batch2train_data([random.choice(data) \
                                                                       for _ in range(batch_size)])

        inp = inp.to(device)
        lenghts = lenghts.to(device)
        target = target.to(device)
        mask = mask.to(device)
        # провожу через енкодинг
        encoder_out, encoder_hidden = encoder(inp, lenghts)
        # начальное значение для работы декодера
        decoder_input = torch.ones(batch_size).long().to(device).unsqueeze(0)

        loss = 0

        for i in range(max_target_len):
            decoder_out, decoder_hidden = decoder(decoder_input, encoder_hidden[:2], encoder_out)

            decoder_input = torch.LongTensor([[decoder_out.topk(1)[1][x][0] for\
                                               x in range(batch_size)]]).to(device)

            loss += calculate_loss(decoder_out, target[i], mask[i])

        short_mem.append(loss)

        if not encoder.epochs % 500:
            print(f'{encoder.epochs}  {loss}')
            history.append(torch.tensor(short_mem).mean())
            short_mem = []

        loss.backward()

        encoder_optim.step()
        decoder_optim.step()

        encoder.epochs += 1

        # save model

        if not encoder.epochs % 1000:

            torch.save(encoder, f"models/encoder{encoder.epochs}")
            torch.save(decoder, f"models/decoder{encoder.epochs}")
            torch.save(embedding, f"models/embedding{encoder.epochs}")

    return history

In [None]:
%%time
history = train(10000) 

0  476.6577453613281
500  858.7587280273438
1000  230.8096160888672


In [None]:
encoder.epochs

In [None]:
plt.plot(history)
