## Импорт необходимых зависимостей

In [1]:
import pandas as pd
import numpy as np
import nltk
import torch
import torch.nn as nn
import torch.optim
import time
import pickle
import torch.nn.functional as F

from random import random, sample
from typing import List
from collections import Counter
from itertools import chain
from functools import reduce
from tqdm.auto import tqdm
from sklearn import model_selection
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelBinarizer
from torch.utils.data import DataLoader, TensorDataset

## Подготовка данных

In [2]:
df = pd.read_csv('./data/simple_sentences/processed/dataset.csv')

In [3]:
df

Unnamed: 0,lemm_texts,orig_texts,nsubj,gender,tense,number
0,я предлагать оригинальный подарок для малыш!,я предлагаю оригинальный подарок для малыша!,я,undefined,pres,sing
1,я обезательный перезвонить в любой случай.,я обезательно перезвоню в любом случае.,я,undefined,fut,sing
2,цена на память я не помнить.,цены на память я не помню.,я,undefined,pres,sing
3,"я не помнить , где находиться.","я не помню, где находились.",я,undefined,pres,sing
4,я работать на высококачественный американский ...,я работаю на высококачественных американских м...,я,undefined,pres,sing
...,...,...,...,...,...,...
356967,другой ящерица медленно подбрести к свой товарка.,другая ящерица медленно подбрела к своей товарке.,ящерица,fem,past,sing
356968,зелёный ящерица застылый на мраморный ступень.,зеленая ящерица застыла на мраморной ступени.,ящерица,fem,past,sing
356969,больший ящерица шмыгнуть по песок.,большая ящерица шмыгнула по песку.,ящерица,fem,past,sing
356970,домашний ящерица быстро пробежать вдоль штора.,домашняя ящерица быстро пробежала вдоль штор.,ящерица,fem,past,sing


### Определение классов словаря и трансформера текста

In [4]:
class Vocab:
    def __init__(self, tokens: List[str], unk_idx: int):
        self._tokens = tokens
        self._token_to_idx = {token: idx for idx, token in enumerate(tokens)}
        self._unk_idx = unk_idx
        
    def token_to_idx(self, token: str) -> int:
        return self._token_to_idx.get(token, self._unk_idx)
    
    def idx_to_token(self, idx: int) -> str:
        return self._tokens[idx]

In [5]:
class TextTransformer:
    def __init__(self, vocab_size: int):
        self.vocab = None
        self.vocab_size = vocab_size
        self.special_tokens_to_idx = {'<UNK>': 0, '<PAD>': 1, '<SOS>': 2, '<EOS>': 3}
        self._tokenizer = nltk.tokenize.wordpunct_tokenize
    
    def tokenize(self, text) -> List[str]:
        return self._tokenizer(text.lower())
    
    def build_vocab(self, tokens: List[str]):
        tokens_ = [special_token for special_token in self.special_tokens_to_idx.keys()]
        special_tokens_amount = len(self.special_tokens_to_idx)
        
        for token, _ in Counter(tokens).most_common(self.vocab_size - special_tokens_amount):
            tokens_.append(token)
        
        unk_idx = self.special_tokens_to_idx.get('<UNK>')
        self.vocab = Vocab(tokens_, unk_idx)
        
    def transform_text(self, text: str) -> List[int]:
        tokenized_text = self.tokenize(text)
        transformed = [self.vocab.token_to_idx(token) for token in tokenized_text]
        return transformed
    
    def fit_transform(self, texts: List[str]) -> None:
        transformed_texts = []
        
        tokenized_texts = [self.tokenize(text) for text in tqdm(texts, 'Tokenizing texts')]
        tokens = chain(*tokenized_texts)
        self.build_vocab(tokens)
        
        for tokenized_text in tqdm(tokenized_texts, 'Transforming texts'):
            transformed = [self.vocab.token_to_idx(token) for token in tokenized_text]
            transformed_texts.append(transformed)
    
    def transform_texts(self, texts: List[str]) -> List[List[int]]:
        transformed_texts = [transform_text(text) for text in tqdm(texts, 'Transforming texts')]
        return transformed_texts
    
    def text_to_tensor(self, text: str, max_seq_len=8) -> torch.tensor:
        transformed_text = self.transform_text(text)
        pad_idx = self.special_tokens_to_idx.get('<PAD>')
        sos_idx = self.special_tokens_to_idx.get('<SOS>')
        eos_idx = self.special_tokens_to_idx.get('<EOS>')
        
        pad_size = 0
        if len(transformed_text) >= max_seq_len:
            transformed_text = transformed_text[:max_seq_len]
        else:
            pad_size = max_seq_len - len(transformed_text)
            transformed_text.extend([pad_idx] * pad_size)   
        transformed_text.insert(0, sos_idx)
        transformed_text.insert(len(transformed_text) - pad_size, eos_idx)
        
        tensor = torch.tensor(transformed_text, dtype=torch.long)
        return tensor.unsqueeze(0)
    
    def texts_to_tensor(self, texts: List[str], max_seq_len=8) -> torch.tensor:
        pad_idx = self.special_tokens_to_idx.get('<PAD>')
        sos_idx = self.special_tokens_to_idx.get('<SOS>')
        eos_idx = self.special_tokens_to_idx.get('<EOS>')
        transformed_texts = []
        
        for text in tqdm(texts, 'Building tensor'):
            transformed_text = self.transform_text(text)
            pad_size = 0
            if len(transformed_text) >= max_seq_len:
                transformed_text = transformed_text[:max_seq_len]
            else:
                pad_size = max_seq_len - len(transformed_text)
                transformed_text.extend([pad_idx] * pad_size)   
            transformed_text.insert(0, sos_idx)
            transformed_text.insert(len(transformed_text) - pad_size, eos_idx)
            transformed_texts.append(transformed_text)
        
        tensor = torch.tensor(transformed_texts, dtype=torch.long).permute(1, 0)
        return tensor

