## Импорт необходимых зависимостей

In [1]:
import pandas as pd
import nltk
import torch
import torch.nn as nn
import torch.optim
import numpy as np
import time
import pickle

from random import random, sample
from typing import List
from collections import Counter
from itertools import chain
from functools import reduce
from tqdm.auto import tqdm
from sklearn import model_selection
from torch.utils.data import DataLoader, TensorDataset

In [2]:
RANDOM_STATE = 1

## Подготовка данных

In [3]:
df = pd.read_csv('./data/russian_literature/processed/dataset.csv')

In [4]:
df

Unnamed: 0,orig_texts,lemm_texts
0,вообще он опасался правил.,вообще он опасаться правило.
1,кто сможет ему родить?,кто смочь он родить?
2,вагон был совершенно гнилой.,вагон быть совершенно гнилой.
3,это оказалась станция железнодорожная.,это оказаться станция железнодорожный.
4,и тестирования весьма неожиданного.,и тестирование весьма неожиданный.
...,...,...
27889,"лаудон же, вместо того чтобы оборонять колонии...","лаудон же, вместо тот чтобы оборонять колония ..."
27890,я с готовностью оказал ей несколько маленьких ...,я с готовность оказать она несколько маленький...
27891,"это конечно особая тема, и она требует самого ...","это конечно особый тема, и она требовать сам п..."
27892,"конечно, если вы смогли вскарабкаться на верши...","конечно, если вы смочь вскарабкаться на вершин..."


### Определение классов словаря и трансформера текста

In [5]:
class Vocab:
    def __init__(self, tokens: List[str], unk_idx: int):
        self._tokens = tokens
        self._token_to_idx = {token: idx for idx, token in enumerate(tokens)}
        self._unk_idx = unk_idx
        
    def token_to_idx(self, token: str) -> int:
        return self._token_to_idx.get(token, self._unk_idx)
    
    def idx_to_token(self, idx: int) -> str:
        return self._tokens[idx]

In [6]:
class TextTransformer:
    def __init__(self, vocab_size: int):
        self.vocab = None
        self.vocab_size = vocab_size
        self.special_tokens_to_idx = {'<UNK>': 0, '<PAD>': 1, '<SOS>': 2, '<EOS>': 3}
        self._tokenizer = nltk.tokenize.wordpunct_tokenize
    
    def tokenize(self, text) -> List[str]:
        return self._tokenizer(text.lower())
    
    def build_vocab(self, tokens: List[str]):
        tokens_ = [special_token for special_token in self.special_tokens_to_idx.keys()]
        special_tokens_amount = len(self.special_tokens_to_idx)
        
        for token, _ in Counter(tokens).most_common(self.vocab_size - special_tokens_amount):
            tokens_.append(token)
        
        unk_idx = self.special_tokens_to_idx.get('<UNK>')
        self.vocab = Vocab(tokens_, unk_idx)
        
    def transform_text(self, text: str) -> List[int]:
        tokenized_text = self.tokenize(text)
        transformed = [self.vocab.token_to_idx(token) for token in tokenized_text]
        return transformed
    
    def fit(self, texts: List[str]) -> None:
        transformed_texts = []
        
        tokenized_texts = [self.tokenize(text) for text in tqdm(texts, 'Tokenizing texts')]
        tokens = chain(*tokenized_texts)
        self.build_vocab(tokens)
        
        for tokenized_text in tqdm(tokenized_texts, 'Transforming texts'):
            transformed = [self.vocab.token_to_idx(token) for token in tokenized_text]
            transformed_texts.append(transformed)
    
    def transform_texts(self, texts: List[str]) -> List[List[int]]:
        transformed_texts = [transform_text(text) for text in tqdm(texts, 'Transforming texts')]
        return transformed_texts
    
    def text_to_tensor(self, text: str, max_seq_len) -> torch.tensor:
        transformed_text = self.transform_text(text)
        pad_idx = self.special_tokens_to_idx.get('<PAD>')
        sos_idx = self.special_tokens_to_idx.get('<SOS>')
        eos_idx = self.special_tokens_to_idx.get('<EOS>')
        
        pad_size = 0
        if len(transformed_text) >= max_seq_len:
            transformed_text = transformed_text[:max_seq_len]
        else:
            pad_size = max_seq_len - len(transformed_text)
            transformed_text.extend([pad_idx] * pad_size)   
        transformed_text.insert(0, sos_idx)
        transformed_text.insert(len(transformed_text) - pad_size, eos_idx)
        
        tensor = torch.tensor(transformed_text, dtype=torch.long)
        return tensor.unsqueeze(0)
    
    def texts_to_tensor(self, texts: List[str], max_seq_len) -> torch.tensor:
        pad_idx = self.special_tokens_to_idx.get('<PAD>')
        sos_idx = self.special_tokens_to_idx.get('<SOS>')
        eos_idx = self.special_tokens_to_idx.get('<EOS>')
        transformed_texts = []
        
        for text in tqdm(texts, 'Building tensor'):
            transformed_text = self.transform_text(text)
            pad_size = 0
            if len(transformed_text) >= max_seq_len:
                transformed_text = transformed_text[:max_seq_len]
            else:
                pad_size = max_seq_len - len(transformed_text)
                transformed_text.extend([pad_idx] * pad_size)   
            transformed_text.insert(0, sos_idx)
            transformed_text.insert(len(transformed_text) - pad_size, eos_idx)
            transformed_texts.append(transformed_text)
        
        tensor = torch.tensor(transformed_texts, dtype=torch.long).permute(1, 0)
        return tensor

