In [180]:
import torch
import numpy
import re

from tqdm.notebook import tqdm

In [181]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [182]:
# для простоты будем использовать текст в нижнем регистре.
with open('602016.txt') as f:
    text = f.read().lower()
    text = re.sub(r'ё', 'е', text)
    text = re.sub(r'[^а-я \n]', '', text)

# Рассмаотрим простую задачу генерации текста, как генерацию последовательности символов (генерация текста из ничего)

## Посмотрим на данные

In [183]:
len(set(text))

34

In [184]:
len(text)

9403794

In [185]:
# Сразу убрал слишком короткие предложения
# Слишком длинные убрал... ну просто так захотел
dataset = [sent.strip() for sent in text.split('\n') if len(sent.strip()) > 20 and len(sent.strip()) < 300 ]

len(dataset)

34073

## Строим отображение символов в индексы

In [186]:
char2idx = {'<PAD>':0, '<UNK>': 1, '<START>': 2, '<FINISH>': 3}
idx2char = {0: '<PAD>', 1: '<UNK>', 2: '<START>', 3: '<FINISH>'}
for item in list(set(text)):
    char2idx[item] = len(char2idx)
    idx2char[char2idx[item]] = item

## Строим модели

In [187]:
class Encoder(torch.nn.Module):
    def __init__(self,
                 vocab_dim = len(char2idx),
                 emb_dim = 10, 
                 hidden_dim = 10,
                 num_layers = 3,
                 bidirectional = False,
                 device=device,
                 ):
        super(Encoder, self).__init__()
        
        self.num_direction = int(bidirectional + 1)
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim

        self.embedding = torch.nn.Embedding(vocab_dim, emb_dim)

        self.encoder = torch.nn.LSTM(
            emb_dim, hidden_dim, num_layers, bidirectional = bidirectional)
        
        self.device=device
        self.to(device)
        
    def forward(self, input):
        input = self.embedding(input)

        input = torch.transpose(input, 0, 1)

        d, (h, c) = self.encoder(input)

        return d, torch.transpose(h, 0, 1) , torch.transpose(c, 0, 1)

class Decoder(torch.nn.Module):
    def __init__(self,
                 vocab_dim = len(char2idx),
                 emb_dim = 10, 
                 hidden_dim = 10,
                 output_dim = len(char2idx),
                 num_layers = 3,
                 bidirectional = False,
                 ):
        super(Decoder, self).__init__()
        
        self.num_direction = int(bidirectional + 1)
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        self.embedding = torch.nn.Embedding(vocab_dim, self.emb_dim)

        self.decoder = torch.nn.LSTM(
            emb_dim, hidden_dim, num_layers, bidirectional = bidirectional)

        self.linear = torch.nn.Linear(
            self.num_direction*hidden_dim, output_dim)
        
        self.device=device
        self.to(device)

    def forward(self, real=None, h = None, c = None, max_len = 50):
        r'''
        :param real: 
            тезор размера batch_size \times seq_len \times emb_dim
        :type real: tensor
        '''
        batch_size = 1
        if h is not None:
            batch_size = h.shape[0]
        if c is not None:
            batch_size = c.shape[0]
        if real is not None:
            batch_size = real.shape[0]


        if real is not None:
            input = self.embedding(real)

            if h is None:
                h = torch.randn(
                    (batch_size, 1, self.num_direction*self.hidden_dim)).to(
                        self.device
                    )
            if c is None:
                c = torch.randn(
                    (batch_size, 1, self.num_direction*self.hidden_dim)).to(
                        self.device
                    )

            input = torch.transpose(input, 0, 1)
            h = torch.transpose(h, 0, 1)
            c = torch.transpose(c, 0, 1)
            d, _ = self.decoder(input, (h, c))
            answers = self.linear(d)
        else:
            input = self.embedding(
                torch.tensor(
                    [[char2idx['<START>']] for _ in range(
                        batch_size)]).long().to(
                        self.device
                    )
                )

            if h is None:
                h = torch.randn(
                    (batch_size, 1, self.num_direction*self.hidden_dim)).to(
                        self.device
                    )
            if c is None:
                c = torch.randn(
                    (batch_size, 1, self.num_direction*self.hidden_dim)).to(
                        self.device
                    )

            input = torch.transpose(input, 0, 1)
            h = torch.transpose(h, 0, 1)
            c = torch.transpose(c, 0, 1)

            answers = torch.zeros(
                (max_len, input.shape[1], self.output_dim)).to(
                    self.device)
                
            for i in range(max_len):
                d, (h, c) = self.decoder(input, (h, c))
                answers[i, :, :] = self.linear(d)[0]

        return torch.transpose(answers, 0, 1)

## Генератор батчей

In [188]:
PAD = char2idx['<PAD>']
def batch_generator(dataset, batch_size=64, shuffle=True, device=device):
    X, Y = dataset[:-1], dataset[1:]
    n_samples = len(X)

# генерим список индексов
    list_of_indexes = numpy.linspace(
        0, n_samples - 1, n_samples, dtype=numpy.int64)
    List_X = []
    List_Y = []
    