### Разбиение данных на обучающую, тестовую и валидационную выборки

In [6]:
train_df, test_df = model_selection.train_test_split(df, test_size=0.1)

In [7]:
test_df, val_df = model_selection.train_test_split(test_df, test_size=0.25)

### Токенизация текстов и индексация токенов

In [8]:
lemm_vocab_size = 35000
orig_vocab_size = 70000

In [9]:
lemm_text_transformer = TextTransformer(lemm_vocab_size)
orig_text_transformer = TextTransformer(orig_vocab_size)

In [10]:
lemm_text_transformer.fit_transform(train_df.lemm_texts)

Tokenizing texts:   0%|          | 0/321274 [00:00<?, ?it/s]

Transforming texts:   0%|          | 0/321274 [00:00<?, ?it/s]

In [11]:
orig_text_transformer.fit_transform(train_df.orig_texts)

Tokenizing texts:   0%|          | 0/321274 [00:00<?, ?it/s]

Transforming texts:   0%|          | 0/321274 [00:00<?, ?it/s]

### Перевод данных в тензоры

In [12]:
train_lemm_tensor = lemm_text_transformer.texts_to_tensor(train_df.lemm_texts.to_list())
test_lemm_tensor = lemm_text_transformer.texts_to_tensor(test_df.lemm_texts.to_list())
val_lemm_tensor = lemm_text_transformer.texts_to_tensor(val_df.lemm_texts.to_list())

Building tensor:   0%|          | 0/321274 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/26773 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/8925 [00:00<?, ?it/s]

In [13]:
train_orig_tensor = orig_text_transformer.texts_to_tensor(train_df.orig_texts.to_list())
test_orig_tensor = orig_text_transformer.texts_to_tensor(test_df.orig_texts.to_list())
val_orig_tensor = orig_text_transformer.texts_to_tensor(val_df.orig_texts.to_list())

Building tensor:   0%|          | 0/321274 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/26773 [00:00<?, ?it/s]

Building tensor:   0%|          | 0/8925 [00:00<?, ?it/s]

In [14]:
def transform_context(df, df_type: str):
    gender_rows = pd.get_dummies(df.gender).iterrows()
    tense_rows = pd.get_dummies(df.tense).iterrows()
    nsubj_to_idx = orig_text_transformer.vocab.token_to_idx
    
    transformed_genders = [row[1].to_list() for row in tqdm(gender_rows, f'Transforming genders ({df_type})')]
    transformed_tenses = [row[1].to_list() for row in tqdm(tense_rows, f'Transforming tenses ({df_type})')]
    transformed_nsubjes = [nsubj_to_idx(nsubj) for nsubj in tqdm(df.nsubj, f'Transforming nsubjes ({df_type})')]
    
    context = [transformed_nsubjes, transformed_genders, transformed_tenses]
    return context