### Разбиение данных на обучающую, тестовую и валидационную выборки

In [7]:
train_df, test_df = model_selection.train_test_split(df, test_size=0.1, random_state=RANDOM_STATE)

In [8]:
test_df, val_df = model_selection.train_test_split(test_df, test_size=0.5, random_state=RANDOM_STATE)

### Токенизация текстов и индексация токенов

In [9]:
lemm_vocab_size = 23000
orig_vocab_size = 65000
max_seq_len = 50

In [10]:
lemm_text_transformer = TextTransformer(lemm_vocab_size)
orig_text_transformer = TextTransformer(orig_vocab_size)

In [11]:
lemm_text_transformer.fit(train_df.lemm_texts)

Tokenizing texts:   0%|          | 0/25104 [00:00<?, ?it/s]

Transforming texts:   0%|          | 0/25104 [00:00<?, ?it/s]

In [12]:
orig_text_transformer.fit(train_df.orig_texts)

Tokenizing texts:   0%|          | 0/25104 [00:00<?, ?it/s]

Transforming texts:   0%|          | 0/25104 [00:00<?, ?it/s]

### Перевод данных в тензоры

In [13]:
train_lemm_tensor = lemm_text_transformer.texts_to_tensor(train_df.lemm_texts.to_list(), max_seq_len)
test_lemm_tensor = lemm_text_transformer.texts_to_tensor(test_df.lemm_texts.to_list(), max_seq_len)
val_lemm_tensor = lemm_text_transformer.texts_to_tensor(val_df.lemm_texts.to_list(), max_seq_len)

Building tensor:   0%|          | 0/25104 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/1395 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/1395 [00:00<?, ?it/s]

In [14]:
train_orig_tensor = orig_text_transformer.texts_to_tensor(train_df.orig_texts.to_list(), max_seq_len)
test_orig_tensor = orig_text_transformer.texts_to_tensor(test_df.orig_texts.to_list(), max_seq_len)
val_orig_tensor = orig_text_transformer.texts_to_tensor(val_df.orig_texts.to_list(), max_seq_len)

Building tensor:   0%|          | 0/25104 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/1395 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/1395 [00:00<?, ?it/s]

In [15]:
def cut_to_fit_batch(tensor: torch.Tensor, batch_size: int):
    n_samples = tensor.shape[1]
    new_n_samples = (n_samples // batch_size) * batch_size
    result = tensor.split(new_n_samples, dim=1)[0]
    return torch.transpose(result, 1, 0)

## Построение модели

In [16]:
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, hidden_size: int, pad_idx: int,
                 device, num_layers, dropout_p: float):
        super(EncoderRNN, self).__init__()
        
        self.device = device
        self.num_layers = num_layers
        
        self.hidden_size = hidden_size
        
        self.embedding = nn.Sequential(
            nn.Embedding(vocab_size, embedding_size, pad_idx),
            nn.Dropout(dropout_p)
        )
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=0.0, bidirectional=True)
        self.fc_compressor_hidden = nn.Linear(hidden_size * 2, hidden_size)
        self.fc_compressor_cell = nn.Linear(hidden_size * 2, hidden_size)
        
    def forward(self, x):
        # x_shape: (seq_len, batch_size)
        embedding = self.embedding(x)
        # embedding_shape: (seq_len, batch_size, embedding_size)
        encoder_states, (hidden, cell) = self.lstm(embedding)
        # encoder_states: (seq_len, batch_size, hidden_size * 2)
        # hidden_shape: (num_layers=1 * 2, batch_size, hidden_size)
        # cell_shape: (num_layers=1 * 2, batch_size, hidden_size)
        
        bi_hidden = torch.cat((hidden[0], hidden[1]), dim=1).unsqueeze(0).permute(1, 0, 2)
        bi_cell = torch.cat((cell[0], cell[1]), dim=1).unsqueeze(0).permute(1, 0, 2)
        # bi_hidden, bi_cell shapes: (batch_size, 1, hidden_size * 2)
        
        hidden_compressed = self.fc_compressor_hidden(bi_hidden).permute(1, 0, 2)
        cell_compressed = self.fc_compressor_hidden(bi_cell).permute(1, 0, 2)
        # hidden_compressed, cell_compressed shapes: (1, batch_size, hidden_size)
        
        return encoder_states, hidden_compressed, cell_compressed
    
    def init_hidden_state(self, batch_size: int):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)
        return hidden, cell

