## Импорт необходимых зависимостей

In [1]:
import pandas as pd
import nltk
import torch
import torch.nn as nn
import torch.optim
import numpy as np
import time
import pickle

from random import random, sample
from typing import List
from collections import Counter
from itertools import chain
from functools import reduce
from tqdm.auto import tqdm
from sklearn import model_selection
from torch.utils.data import DataLoader, TensorDataset
from torchtext.data.metrics import bleu_score

In [2]:
RANDOM_STATE = 1

## Подготовка данных

In [3]:
df = pd.read_csv('./data/lenta/dataset.csv')

In [4]:
df = df.sample(frac=0.5, random_state=RANDOM_STATE)
df

Unnamed: 0,orig_texts,lemm_texts,nsubj,gender,tense
1245806,об этом сообщает риа новости со ссылкой на мат...,о это сообщать риа новость с ссылка на мать по...,риа,neut,pres
1594042,генеральный прокурор рф владимир устинов счита...,генеральный прокурор рф владимир устинов счита...,прокурор,masc,pres
705659,"телеканал «дождь» восстановил вещание, прерван...","телеканал « дождь » восстановить вещание, прер...",телеканал,masc,past
603796,соответствующее требование прозвучало во время...,соответствующий требование прозвучать в время ...,требование,neut,past
1430273,"в пятницу вечером на сайтах ""единой россии"", ""...","в пятница вечером на сайт ""единый россия"", ""гр...",заявления,neut,past
...,...,...,...,...,...
1331395,об этом заявил президент - председатель правле...,о это заявить президент - председатель правлен...,президент,masc,past
1710395,"""мы думаем, что они (страны-члены совбеза) пре...","\"" мы думать, что они (страна-член совбез) пре...",мы,undefined,pres
1434398,"социологи ""росгосстраха"" оценили сознательност...","социолог ""росгосстрах""оценить сознательность р...",социологи,masc,past
1832873,российская сборная сохранила за собой 24 строчку.,российский сборная сохранить за себя 24 строчка.,сборная,fem,past


### Определение классов словаря и трансформера текста

In [5]:
class Vocab:
    def __init__(self, tokens: List[str], unk_idx: int):
        self._tokens = tokens
        self._token_to_idx = {token: idx for idx, token in enumerate(tqdm(tokens, 'Transforming tokens'))}
        self._unk_idx = unk_idx
        
    def token_to_idx(self, token: str) -> int:
        return self._token_to_idx.get(token, self._unk_idx)
    
    def idx_to_token(self, idx: int) -> str:
        return self._tokens[idx]

In [6]:
class TextTransformer:
    def __init__(self, vocab_size: int = 250000):
        self.vocab = None
        self.vocab_size = vocab_size
        self.special_tokens_to_idx = {'<unk>': 0, '<pad>': 1, '<sos>': 2, '<eos>': 3}
#         self.special_tokens_to_idx = None
        self._tokenizer = nltk.tokenize.word_tokenize
    
    def tokenize(self, text, language='russian') -> List[str]:
        return self._tokenizer(text.lower(), language)
    
    def save_vocab(self, path='./vocab.vcb'):
        with open(path, 'wb') as f:
            pickle.dump(self.vocab, f)
            
    def load_vocab(self, path):
        with open(path, 'rb') as f:
            self.vocab = pickle.load(f)
    
    def build_vocab(self, tokens: List[str], unk_idx: int = 0, pad_idx: int = 1):
#         self.special_tokens_to_idx = {'<unk>': unk_idx, '<pad>': pad_idx, '<sos>': unk_idx + 1, '<eos>': unk_idx + 2}
#         tokens.extend(list(self.special_tokens_to_idx.keys()))
#         self.vocab = Vocab(tokens, unk_idx)
        tokens_ = [special_token for special_token in self.special_tokens_to_idx.keys()]
        special_tokens_amount = len(self.special_tokens_to_idx)
        
        for token, _ in Counter(tokens).most_common(self.vocab_size - special_tokens_amount):
            tokens_.append(token)
        
        unk_idx = self.special_tokens_to_idx.get('<unk>')
        self.vocab = Vocab(tokens_, unk_idx)
        
    def transform_text(self, text: str) -> List[int]:
        tokenized_text = self.tokenize(text)
        transformed = [self.vocab.token_to_idx(token) for token in tokenized_text]
        return transformed
    
    def fit(self, texts: List[str]) -> None:
        transformed_texts = []
        
        tokenized_texts = [self.tokenize(text) for text in tqdm(texts, 'Tokenizing texts')]
        tokens = chain(*tokenized_texts)
        self.build_vocab(tokens)
        
#         for tokenized_text in tqdm(tokenized_texts, 'Transforming texts'):
#             transformed = [self.vocab.token_to_idx(token) for token in tokenized_text]
#             transformed_texts.append(transformed)
    
    def transform_texts(self, texts: List[str]) -> List[List[int]]:
        transformed_texts = [transform_text(text) for text in tqdm(texts, 'Transforming texts')]
        return transformed_texts
    
    def text_to_tensor(self, text: str, max_seq_len) -> torch.tensor:
        transformed_text = self.transform_text(text)
        pad_idx = self.special_tokens_to_idx.get('<pad>')
        sos_idx = self.special_tokens_to_idx.get('<sos>')
        eos_idx = self.special_tokens_to_idx.get('<eos>')
        
        pad_size = 0
        if len(transformed_text) >= max_seq_len:
            transformed_text = transformed_text[:max_seq_len]
        else:
            pad_size = max_seq_len - len(transformed_text)
            transformed_text.extend([pad_idx] * pad_size)   
        transformed_text.insert(0, sos_idx)
        transformed_text.insert(len(transformed_text) - pad_size, eos_idx)
        
        tensor = torch.tensor(transformed_text, dtype=torch.long)
        return tensor.unsqueeze(0)
    
    def texts_to_tensor(self, texts: List[str], max_seq_len) -> torch.tensor:
        pad_idx = self.special_tokens_to_idx.get('<pad>')
        sos_idx = self.special_tokens_to_idx.get('<sos>')
        eos_idx = self.special_tokens_to_idx.get('<eos>')
        transformed_texts = []
        
        for text in tqdm(texts, 'Building tensor'):
            transformed_text = self.transform_text(text)
            pad_size = 0
            if len(transformed_text) >= max_seq_len:
                transformed_text = transformed_text[:max_seq_len]
            else:
                pad_size = max_seq_len - len(transformed_text)
                transformed_text.extend([pad_idx] * pad_size)   
            transformed_text.insert(0, sos_idx)
            transformed_text.insert(len(transformed_text) - pad_size, eos_idx)
            transformed_texts.append(transformed_text)
        
        tensor = torch.tensor(transformed_texts, dtype=torch.long).permute(1, 0)
        return tensor

### Разбиение данных на обучающую, тестовую и валидационную выборки

In [7]:
train_df, test_df = model_selection.train_test_split(df, test_size=0.1, random_state=RANDOM_STATE)

In [8]:
test_df, val_df = model_selection.train_test_split(test_df, test_size=0.5, random_state=RANDOM_STATE)

### Токенизация текстов и индексация токенов

In [9]:
vocab_size = 125000

In [10]:
max_seq_len = 40

In [11]:
text_transformer = TextTransformer(vocab_size)

In [12]:
text_transformer.load_vocab('./data/cached/vocab.vcb')

In [13]:
# text_transformer.build_vocab(embedding.vocab.words[0:-2], embedding.vocab.unk_id, embedding.vocab.pad_id)

In [14]:
# lemm_vocab_size = 23000
# orig_vocab_size = 65000

In [15]:
# lemm_text_transformer = TextTransformer(lemm_vocab_size)
# orig_text_transformer = TextTransformer(orig_vocab_size)

In [16]:
# text_transformer.fit(train_df.orig_texts.to_list() + train_df.lemm_texts.to_list())

In [17]:
# with open('./data/cached/tokens.list', 'rb') as f:
#     tokens = pickle.load(f)

In [18]:
# text_transformer.build_vocab(tokens)

In [19]:
# tokens = [text_transformer.vocab.idx_to_token(idx) for idx in range(4, 124999)]

In [20]:
# with open('./tokens.list', 'wb') as f:
#     pickle.dump(tokens, f)

In [21]:
# orig_text_transformer.fit(train_df.orig_texts)

### Перевод данных в тензоры

In [22]:
# tensors = {
#     'train_lemm_tensor': train_lemm_tensor,
#     'test_lemm_tensor': test_lemm_tensor,
#     'val_lemm_tensor': val_lemm_tensor,
#     'train_orig_tensor': train_orig_tensor,
#     'test_orig_tensor': test_orig_tensor,
#     'val_orig_tensor': val_orig_tensor
# }

# with open('./data_tensors.data', 'wb') as f:
#     pickle.dump(tensors, f)

In [23]:
with open('./data/cached/data_tensors.data', 'rb') as f:
    tensors = pickle.load(f)

In [24]:
train_lemm_tensor, test_lemm_tensor, val_lemm_tensor,\
train_orig_tensor, test_orig_tensor, val_orig_tensor = tensors.values()

In [25]:
# train_lemm_tensor = text_transformer.texts_to_tensor(train_df.lemm_texts.to_list(), max_seq_len)
# test_lemm_tensor = text_transformer.texts_to_tensor(test_df.lemm_texts.to_list(), max_seq_len)
# val_lemm_tensor = text_transformer.texts_to_tensor(val_df.lemm_texts.to_list(), max_seq_len)

In [26]:
# train_orig_tensor = text_transformer.texts_to_tensor(train_df.orig_texts.to_list(), max_seq_len)
# test_orig_tensor = text_transformer.texts_to_tensor(test_df.orig_texts.to_list(), max_seq_len)
# val_orig_tensor = text_transformer.texts_to_tensor(val_df.orig_texts.to_list(), max_seq_len)

In [27]:
gender_to_vec = {
    'masc': [1, 0, 0, 0],
    'fem': [0, 1, 0, 0],
    'neut': [0, 0, 1, 0],
    'undefined': [0, 0, 0, 1]
}

In [28]:
tense_to_vec = {
    'pres': [1, 0, 0],
    'past': [0, 1, 0],
    'fut': [0, 0, 1]
}

In [29]:
def transform_context(df, df_type: str):
    transformed_gender = [gender_to_vec.get(gender) for gender in tqdm(df.gender, f'Transforming gender ({df_type})')]
    transformed_tense = [tense_to_vec.get(tense) for tense in tqdm(df.tense, f'Transforming tense ({df_type})')]
    transformed_nsubj = [text_transformer.vocab.token_to_idx(nsubj) for nsubj in tqdm(df.nsubj, f'Transforming nsubj ({df_type})')]
    
    context = [transformed_nsubj, transformed_gender, transformed_tense]
    return context

In [30]:
def context_to_tensors(context):
    nsubj, gender, tense = context
    
    nsubj_tensor = torch.tensor(nsubj)
    gender_tensor = torch.tensor(gender, dtype=torch.float32)
    tense_tensor = torch.tensor(tense, dtype=torch.float32)
    
    context_tensors = [nsubj_tensor, gender_tensor, tense_tensor]
    return context_tensors

In [31]:
train_context = transform_context(train_df, 'train')
test_context = transform_context(test_df, 'test')
val_context = transform_context(val_df, 'validation')

HBox(children=(FloatProgress(value=0.0, description='Transforming gender (train)', max=833725.0, style=Progres…




HBox(children=(FloatProgress(value=0.0, description='Transforming tense (train)', max=833725.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, description='Transforming nsubj (train)', max=833725.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, description='Transforming gender (test)', max=46318.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, description='Transforming tense (test)', max=46318.0, style=ProgressSt…




HBox(children=(FloatProgress(value=0.0, description='Transforming nsubj (test)', max=46318.0, style=ProgressSt…




HBox(children=(FloatProgress(value=0.0, description='Transforming gender (validation)', max=46319.0, style=Pro…




HBox(children=(FloatProgress(value=0.0, description='Transforming tense (validation)', max=46319.0, style=Prog…




HBox(children=(FloatProgress(value=0.0, description='Transforming nsubj (validation)', max=46319.0, style=Prog…




In [32]:
train_context_tensors = context_to_tensors(train_context)
test_context_tensors = context_to_tensors(test_context)
val_context_tensors = context_to_tensors(val_context)

In [33]:
def cut_to_fit_batch(tensor: torch.Tensor, batch_size: int):
    n_samples = tensor.shape[1]
    new_n_samples = (n_samples // batch_size) * batch_size
    result = tensor.split(new_n_samples, dim=1)[0]
    return torch.transpose(result, 1, 0)

## Построение модели

In [34]:
class ContextMem(nn.Module):
    def __init__(self, gender_input_size, tense_input_size, hidden_size, output_size, nsubj_embedding_size, device):
        super(ContextMem, self).__init__()
        
        self.device = device

        self.gender_proj = nn.Linear(gender_input_size, hidden_size, bias=False)
        self.tense_proj = nn.Linear(tense_input_size, hidden_size, bias=False)
        self.fc_out = nn.Linear(hidden_size * 2 + nsubj_embedding_size, output_size, bias=False)
        
    def forward(self, nsubj_embedding, gender, tense):
        # nsubj_embedding_shape: (batch_size, embedding_size)
        # gender_shape: (batch_size, input_size)
        # tense_shape: (batch_size, input_size)
        
        gender = self.gender_proj(gender)
        # gender_shape: (batch_size, hidden_size)
        
        tense = self.tense_proj(tense)
        # tense_shape: (batch_size, hidden_size)    
        
        context = torch.cat([nsubj_embedding, gender, tense], dim=-1)
        # context_shape: (batch_size, hidden_size * 2 + embedding_size)
        
        context = self.fc_out(context)
        # context_shape: (batch_size, output_size)
        
        return context

In [35]:
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, hidden_size: int, pad_idx: int,
                 device, dropout_p: float, embedding=None, pretrained_embedding_loaded=False):
        super(EncoderRNN, self).__init__()
        
        self.device = device
        
        self.hidden_size = hidden_size
        
        self.embedding = embedding
        self.pretrained_embedding_loaded = pretrained_embedding_loaded
        
        self.rnn = nn.GRU(embedding_size, hidden_size, dropout=0.0, bidirectional=True)
        
    def forward(self, sequence, hidden):
        # sequence_shape: (seq_len, batch_size)
        # hidden_shape: (num_layers=1 * 2, batch_size, hidden_size)
        # cell_shape: (num_layers=1 * 2, batch_size, hidden_size)
        
        if self.pretrained_embedding_loaded:
            with torch.no_grad():
                embedding = self.embedding(sequence)
        else:
            embedding = self.embedding(sequence)
        # embedding_shape: (seq_len, batch_size, embedding_size)
        encoder_states = self.rnn(embedding, hidden)[0]
        # encoder_states: (seq_len, batch_size, hidden_size * 2)
        
        return encoder_states

In [36]:
class DecoderRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, hidden_size: int, output_size: int, pad_idx: int,
                 device, dropout_p: float, embedding=None, pretrained_embedding_loaded=False):
        super(DecoderRNN, self).__init__()
        
        self.device = device
        
        self.hidden_size = hidden_size
        
        self.embedding = embedding
        self.pretrained_embedding_loaded = pretrained_embedding_loaded
        
        self.attn_weights = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size, bias=False),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False),
            nn.Softmax(dim=1)
        )
        self.rnn = nn.GRU(embedding_size + 2 * hidden_size, hidden_size, dropout=0.0)
        self.fc_out = nn.Linear(hidden_size, output_size)
    
    def forward(self, token, encoder_states):
        token.unsqueeze_(0)
        # token_shape: (seq_len=1, batch_size)
        
        encoder_states = torch.transpose(encoder_states, 1, 0)
        # encoder_states_shape: (batch_size, seq_len, hidden_size * 2)
        
        if self.pretrained_embedding_loaded:
            with torch.no_grad():
                embedding = self.embedding(token)
        else:
            embedding = self.embedding(token)
        # embedding_shape: (seq_len=1, batch_size, embedding_size)
        
        seq_len = encoder_states.shape[1]
        
        attn_weights = self.attn_weights(encoder_states)
        # attn_weights_shape: (batch_size, seq_len, 1)
        
        context_vec = torch.bmm(attn_weights.permute(0, 2, 1), encoder_states).permute(1, 0, 2)
        # context_vec_shape: (1, batch_size, hidden_size * 2)
        
        combined = torch.cat((embedding, context_vec), dim=2)
        # combined_shape: (1, batch_size, embedding_size + 2 * hidden_size)
        
        rnn_out = self.rnn(combined)[0]
        # rnn_out_shape: (seq_len=1, batch_size, hidden_size)
        
        fc_out = self.fc_out(rnn_out)
        # fc_out_shape: (seq_len=1, batch_size, output_size)
        
        return fc_out

In [37]:
class Seq2SeqModel(nn.Module):
    def __init__(self, 
                 vocab_size, embedding_size, hidden_size, output_size,
                 gender_input_size, tense_input_size, context_hidden_size, context_output_size,
                 pad_idx, device, dropout_p, pretrained_embedding=None):
        super(Seq2SeqModel, self).__init__()
        
        self.device = device
        
        if pretrained_embedding is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embedding, padding_idx=pad_idx)
            self.pretrained_embedding_loaded = True
        else:
            self.embedding = nn.Sequential(
                nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx),
                nn.Dropout(dropout_p)
            )
            self.pretrained_embedding_loaded = False
        
        self.context_mem = ContextMem(gender_input_size, tense_input_size, context_hidden_size, context_output_size, embedding_size, device).to(device)
        self.encoder = EncoderRNN(vocab_size, embedding_size, hidden_size,
                                  pad_idx, device, dropout_p,
                                  self.embedding, self.pretrained_embedding_loaded).to(device)
        self.decoder = DecoderRNN(vocab_size, embedding_size, hidden_size, output_size,
                                  pad_idx, device, dropout_p,
                                  self.embedding, self.pretrained_embedding_loaded).to(device)
        
        self.vocab_size = vocab_size
        
    def forward(self, input, target, context, teacher_forcing_ratio=0.0):
        batch_size = input.shape[1]
        target_len = target.shape[0]
        target_vocab_size = self.vocab_size
        
        outputs = torch.zeros(target_len, batch_size, target_vocab_size, device=self.device)
        
        nsubj, gender, tense = context
        # nsubj_shape:  (batch_size)
        # gender_shape: (batch_size, gender_input_size)
        # tense_shape:  (batch_size, tense_input_size)
        
        if self.pretrained_embedding_loaded:
            with torch.no_grad():
                nsubj_embedding = self.embedding(nsubj).squeeze(0)
        else:
            nsubj_embedding = self.embedding(nsubj).squeeze(0)
            # nsubj_embedding_shape: (batch_size, embedding_size)
        
        hidden = self.context_mem(nsubj_embedding, gender, tense)
        # hidden_shape: (batch_size, context_output_size=hidden_size)
        
        hidden = torch.cat([hidden.unsqueeze(0)] * 2, 0)
        # hidden_shape: (2, batch_size, context_output_size=hidden_size)
        
        encoder_states = self.encoder(input, hidden)
        # encoder_states_shape: (seq_len, batch_size, hidden_size * 2)
        
        prev_token_idx = target[0] # sos_token
        # prev_token_shape: (batch_size)
        
        for t in range(1, target_len):
            output = self.decoder(prev_token_idx, encoder_states)
            # output_shape: (1, batch_size, output_size)
            
            outputs[t] = output.squeeze(0)
            
            best_prediction = outputs[t].argmax(dim=1)
            # best_prediction_shape: (batch_size)
            
            prev_token_idx = target[t] if random() < teacher_forcing_ratio else best_prediction
        
        return outputs

## Обучение модели

### Функция сохранения текущего состояния модели

In [38]:
def save_model(model, optimizer, epoch, path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'criterion': criterion,
        'epoch': epoch
    }
    
    torch.save(checkpoint, path)

### Функция загрузки уже тренировавшейся модели

In [39]:
def load_model(model, optimizer, criterion, path, for_inference=True, device=torch.device('cpu')):
    checkpoint = torch.load(path, map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    
    if not for_inference:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        criterion = checkpoint['criterion']

        return epoch

### Инициализация гиперпараметров

In [40]:
learning_rate = 0.001
batch_size = 4
epochs_amount = 50
hidden_size = 768
embedding_size = 300
max_norm = 1.0
dropout_p = 0.5
gender_input_size = 4
tense_input_size = 3
context_hidden_size = hidden_size // 2
context_output_size = hidden_size
patience = 3
output_size = vocab_size
pad_idx = text_transformer.special_tokens_to_idx.get('<pad>')
model_path = './models/'
model_name = 'seq2seq_attention_fixed.model'

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
# device = torch.device('cpu')

In [43]:
model = Seq2SeqModel(vocab_size, embedding_size, hidden_size, output_size,
                     gender_input_size, tense_input_size, context_hidden_size, context_output_size, 
                     pad_idx, device, dropout_p).to(device)

In [44]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [45]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [46]:
try:
    epoch = load_model(model, optimizer, criterion, model_path + model_name, for_inference=True)
    print(f'Loaded model from {model_path}')
except:
    print(f'No models found at {model_path}')
    epoch = 1

No models found at ./models/


### Урезание данных для соответствия размеру батча

In [47]:
train_lemm_tensor_f = cut_to_fit_batch(train_lemm_tensor, batch_size)
train_orig_tensor_f = cut_to_fit_batch(train_orig_tensor, batch_size)

test_lemm_tensor_f = cut_to_fit_batch(test_lemm_tensor, batch_size)
test_orig_tensor_f = cut_to_fit_batch(test_orig_tensor, batch_size)

val_lemm_tensor_f = cut_to_fit_batch(val_lemm_tensor, batch_size)
val_orig_tensor_f = cut_to_fit_batch(val_orig_tensor, batch_size)

In [48]:
train_context_tensors_f = [cut_to_fit_batch(tensor.unsqueeze(0), batch_size).squeeze(1) for tensor in train_context_tensors]
test_context_tensors_f = [cut_to_fit_batch(tensor.unsqueeze(0), batch_size).squeeze(1) for tensor in test_context_tensors]
val_context_tensors_f = [cut_to_fit_batch(tensor.unsqueeze(0), batch_size).squeeze(1) for tensor in val_context_tensors]

### Инициализация данных итерируемых по батчам

In [49]:
train_dataset = TensorDataset(train_lemm_tensor_f, train_orig_tensor_f, *train_context_tensors_f)
test_dataset = TensorDataset(test_lemm_tensor_f, test_orig_tensor_f, *test_context_tensors_f)
val_dataset = TensorDataset(val_lemm_tensor_f, val_orig_tensor_f, *val_context_tensors_f)

In [50]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

### Определение функции проверки работы сети между эпохами обучения

In [51]:
def test_evaluate(model, input, context, target_len=40):
    with torch.no_grad():
        model.eval()
        
        input = input.to(device)

        nsubj, gender, tense = context
        nsubj_embedding = model.embedding(nsubj)

        hidden = model.context_mem(nsubj_embedding, gender, tense)
        hidden = torch.cat([hidden.unsqueeze(0)] * 2, 0)

        sos_idx = text_transformer.special_tokens_to_idx.get('<sos>')
        eos_idx = text_transformer.special_tokens_to_idx.get('<eos>')
    
        encoder_states = model.encoder(input, hidden)
        
        predicted_indexes = [sos_idx]
        
        for _ in range(1, target_len):
            prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)
            
            output = model.decoder(prev_idx, encoder_states)
            output = output.squeeze(0)
            
            best_prediction = output.argmax(dim=1).item()
            
            if best_prediction == eos_idx:
                break
                
            predicted_indexes.append(best_prediction)
                        
        
    predicted_tokens = [text_transformer.vocab.idx_to_token(idx) for idx in predicted_indexes]
    return predicted_tokens[1:]

### Определение функции обучения сети

In [52]:
def train(model, optimizer, criterion, train_data, val_data, test_data, epochs_amount, max_norm, patience=3, current_epoch=1, n_prints=5):
    min_mean_val_loss = float('+inf')
    initial_patiece = patience
    print_every = len(train_data) // n_prints
    
    for epoch in tqdm(range(current_epoch, epochs_amount + 1), 'Epochs'):
        print(f'\nEpoch [{epoch} / {epochs_amount}]')
        
        model.train()
        for iteration, (input, target, nsubj, gender, tense) in enumerate(tqdm(train_data, 'Epoch training iterations')):
            optimizer.zero_grad()
            # input = lemm_texts, target = orig_texts
            
            input = torch.transpose(input, 1, 0).to(device)
            # input_shape: (seq_len, batch_size)
            
            target = torch.transpose(target, 1, 0).to(device)
            # target_shape: (seq_len, batch_size)
            
            context = (nsubj.to(device), gender.to(device), tense.to(device))
            
            output = model(input, target, context)
            # output_shape: (seq_len, batch_size, vocab_size) but need (N, vocab_size)
            
            target = target[1:].reshape(-1)
            # now target_shape is (seq_len * batch_size)
            
            vocab_size = output.shape[2]
            
            output = output[1:].reshape(-1, vocab_size)
            # now output_shape is (seq_len * batch_size, vocab_size)
            
            loss = criterion(output, target)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            
            optimizer.step()
            
            if iteration % print_every == 0:
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            elif iteration == len(train_data):
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            
            
        with torch.no_grad():
            model.eval()
            val_loss = []
            
            for input, target, nsubj, gender, tense in tqdm(val_data, 'Epoch validating iterations'):
                input = torch.transpose(input, 1, 0).to(device)
                target = torch.transpose(target, 1, 0).to(device)
                context = (nsubj.to(device), gender.to(device), tense.to(device))
                
                output = model(input, target, context)
                vocab_size = output.shape[2]
                output = output[1:].reshape(-1, orig_vocab_size)
                target = target[1:].reshape(-1)
                
                val_loss.append(criterion(output, target).item())
            
            mean_val_loss = sum(val_loss) / len(val_loss)
            print(f'\tValidation loss = {mean_val_loss}')
            if mean_val_loss < min_mean_val_loss:
                try:
                    save_model(model, optimizer, epoch, model_path + model_name)
                    min_mean_val_loss = mean_val_loss
                    patience = initial_patiece
                except Exception as exc:
                    print(exc)
            else:
                patience -= 1
            
            test_data = DataLoader(test_data.dataset, batch_size=1, shuffle=True)
            for input, target, nsubj, gender, tense in test_data:
                target = target.squeeze(0).to(device)
                context = (nsubj.to(device), gender.to(device), tense.to(device))
                
                input = torch.transpose(input, 1, 0)
                target_len = target.shape[0]
                
                output = test_evaluate(model, input, context, target_len)
                decoded_input = [text_transformer.vocab.idx_to_token(idx.item()) for idx in input]
                decoded_target = [text_transformer.vocab.idx_to_token(idx.item()) for idx in target]
                
                print(f'\tInput: {decoded_input}')
                print(f'\tOutput: {output}')
                print(f'\tTarget: {decoded_target}')
                break
        
        if patience == 0:
            print(f'\nModel learning finished due to early stopping')
            break


### Определение функции эксплуатации обученной модели

In [53]:
def evaluate(model: Seq2SeqModel, sentence: str, context, max_seq_len=45):
    with torch.no_grad():
        model.eval()
        
        nsubj, gender, tense = context
        
        nsubj = torch.tensor([text_transformer.vocab.token_to_idx(nsubj)], device=device).unsqueeze(0)
        gender = torch.tensor([gender_to_vec[gender]], dtype=torch.float32, device=device)
        tense = torch.tensor([tense_to_vec[tense]], dtype=torch.float32, device=device)
        
        nsubj_embedding = model.embedding(nsubj).squeeze(0)

        hidden = model.context_mem(nsubj_embedding, gender, tense)
        hidden = torch.cat([hidden.unsqueeze(0)] * 2, 0)
        # hidden_shape: (2, batch_size, context_output_size=hidden_size)
        
        input_tensor = text_transformer.text_to_tensor(sentence, max_seq_len).to(device)
        input_tensor = torch.transpose(input_tensor, 1, 0)
        
        sos_idx = text_transformer.special_tokens_to_idx.get('<sos>')
        eos_idx = text_transformer.special_tokens_to_idx.get('<eos>')
    
    
        encoder_states = model.encoder(input_tensor, hidden)
        
        predicted_indexes = [sos_idx]                       
        
        for _ in range(1, max_seq_len):
            prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)
            
            output = model.decoder(prev_idx, encoder_states)
            output = output.squeeze(0)
            
            best_prediction = output.argmax(dim=1).item()
            
            if best_prediction == eos_idx:
                break
                
            predicted_indexes.append(best_prediction)
        
    predicted_tokens = [text_transformer.vocab.idx_to_token(idx) for idx in predicted_indexes]
    return predicted_tokens[1:]

In [54]:
train(model, optimizer, criterion, train_loader, val_loader, test_loader, epochs_amount, max_norm, patience, epoch)

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=50.0, style=ProgressStyle(description_width=…


Epoch [1 / 50]


HBox(children=(FloatProgress(value=0.0, description='Epoch training iterations', max=208431.0, style=ProgressS…

	Iteration #0: training loss = 11.751469612121582




KeyboardInterrupt: 

In [None]:
# import gc
# del model
# del optimizer
# gc.collect()
# torch.cuda.empty_cache()
# gc.collect()

In [None]:
test_sample = test_df.sample(10)
test_lemm_sents = test_sample.lemm_texts.to_list()
test_target_sents = test_sample.orig_texts.to_list()
test_nsubj = test_sample.nsubj.to_list()
test_gender = test_sample.gender.to_list()
test_tense = test_sample.tense.to_list()
test_input = zip(test_lemm_sents, test_target_sents, test_nsubj, test_gender, test_tense) 

In [None]:
for lemm_sent, target_sent, nsubj, gender, tense in test_input:
    model_output = evaluate(model, lemm_sent, (nsubj, gender, tense))
    print(f'Input: {lemm_sent}')
    print(f'Output: {model_output}')
    print(f'Target: {target_sent}')
    print(f'Nsubj: {nsubj}')
    print(f'Gender: {gender}')
    print(f'Tense: {tense}')
    print('\n')

In [None]:
max_samples = 15000

test_lemm_sent = test_df.lemm_texts.to_list()[:max_samples]
test_orig_sent = test_df.orig_texts.to_list()[:max_samples]
test_nsubj = test_df.nsubj.to_list()[:max_samples]
test_gender = test_df.gender.to_list()[:max_samples]
test_tense = test_df.tense.to_list()[:max_samples]
test_input = zip(test_lemm_sent, test_nsubj, test_gender, test_tense)

In [None]:
device = torch.device('cuda')
model.to(device)

In [None]:
outputs = [evaluate(model, lemm_sent, (nsubj, gender, tense)) for lemm_sent, nsubj, gender, tense in tqdm(test_input)]
targets = [[text_transformer.tokenize(target)] for target in tqdm(test_orig_sent)]

In [None]:
score = round(bleu_score(outputs, targets, max_n=1, weights=[1]), 3)
score