In [15]:
train_context = transform_context(train_df, 'train')
test_context = transform_context(test_df, 'test')
val_context = transform_context(val_df, 'validation')

Transforming genders (train): 0it [00:00, ?it/s]

Transforming tenses (train): 0it [00:00, ?it/s]

Transforming nsubjes (train):   0%|          | 0/321274 [00:00<?, ?it/s]

Transforming genders (test): 0it [00:00, ?it/s]

Transforming tenses (test): 0it [00:00, ?it/s]

Transforming nsubjes (test):   0%|          | 0/26773 [00:00<?, ?it/s]

Transforming genders (validation): 0it [00:00, ?it/s]

Transforming tenses (validation): 0it [00:00, ?it/s]

Transforming nsubjes (validation):   0%|          | 0/8925 [00:00<?, ?it/s]

In [16]:
def context_to_tensors(nsubj_list, gender_list, tense_list):
    nsubj_tensor = torch.tensor(nsubj_list)
    gender_tensor = torch.tensor(gender_list, dtype=torch.float32)
    tense_tensor = torch.tensor(tense_list, dtype=torch.float32)
    
    context_tensors = [nsubj_tensor, gender_tensor, tense_tensor]
    return context_tensors

In [17]:
train_context_tensors = context_to_tensors(*train_context)
test_context_tensors = context_to_tensors(*test_context)
val_context_tensors = context_to_tensors(*val_context)

In [18]:
def cut_to_fit_batch(tensor: torch.Tensor, batch_size: int):
    n_samples = tensor.shape[1]
    new_n_samples = (n_samples // batch_size) * batch_size
    result, _ = tensor.split(new_n_samples, dim=1)
    return torch.transpose(result, 1, 0)

In [19]:
# class ContextMem(nn.Module):
#     def __init__(self, input_dim, output_dim, hidden_ff):
#         super(ContextMem, self).__init__()
#         self.fc_norm = nn.Linear(input_dim, hidden_ff)
#         self.hff     = nn.Linear(hidden_ff, output_dim)
#         self.fc_gate = nn.Linear(output_dim, output_dim)

#     def forward(self, context):
#         #context shape = (batch_size, input_dim=3)
#         context = self.fc_norm(context)
#         #context shape = (batch_size, hidden_ff)
#         context = self.hff(context)
#         #context shape = (batch_size, output_dim)
#         context_norm = F.tanh(context)
        
#         #context shape = (batch_size, output_dim)
#         context_gate = self.fc_gate(context)
#         #context_gate shape = (batch_size, output_dim)
#         context_gate = F.sigmoid(context)

#         return context_norm * context_gate

## Построение модели

In [20]:
class ContextMem(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, nsubj_embedding_size, device):
        super(ContextMem, self).__init__()
        
        self.device = device
        
        self.gender_proj = nn.Linear(input_size, hidden_size, bias=False)
        self.tense_proj = nn.Linear(input_size, hidden_size, bias=False)
        self.fc_out = nn.Linear(hidden_size * 2 + nsubj_embedding_size, output_size, bias=False)
        
    def forward(self, nsubj_embedding, gender, tense):
        # nsubj_embedding_shape: (batch_size, embedding_size)
        # gender_shape: (batch_size, input_size)
        # tense_shape: (batch_size, input_size)
        
        gender = self.gender_proj(gender)
        # gender_shape: (batch_size, hidden_size)
        
        tense = self.tense_proj(tense)
        # tense_shape: (batch_size, hidden_size)    
        
        context = torch.cat([nsubj_embedding, gender, tense], dim=1)
        # context_shape: (batch_size, hidden_size * 2 + embedding_size) 
        
        context = self.fc_out(context)
        # context_shape: (batch_size, output_size)
        
        return context

In [21]:
class EncoderRNN(nn.Module):
    def __init__(self,
                 vocab_size, embedding_size, hidden_size,
                 pad_idx, device, num_layers, dropout_p):
        
        super(EncoderRNN, self).__init__()
        
        self.device = device
        self.num_layers = num_layers
        
        self.hidden_size = hidden_size
        
        self.embedding = nn.Sequential(
            nn.Embedding(vocab_size, embedding_size, pad_idx),
            nn.Dropout(dropout_p)
        )
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=dropout_p)
        
    def forward(self, x, hidden, cell):
        # x_shape: (seq_len, batch_size)
        embedding = self.embedding(x)
        # embedding_shape: (seq_len, batch_size, embedding_size)
        output, (hidden, cell) = self.lstm(embedding, (hidden, cell))
        # output_shape: (seq_len, batch_size, hidden_size)
        # hidden_shape: (num_layers, batch_size, hidden_size)
        # cell_shape: (num_layers, batch_size, hidden_size)
        return hidden, cell