In [17]:
class DecoderRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, hidden_size: int, output_size: int, pad_idx: int,
                 device, num_layers, dropout_p: float):
        super(DecoderRNN, self).__init__()
        
        self.device = device
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Sequential(
            nn.Embedding(vocab_size, embedding_size, pad_idx),
            nn.Dropout(dropout_p)
        )
        self.attn_weights = nn.Sequential(
            nn.Linear(hidden_size * 3, hidden_size, bias=False),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False),
            nn.Softmax(dim=1)
        )
        self.lstm = nn.LSTM(embedding_size + 2 * hidden_size, hidden_size, num_layers, dropout=0.0)
        self.fc_out = nn.Linear(hidden_size, output_size)
    
    def forward(self, x, encoder_states, hidden, cell):
        x = x.unsqueeze(0)
        # x_shape: (seq_len=1, batch_size)
        # hidden_shape: (1, batch_size, hidden_size)
        # cell_shape: (1, batch_size, hidden_size)
        encoder_states = torch.transpose(encoder_states, 1, 0)
        # encoder_states_shape: (batch_size, seq_len, hidden_size * 2)
        embedding = self.embedding(x)
        # embedding_shape: (seq_len=1, batch_size, embedding_size)
        
        seq_len = encoder_states.shape[1]
        hidden_repeated = hidden.repeat(seq_len, 1, 1).permute(1, 0, 2)
        # hidden_repeated_shape: (batch_size, seq_len, hidden_size)
        
        attn_weights = self.attn_weights(torch.cat((hidden_repeated, encoder_states), dim=2))
        # attn_weights_shape: (batch_size, seq_len, 1)
        
        context_vec = torch.bmm(attn_weights.permute(0, 2, 1), encoder_states).permute(1, 0, 2)
        # context_vec_shape: (1, batch_size, hidden_size * 2)
        
        combined = torch.cat((embedding, context_vec), dim=2)
        # combined_shape: (1, batch_size, embedding_size + 2 * hidden_size)
        
        lstm_out, (hidden, cell) = self.lstm(combined, (hidden, cell))
        # lstm_out_shape: (seq_len=1, batch_size, hidden_size)
        fc_out = self.fc_out(lstm_out)
        # fc_out_shape: (seq_len=1, batch_size, output_size)
        
        return fc_out, hidden, cell

In [18]:
class Seq2SeqModel(nn.Module):
    def __init__(self, encoder_vocab_size: int, decoder_vocab_size: int, embedding_size: int, hidden_size: int, output_size: int,
                 pad_idx: int, device, num_layers, dropout_p: float):
        super(Seq2SeqModel, self).__init__()
        
        self.device = device
        
        self.encoder = EncoderRNN(encoder_vocab_size, embedding_size, hidden_size, pad_idx, device, num_layers, dropout_p).to(device)
        self.decoder = DecoderRNN(decoder_vocab_size, embedding_size, hidden_size, output_size, pad_idx, device, num_layers, dropout_p).to(device)
        self.decoder_vocab_size = decoder_vocab_size
        
    def forward(self, input, target, teacher_forcing_ratio=0.5):
        batch_size = input.shape[1]
        target_len = target.shape[0]
        target_vocab_size = self.decoder_vocab_size
        
        outputs = torch.zeros(target_len, batch_size, target_vocab_size, device=self.device)
        
        encoder_states, hidden, cell = self.encoder(input)
        # hidden, cell shapes: (num_layers * 2, batch_size, hidden_size)
        
        prev_token_idx = target[0]
        # prev_token_shape: (batch_size)
        
        for t in range(1, target_len):
            output, hidden, cell = self.decoder(prev_token_idx, encoder_states, hidden, cell)
            # output_shape: (1, batch_size, output_size)
            outputs[t] = output.squeeze(0)
            
            best_prediction = outputs[t].argmax(dim=1)
            # best_prediction_shape: (batch_size)
            prev_token_idx = target[t] if random() < teacher_forcing_ratio else best_prediction
        
        return outputs