# если нужно перемешать, то перемешиваем
    if shuffle:
        numpy.random.shuffle(list_of_indexes)
        

# сгенерировал список индексов, по этим индексам, 
# сделаем новый перемешаный спиисок токенов и тэгов
    for indx in list_of_indexes:
        List_X.append(X[indx])
        List_Y.append(Y[indx])
    
    n_batches = n_samples//batch_size
    if n_samples%batch_size != 0:
        n_batches+=1
        
    # For each k yield pair x and y
    for k in range(n_batches):
# указываем текущии размер батча
        this_batch_size = batch_size
    
# если мы выдаем последний батч, то его нужно обрезать
        if k == n_batches - 1:
            if n_samples%batch_size > 0:
                this_batch_size = n_samples%batch_size
                
        This_X = List_X[k*batch_size:k*batch_size + this_batch_size]
        This_Y = List_Y[k*batch_size:k*batch_size + this_batch_size]
        
        This_X_line = [
                       [char2idx.get(char, 0) for char in sent]\
                       for sent in This_X]
        This_Y_line = [
                       [char2idx.get('<START>', 0)]\
                       + [char2idx.get(char, 0) for char in sent][:20]\
                       + [char2idx.get('<FINISH>', 0)]\
                       for sent in This_Y]

        List_of_length_x = [len(sent) for sent in This_X_line]
        length_of_sentence_x = max(List_of_length_x)
        List_of_length_y = [len(sent) for sent in This_Y_line]
        length_of_sentence_y = max(List_of_length_y)

        x_arr = numpy.ones(shape=[this_batch_size, length_of_sentence_x])*PAD
        y_arr = numpy.ones(shape=[this_batch_size, length_of_sentence_y])*PAD

        for i in range(this_batch_size):
            x_arr[i, :len(This_X_line[i])] = This_X_line[i]
            y_arr[i, :len(This_Y_line[i])] = This_Y_line[i]

        x = torch.LongTensor(x_arr).to(device)
        y = torch.LongTensor(y_arr).to(device)
        lengths = torch.LongTensor(List_of_length_x).to(device)

        yield x, y

## Скрипты обучения

In [189]:
def train_on_batch(model, batch_of_x, batch_of_y, optimizer, loss_function):
    encoder, decoder = model
    encoder.train()
    decoder.train()
    encoder.zero_grad()
    decoder.zero_grad()
    
    d, h, c = encoder(batch_of_x)
    output = decoder(
        batch_of_y, 
        h=h[:, -decoder.num_layers:, :], 
        c=c[:, -decoder.num_layers:, :])

    loss = loss_function(output[:, :-1, :].transpose(1, 2), batch_of_y[:, 1:])
    
    loss.backward()
    optimizer.step()
    return loss.item()
    
def train_epoch(train_generator, model, loss_function, optimizer):
    epoch_loss = 0
    total = 0
    for it, (batch_of_x, batch_of_y) in enumerate(train_generator):
        local_loss = train_on_batch(
            model, batch_of_x, batch_of_y, optimizer, loss_function)
        train_generator.set_postfix({'train batch loss': local_loss})

        epoch_loss += local_loss*len(batch_of_x)
        total += len(batch_of_x)
    
    return epoch_loss/total

def trainer(count_of_epoch, 
            batch_size,
            model,
            dataset,
            loss_function,
            optimizer,
           ):
    iterations = tqdm(range(count_of_epoch))

    for it in iterations:
        optima = optimizer

        number_of_batch = len(dataset)//batch_size + (len(dataset)%batch_size>0)
        generator = tqdm(
            batch_generator(dataset, batch_size, device=device), 
            leave=False, total=number_of_batch)
        
        epoch_loss = train_epoch(
            train_generator = generator, model = model, 
            loss_function = loss_function, 
            optimizer = optima)
        
        encoder, decoder = model
        encoder.eval()
        decoder.eval()
        
        sent = dataset[5].lower()
        x = torch.LongTensor([[char2idx[char] for char in sent]]).to(device)
        d, h, c = encoder(x)

        result =''.join(
          [idx2char[index] for index in torch.argmax(
              decoder(
                  h=h[:, -decoder.num_layers:, :], 
                  c=c[:, -decoder.num_layers:, :], 
                  max_len=5), dim=-1).detach().cpu().numpy()[0]])

        iterations.set_postfix({'train epoch loss': epoch_loss, 
                                'example': result})
    return



## Обучение моделей

In [190]:
encoder = Encoder(num_layers=1)
decoder = Decoder(num_layers=1)

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()), lr=1e-5)
loss_function = torch.nn.CrossEntropyLoss(ignore_index=char2idx['<PAD>'])

In [None]:
trainer(count_of_epoch = 100,
        batch_size = 64,
        model = (encoder, decoder),
        dataset = dataset,
        loss_function = loss_function,
        optimizer = optimizer,
       )

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=533.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=533.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=533.0), HTML(value='')))