In [22]:
class DecoderRNN(nn.Module):
    def __init__(self,
                 vocab_size, embedding_size, hidden_size, output_size,
                 pad_idx, device, num_layers, dropout_p):
        
        super(DecoderRNN, self).__init__()
        
        self.device = device
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Sequential(
            nn.Embedding(vocab_size, embedding_size, pad_idx),
            nn.Dropout(dropout_p)
        )
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=dropout_p)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, hidden, cell):
        x = x.unsqueeze(0)
        # x_shape:      (seq_len=1, batch_size)
        # hidden_shape: (num_layers, batch_size, hidden_size)
        # cell_shape:   (num_layers, batch_size, hidden_size)
        
        embedding = self.embedding(x)
        # embedding_shape: (seq_len=1, batch_size, embedding_size)
        
        lstm_out, (hidden, cell) = self.lstm(embedding, (hidden, cell))
        # lstm_out_shape: (seq_len=1, batch_size, hidden_size)
        
        fc_out = self.fc(lstm_out)
        # fc_out_shape: (seq_len=1, batch_size, output_size)
        # output_shape: (seq_len=1, batch_size, output_size)
        
        return fc_out, hidden, cell

In [23]:
class Seq2SeqModel(nn.Module):
    def __init__(self,
                 encoder_vocab_size, decoder_vocab_size,
                 embedding_size, hidden_size, output_size,
                 context_input_size, context_hidden_size, context_output_size,
                 pad_idx, device, num_layers, dropout_p):
        
        super(Seq2SeqModel, self).__init__()
        
        self.device = device
        
        self.num_layers = num_layers
        self.decoder_vocab_size = decoder_vocab_size
        
        self.context_mem = ContextMem(context_input_size, context_hidden_size, context_output_size, embedding_size, device).to(device)
        self.encoder = EncoderRNN(encoder_vocab_size, embedding_size, hidden_size, pad_idx, device, num_layers, dropout_p).to(device)
        self.decoder = DecoderRNN(decoder_vocab_size, embedding_size, hidden_size, output_size, pad_idx, device, num_layers, dropout_p).to(device)
        
    def forward(self, input, target, context, teacher_forcing_ratio=0.5):
        batch_size = input.shape[1]
        target_len = target.shape[0]
        target_vocab_size = self.decoder_vocab_size
        
        outputs = torch.zeros(target_len, batch_size, target_vocab_size, device=self.device)

        nsubj, gender, tense = context
        # nsubj_shape:  (batch_size)
        # gender_shape: (batch_size, context_input_size)
        # tense_shape:  (batch_size, context_input_size)
        
        nsubj_embedding = self.decoder.embedding(nsubj).squeeze(0)
        # nsubj_embedding_shape: (batch_size, embedding_size)
        
        hidden = self.context_mem(nsubj_embedding, gender, tense)
        cell = hidden.clone()
        # hidden, cell shapes: (batch_size, context_output_size=hidden_size)
        
        if self.num_layers == 1:
            hidden.unsqueeze_(0)
            cell.unsqueeze_(0)
            # hidden, cell shapes: (1, batch_size, context_output_size=hidden_size)
        else:
            hidden = torch.cat([hidden.unsqueeze(0)] * self.num_layers, 0)
            cell = torch.cat([cell.unsqueeze(0)] * self.num_layers, 0)
            # hidden, cell shapes: (num_layers, batch_size, context_output_size=hidden_size)
        
        hidden, cell = self.encoder(input, hidden, cell)
        # hidden, cell shapes: (num_layers, batch_size, hidden_size)
        
        prev_token_idx = target[0]
        # prev_token_shape: (batch_size)
        
        for t in range(1, target_len):
            output, hidden, cell = self.decoder(prev_token_idx, hidden, cell)
            outputs[t] = output.squeeze(0)
            
            best_prediction = outputs[t].argmax(dim=1)
            # best_prediction_shape: (batch_size)
            prev_token_idx = target[t] if random() < teacher_forcing_ratio else best_prediction
        
        return outputs