## Обучение модели

### Функция сохранения текущего состояния модели

In [19]:
def save_model(model, optimizer, epoch, path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'criterion': criterion,
        'epoch': epoch
    }
    
    torch.save(checkpoint, path)

### Функция загрузки уже тренировавшейся модели

In [20]:
def load_model(model, optimizer, criterion, path, for_inference=True, device=torch.device('cpu')):
    checkpoint = torch.load(path, map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    
    if not for_inference:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        criterion = checkpoint['criterion']

        return epoch

### Инициализация гиперпараметров

In [21]:
learning_rate = 0.001
batch_size = 16
epochs_amount = 50
hidden_size = 512
embedding_size = 300
num_layers = 1
max_norm = 1.0
dropout_p = 0.5
patience = 5
output_size = orig_vocab_size
pad_idx = lemm_text_transformer.special_tokens_to_idx.get('<PAD>')
model_path = './models/'
model_name = 'seq2seq_attention.model'

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [23]:
model = Seq2SeqModel(lemm_vocab_size, orig_vocab_size, embedding_size, hidden_size, output_size, pad_idx, device, num_layers, dropout_p).to(device)

In [24]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [25]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [26]:
try:
    epoch = load_model(model, optimizer, criterion, model_path + model_name, for_inference=False)
    print(f'Loaded model from {model_path}')
except:
    print(f'No models found at {model_path}')
    epoch = 1

No models found at ./


### Урезание данных для соответствия размеру батча

In [27]:
train_lemm_tensor_f = cut_to_fit_batch(train_lemm_tensor, batch_size)
train_orig_tensor_f = cut_to_fit_batch(train_orig_tensor, batch_size)

test_lemm_tensor_f = cut_to_fit_batch(test_lemm_tensor, batch_size)
test_orig_tensor_f = cut_to_fit_batch(test_orig_tensor, batch_size)

val_lemm_tensor_f = cut_to_fit_batch(val_lemm_tensor, batch_size)
val_orig_tensor_f = cut_to_fit_batch(val_orig_tensor, batch_size)

### Инициализация данных итерируемых по батчам

In [28]:
train_dataset = TensorDataset(train_lemm_tensor_f, train_orig_tensor_f)
test_dataset = TensorDataset(test_lemm_tensor_f, test_orig_tensor_f)
val_dataset = TensorDataset(val_lemm_tensor_f, val_orig_tensor_f)

In [29]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

### Определение функции проверки работы сети между эпохами обучения

In [30]:
def test_evaluate(model, input, target_len):
    input = input.to(device)
    sos_idx = lemm_text_transformer.special_tokens_to_idx.get('<SOS>')
    eos_idx = lemm_text_transformer.special_tokens_to_idx.get('<EOS>')
    
    with torch.no_grad():
        model.eval()
        hidden, cell = model.encoder(input)
        
        predicted_indexes = [sos_idx]
        
        for _ in range(1, target_len):
            prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)
            
            output, hidden, cell = model.decoder(prev_idx, hidden, cell)
            output = output.squeeze(0)
            
            best_prediction = output.argmax(dim=1).item()
            
            if best_prediction == eos_idx:
                break
                
            predicted_indexes.append(best_prediction)
                        
        
    predicted_tokens = [orig_text_transformer.vocab.idx_to_token(idx) for idx in predicted_indexes]
    return predicted_tokens[1:]

### Определение функции обучения сети