## Обучение модели

### Функция сохранения текущего состояния модели

In [24]:
def save_model(model: Seq2SeqModel, optimizer, epoch, path):
    checkpoint = {
        'context_mem_state_dict': model.context_mem.state_dict(),
        'encoder_state_dict': model.encoder.state_dict(),
        'decoder_state_dict': model.decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'criterion': criterion,
        'epoch': epoch
    }
    
    torch.save(checkpoint, path)
#     with open(path, mode='wb') as f:
#         pickle.dump(checkpoint, f)

### Функция загрузки уже тренировавшейся модели

In [25]:
def load_model(model: Seq2SeqModel, optimizer, criterion, path, device=torch.device('cpu'), for_inference=True):
#     with open(path, mode='rb') as f:
#         checkpoint = pickle.load(f)
    checkpoint = torch.load(path, map_location=device)
    
    model.context_mem.load_state_dict(checkpoint['context_mem_state_dict'])
    model.encoder.load_state_dict(checkpoint['encoder_state_dict'])
    model.decoder.load_state_dict(checkpoint['decoder_state_dict'])
    
    if not for_inference:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        criterion = checkpoint['criterion']

        return epoch

### Инициализация гиперпараметров

In [26]:
learning_rate = 0.001
batch_size = 256
epochs_amount = 8
lemm_vocab_size = 35000
orig_vocab_size = 70000
hidden_size = 512
embedding_size = 300
num_layers = 2
max_norm = 1.0
dropout_p = 0.5
context_input_size = 4
context_hidden_size = 256
context_output_size = hidden_size
patience = 2
output_size = orig_vocab_size
pad_idx = lemm_text_transformer.special_tokens_to_idx.get('<PAD>')
model_path = './models/'
model_name = 'simple_seq2seq_with_context.model'

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
model = Seq2SeqModel(lemm_vocab_size, orig_vocab_size, embedding_size, hidden_size, output_size,
                     context_input_size, context_hidden_size, context_output_size,
                     pad_idx, device, num_layers, dropout_p).to(device)

In [29]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [30]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [31]:
try:
    epoch = load_model(model, optimizer, criterion, model_path + model_name, device)
    print(f'Loaded model from {model_path}')
except:
    print(f'No models found at {model_path}')
    epoch = 1

Loaded model from ./models/


### Урезание данных для соответствия размеру батча

In [32]:
train_lemm_tensor_f = cut_to_fit_batch(train_lemm_tensor, batch_size)
train_orig_tensor_f = cut_to_fit_batch(train_orig_tensor, batch_size)

test_lemm_tensor_f = cut_to_fit_batch(test_lemm_tensor, batch_size)
test_orig_tensor_f = cut_to_fit_batch(test_orig_tensor, batch_size)

val_lemm_tensor_f = cut_to_fit_batch(val_lemm_tensor, batch_size)
val_orig_tensor_f = cut_to_fit_batch(val_orig_tensor, batch_size)

In [33]:
train_context_tensors_f = [cut_to_fit_batch(tensor.unsqueeze(0), batch_size).squeeze(1) for tensor in train_context_tensors]
test_context_tensors_f = [cut_to_fit_batch(tensor.unsqueeze(0), batch_size).squeeze(1) for tensor in test_context_tensors]
val_context_tensors_f = [cut_to_fit_batch(tensor.unsqueeze(0), batch_size).squeeze(1) for tensor in val_context_tensors]

### Инициализация данных итерируемых по батчам

In [34]:
train_dataset = TensorDataset(train_lemm_tensor_f, train_orig_tensor_f, *train_context_tensors_f)
test_dataset = TensorDataset(test_lemm_tensor_f, test_orig_tensor_f, *test_context_tensors_f)
val_dataset = TensorDataset(val_lemm_tensor_f, val_orig_tensor_f, *val_context_tensors_f)

In [35]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

### Определение функции проверки работы сети между эпохами обучения

In [36]:
def test_evaluate(model, input, context, target_len=8):
    with torch.no_grad():
        model.eval()
        
        input = input.to(device)

        nsubj, gender, tense = context
        nsubj_embedding = model.decoder.embedding(nsubj)

        hidden = model.context_mem(nsubj_embedding, gender, tense)
        cell = hidden.clone()

        if model.num_layers == 1:
            hidden.unsqueeze_(0)
            cell.unsqueeze_(0)
            # hidden, cell shapes: (1, batch_size, context_output_size=hidden_size)
        else:
            hidden = torch.cat([hidden.unsqueeze(0)] * model.num_layers, 0)
            cell = torch.cat([cell.unsqueeze(0)] * model.num_layers, 0)
            # hidden, cell shapes: (num_layers, batch_size, context_output_size=hidden_size)

        sos_idx = lemm_text_transformer.special_tokens_to_idx.get('<SOS>')
        eos_idx = lemm_text_transformer.special_tokens_to_idx.get('<EOS>')

        hidden, cell = model.encoder(input, hidden, cell)

        predicted_indexes = [sos_idx]

        for _ in range(1, target_len):
            prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)

            output, hidden, cell = model.decoder(prev_idx, hidden, cell)
            output = output.squeeze(0)

            best_prediction = output.argmax(dim=1).item()

            if best_prediction == eos_idx:
                break

            predicted_indexes.append(best_prediction)


        predicted_tokens = [orig_text_transformer.vocab.idx_to_token(idx) for idx in predicted_indexes]
        return predicted_tokens[1:]

### Определение функции обучения сети

In [37]:
def train(model, optimizer, criterion, train_data, val_data, test_data, epochs_amount, max_norm, context, patience=3, current_epoch=1, n_prints=5):
    min_mean_val_loss = float('+inf')
    initial_patiece = patience
    print_every = len(train_data) // n_prints
    
    for epoch in tqdm(range(current_epoch, epochs_amount + 1), 'Epochs'):
        print(f'\nEpoch [{epoch} / {epochs_amount}]')
        
        model.train()
        for iteration, (input, target, nsubj, gender, tense) in enumerate(tqdm(train_data, 'Epoch training iterations')):
            optimizer.zero_grad()
            
            input = torch.transpose(input, 1, 0).to(device)   
            # input_shape: (seq_len, batch_size)
            
            target = torch.transpose(target, 1, 0).to(device)
            # target_shape: (seq_len, batch_size)
            
            context = (nsubj.to(device), gender.to(device), tense.to(device))
            
            output = model(input, target, context)
            # output_shape: (seq_len, batch_size, orig_vocab_size) but need (N, orig_vocab_size)
            
            target = target[1:].reshape(-1)
            # now target_shape is (seq_len * batch_size)
            
            orig_vocab_size = output.shape[2]
            
            output = output[1:].reshape(-1, orig_vocab_size)
            # now output_shape is (seq_len * batch_size, orig_vocab_size)
            
            loss = criterion(output, target)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            
            optimizer.step()
            
            if iteration % print_every == 0:
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            elif iteration == len(train_data):
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            
            
        with torch.no_grad():
            model.eval()
            val_loss = []
            
            for input, target, nsubj, gender, tense in tqdm(val_data, 'Epoch validating iterations'):
                input = torch.transpose(input, 1, 0).to(device)
                target = torch.transpose(target, 1, 0).to(device)
                context = (nsubj.to(device), gender.to(device), tense.to(device))
                output = model(input, target, context)
                
                orig_vocab_size = output.shape[2]
                output = output[1:].reshape(-1, orig_vocab_size)
                target = target[1:].reshape(-1)
                
                val_loss.append(criterion(output, target).item())
            
            mean_val_loss = sum(val_loss) / len(val_loss)
            print(f'\tValidation loss = {mean_val_loss}')
            if mean_val_loss < min_mean_val_loss:
                try:
                    save_model(model, optimizer, epoch, model_path + model_name)
                    min_mean_val_loss = mean_val_loss
                    patience = initial_patiece
                except Exception as exc:
                    print(exc)
            else:
                patience -= 1
            
            test_data = DataLoader(test_data.dataset, batch_size=1, shuffle=True)
            for input, target, nsubj, gender, tense in test_data:
                target = target.squeeze(0).to(device)
                context = (nsubj.to(device), gender.to(device), tense.to(device))
                
                input = torch.transpose(input, 1, 0)
                target_len = target.shape[0]
                
                output = test_evaluate(model, input, context, target_len)
                decoded_input = [lemm_text_transformer.vocab.idx_to_token(idx.item()) for idx in input]
                decoded_target = [orig_text_transformer.vocab.idx_to_token(idx.item()) for idx in target]
                
                print(f'\tInput: {decoded_input}')
                print(f'\tOutput: {output}')
                print(f'\tTarget: {decoded_target}')
                break
        
        if patience == 0:
            print(f'\nModel learning finished due to early stopping')
            break