In [31]:
def train(model, optimizer, criterion, train_data, val_data, test_data, epochs_amount, max_norm, patience=2, current_epoch=1, n_prints=5):
    min_mean_val_loss = float('+inf')
    initial_patiece = patience
    print_every = len(train_data) // n_prints
    
    for epoch in tqdm(range(current_epoch, epochs_amount + 1), 'Epochs'):
        print(f'\nEpoch [{epoch} / {epochs_amount}]')
        
        model.train()
        for iteration, (input, target) in enumerate(tqdm(train_data, 'Epoch training iterations')):
            optimizer.zero_grad()
            # input = lemm_texts, target = orig_texts
            input = torch.transpose(input, 1, 0).to(device)
            # input_shape: (seq_len, batch_size)
            target = torch.transpose(target, 1, 0).to(device)
            # target_shape: (seq_len, batch_size)
            output = model(input, target)
            # output_shape: (seq_len, batch_size, orig_vocab_size) but need (N, orig_vocab_size)
            target = target[1:].reshape(-1)
            # now target_shape is (seq_len * batch_size)
            orig_vocab_size = output.shape[2]
            output = output[1:].reshape(-1, orig_vocab_size)
            # now output_shape is (seq_len * batch_size, orig_vocab_size)
            
            loss = criterion(output, target)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            
            optimizer.step()
            
            if iteration % print_every == 0:
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            elif iteration == len(train_data):
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            
            
        with torch.no_grad():
            model.eval()
            val_loss = []
            
            for input, target in tqdm(val_data, 'Epoch validating iterations'):
                input = torch.transpose(input, 1, 0).to(device)
                target = torch.transpose(target, 1, 0).to(device)
                
                output = model(input, target)
                orig_vocab_size = output.shape[2]
                output = output[1:].reshape(-1, orig_vocab_size)
                target = target[1:].reshape(-1)
                
                val_loss.append(criterion(output, target).item())
            
            mean_val_loss = sum(val_loss) / len(val_loss)
            print(f'\tValidation loss = {mean_val_loss}')
            if mean_val_loss < min_mean_val_loss:
                try:
                    save_model(model, optimizer, epoch, model_path + model_name)
                    min_mean_val_loss = mean_val_loss
                    patience = initial_patiece
                except Exception as exc:
                    print(exc)
            else:
                patience -= 1
            
            test_data = DataLoader(test_data.dataset, batch_size=1, shuffle=True)
            for input, target in test_data:
                target = target.squeeze(0).to(device)
                
                input = torch.transpose(input, 1, 0)
                target_len = target.shape[0]
                
                output = test_evaluate(model, input, target_len)
                decoded_input = [lemm_text_transformer.vocab.idx_to_token(idx.item()) for idx in input]
                decoded_target = [orig_text_transformer.vocab.idx_to_token(idx.item()) for idx in target]
                
                print(f'\tInput: {decoded_input}')
                print(f'\tOutput: {output}')
                print(f'\tTarget: {decoded_target}')
                break
        
        if patience == 0:
            print(f'\nModel learning finished due to early stopping')
            break


### Определение функции эксплуатации обученной модели

In [32]:
def evaluate(model: Seq2SeqModel, sentence: str, max_seq_len):
    input_tensor = lemm_text_transformer.text_to_tensor(sentence, max_seq_len).to(device)
    input_tensor = torch.transpose(input_tensor, 1, 0)
    sos_idx = lemm_text_transformer.special_tokens_to_idx.get('<SOS>')
    eos_idx = lemm_text_transformer.special_tokens_to_idx.get('<EOS>')
    
    with torch.no_grad():
        model.eval()
        hidden, cell = model.encoder(input_tensor)
        
        predicted_indexes = [sos_idx]
        
#         while True:
#             prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)
            
#             output, hidden, cell = model.decoder(prev_idx, hidden, cell)
#             output = output.squeeze(0)
            
#             best_prediction = output.argmax(dim=1).item()
            
#             if best_prediction == eos_idx:
#                 break
            
#             predicted_indexes.append(best_prediction)
                       
        
        for _ in range(1, max_seq_len):
            prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)
            
            output, hidden, cell = model.decoder(prev_idx, hidden, cell)
            output = output.squeeze(0)
            
            best_prediction = output.argmax(dim=1).item()
            
            if best_prediction == eos_idx:
                break
                
            predicted_indexes.append(best_prediction)
        
    predicted_tokens = [orig_text_transformer.vocab.idx_to_token(idx) for idx in predicted_indexes]
    return predicted_tokens[1:]

In [33]:
train(model, optimizer, criterion, train_loader, val_loader, test_loader, epochs_amount, max_norm, patience, epoch)

Epochs:   0%|          | 0/50 [00:00<?, ?it/s]


Epoch [1 / 50]


Epoch training iterations:   0%|          | 0/1569 [00:00<?, ?it/s]

	Iteration #0: training loss = 11.095197677612305


KeyboardInterrupt: 

In [None]:
# import gc
# del model
# del optimizer
# gc.collect()
# torch.cuda.empty_cache()
# gc.collect()

In [None]:
test_sample = test_df.sample(10)
test_input = test_sample.lemm_texts.to_list()
test_target = test_sample.orig_texts.to_list()
test_pair = list(zip(test_input, test_target))

In [None]:
for input_sentence, target_sentence in test_pair:
    model_output = evaluate(model, input_sentence, max_seq_len)
    print(f'Input: {input_sentence}')
    print(f'Output: {model_output}')
    print(f'Target: {target_sentence}')
    print('\n')