### Определение функции эксплуатации обученной модели

In [38]:
def evaluate(model, sentence: str, context, max_seq_len=8):
    with torch.no_grad():
        model.eval()
        
        nsubj, gender, tense = context
        
        nsubj = torch.tensor([orig_text_transformer.vocab.token_to_idx(nsubj)], device=device).unsqueeze(0)
        gender = torch.tensor([gender_label_to_vec[gender]], dtype=torch.float32, device=device)
        tense = torch.tensor([tense_label_to_vec[tense]], dtype=torch.float32, device=device)
        
        nsubj_embedding = model.decoder.embedding(nsubj).squeeze(0)

        hidden = model.context_mem(nsubj_embedding, gender, tense)
        cell = hidden.clone()

        if model.num_layers == 1:
            hidden.unsqueeze_(0)
            cell.unsqueeze_(0)
            # hidden, cell shapes: (1, batch_size, context_output_size=hidden_size)
        else:
            hidden = torch.cat([hidden.unsqueeze(0)] * model.num_layers, 0)
            cell = torch.cat([cell.unsqueeze(0)] * model.num_layers, 0)
            # hidden, cell shapes: (num_layers, batch_size, context_output_size=hidden_size)
        
        input_tensor = lemm_text_transformer.text_to_tensor(sentence, max_seq_len).to(device)
        input_tensor = torch.transpose(input_tensor, 1, 0)
        sos_idx = lemm_text_transformer.special_tokens_to_idx.get('<SOS>')
        eos_idx = lemm_text_transformer.special_tokens_to_idx.get('<EOS>')

    
        hidden, cell = model.encoder(input_tensor, hidden, cell)
        
        predicted_indexes = [sos_idx]
        
#         while True:
#             prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)
            
#             output, hidden, cell = model.decoder(prev_idx, hidden, cell)
#             output = output.squeeze(0)
            
#             best_prediction = output.argmax(dim=1).item()
            
#             if best_prediction == eos_idx:
#                 break
            
#             predicted_indexes.append(best_prediction)
                       
        
        for _ in range(1, max_seq_len):
            prev_idx = torch.tensor([predicted_indexes[-1]], dtype=torch.long, device=device)
            
            output, hidden, cell = model.decoder(prev_idx, hidden, cell)
            output = output.squeeze(0)
            
            best_prediction = output.argmax(dim=1).item()
            
            if best_prediction == eos_idx:
                break
                
            predicted_indexes.append(best_prediction)
        
    predicted_tokens = [orig_text_transformer.vocab.idx_to_token(idx) for idx in predicted_indexes]
    return predicted_tokens[1:]

In [39]:
# train(model, optimizer, criterion, train_loader, val_loader, test_loader, epochs_amount, max_norm, patience, epoch)

In [40]:
# load_model(model, optimizer, criterion, model_path + model_name)
# model.eval()

In [41]:
# import gc
# del model
# del optimizer
# gc.collect()
# torch.cuda.empty_cache()
# gc.collect()

In [42]:
gender_label_to_vec = {
    'masc'     : [0, 1, 0, 0],
    'fem'      : [1, 0, 0, 0],
    'undefined': [0, 0, 0, 1],
    'neut'     : [0, 0, 1, 0]
}

tense_label_to_vec = {
    'past'     : [0, 1, 0, 0],
    'pres'     : [0, 0, 1, 0],
    'fut'      : [1, 0, 0, 0],
    'undefined': [0, 0, 0, 1]
}

In [43]:
def vec_to_label(label_type: str, input_vec):
    if label_type == 'gender':
        for label, vec in gender_label_to_vec.items():
            if vec == input_vec:
                return label
        
    elif label_type == 'tense':
        for label, vec in tense_label_to_vec.items():
            if vec == input_vec:
                return label
    else:
        return None

In [44]:
test_sample = train_df.sample(50)

test_input = test_sample.lemm_texts.to_list()
test_target = test_sample.orig_texts.to_list()

test_nsubj = test_sample.nsubj.to_list()
test_gender = test_sample.gender.to_list()
test_tense = test_sample.tense.to_list()

test_zipped = list(zip(test_input, test_target, test_nsubj, test_gender, test_tense))

In [47]:
state_dict = torch.load('./models/params.model')

In [65]:
model.state_dict()

OrderedDict([('context_mem.gender_proj.weight',
              tensor([[-0.4963,  0.0215, -0.7154, -0.2743],
                      [-0.2054,  0.0732, -0.3534,  0.0271],
                      [ 0.1722,  0.1258,  0.0351, -0.1378],
                      ...,
                      [-0.2794, -0.2679, -0.5680, -0.0552],
                      [ 0.0308,  0.2176, -0.1868,  0.1538],
                      [-0.3556,  0.2144,  0.3736,  0.1786]], device='cuda:0')),
             ('context_mem.tense_proj.weight',
              tensor([[-0.6928, -0.0233,  0.4835, -0.2829],
                      [ 0.5565, -0.0669, -0.0504,  0.6564],
                      [ 0.4251,  0.0579,  0.4597, -0.3516],
                      ...,
                      [-0.2581,  0.1273, -0.1396, -0.1242],
                      [ 0.6218, -0.0855,  0.4336,  0.4036],
                      [ 0.1899,  0.0572,  0.1556,  0.0835]], device='cuda:0')),
             ('context_mem.fc_out.weight',
              tensor([[-0.0824, -0.0193,  0.0476

In [64]:
for k, v in state_dict.items():
    res = torch.all(v == model.state_dict()[k]).item()
    print(k, res)

context_mem.gender_proj.weight True
context_mem.tense_proj.weight True
context_mem.fc_out.weight True
encoder.embedding.0.weight True
encoder.lstm.weight_ih_l0 True
encoder.lstm.weight_hh_l0 True
encoder.lstm.bias_ih_l0 True
encoder.lstm.bias_hh_l0 True
encoder.lstm.weight_ih_l1 True
encoder.lstm.weight_hh_l1 True
encoder.lstm.bias_ih_l1 True
encoder.lstm.bias_hh_l1 True
decoder.embedding.0.weight True
decoder.lstm.weight_ih_l0 True
decoder.lstm.weight_hh_l0 True
decoder.lstm.bias_ih_l0 True
decoder.lstm.bias_hh_l0 True
decoder.lstm.weight_ih_l1 True
decoder.lstm.weight_hh_l1 True
decoder.lstm.bias_ih_l1 True
decoder.lstm.bias_hh_l1 True
decoder.fc.weight True
decoder.fc.bias True


In [45]:
for input_sent, target_sent, nsubj, gender, tense in test_zipped:
    output = evaluate(model, input_sent, (nsubj, gender, tense))
    
    print(f'Input: {input_sent}')
    print(f'Nsubj: {nsubj}\ngender: {gender}\ntense: {tense}')
    print(f'Output: {output}')
    print(f'Target: {target_sent}')
    print('\n')

Input: я работать на россия.
Nsubj: я
gender: fem
tense: past
Output: ['я', 'уехал', 'на', '<UNK>', '.']
Target: я работала на россию.


Input: я часто ездить к мать.
Nsubj: я
gender: undefined
tense: pres
Output: ['я', 'официант', 'роль', 'к', 'первый', '.']
Target: я часто езжу к матери.


Input: ленка всегда спасть с какой - нибыть железяка.
Nsubj: ленка
gender: fem
tense: past
Output: ['<UNK>', 'удивленно', 'ника', 'с', 'ногу', '-', '-']
Target: ленка всегда спала с какой-нибудь железякой.


Input: христос стоять на колено ...
Nsubj: христос
gender: masc
tense: past
Output: ['христос', 'шел', 'на', 'досады', '...']
Target: христос стоял на коленях...


Input: как я скучать , как скучать!
Nsubj: я
gender: fem
tense: past
Output: ['оказался', 'я', 'вздохнул', ',', 'оказался', 'правилам', '!']
Target: как я скучала, как скучала!


Input: кэти нажать на курок.
Nsubj: кэти
gender: fem
tense: past
Output: ['вивиана', '<UNK>', 'на', 'рукам', '.']
Target: кэти нажала на курок.


Input: я в