In [1]:
import pandas as pd
import nltk
import torch
import torch.nn as nn
import torch.optim
import numpy as np
import time
import pickle

from pprint import pprint
from random import random, sample
from typing import List
from collections import Counter
from itertools import chain
from functools import reduce
from tqdm.auto import tqdm
from sklearn import model_selection
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchtext.data.metrics import bleu_score
from transformers import AutoTokenizer, AutoModel

In [2]:
RANDOM_STATE = 1

## Подготовка данных

In [3]:
df = pd.read_csv('./data/lenta/dataset.csv')

In [4]:
df = df.sample(frac=0.5, random_state=RANDOM_STATE)
df

Unnamed: 0,orig_texts,lemm_texts,nsubj,gender,tense
1245806,об этом сообщает риа новости со ссылкой на мат...,о это сообщать риа новость с ссылка на мать по...,риа,neut,pres
1594042,генеральный прокурор рф владимир устинов счита...,генеральный прокурор рф владимир устинов счита...,прокурор,masc,pres
705659,"телеканал «дождь» восстановил вещание, прерван...","телеканал « дождь » восстановить вещание, прер...",телеканал,masc,past
603796,соответствующее требование прозвучало во время...,соответствующий требование прозвучать в время ...,требование,neut,past
1430273,"в пятницу вечером на сайтах ""единой россии"", ""...","в пятница вечером на сайт ""единый россия"", ""гр...",заявления,neut,past
...,...,...,...,...,...
1331395,об этом заявил президент - председатель правле...,о это заявить президент - председатель правлен...,президент,masc,past
1710395,"""мы думаем, что они (страны-члены совбеза) пре...","\"" мы думать, что они (страна-член совбез) пре...",мы,undefined,pres
1434398,"социологи ""росгосстраха"" оценили сознательност...","социолог ""росгосстрах""оценить сознательность р...",социологи,masc,past
1832873,российская сборная сохранила за собой 24 строчку.,российский сборная сохранить за себя 24 строчка.,сборная,fem,past


### Загрузка готового токенайзера

In [5]:
tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")

In [6]:
added_tokens_amount = tokenizer.add_special_tokens({
    'bos_token': '[BOS]',
    'eos_token': '[EOS]'
})

### Разбиение данных на тренировочные, валидационные, тестовые

In [7]:
train_df, test_df = model_selection.train_test_split(df, test_size=0.1, random_state=RANDOM_STATE)

In [8]:
test_df, val_df = model_selection.train_test_split(test_df, test_size=0.5, random_state=RANDOM_STATE)

### Инициализация гиперпараметров

In [9]:
learning_rate = 0.001
batch_size = 64
epochs_amount = 50
hidden_size = 768
embedding_size = 300
max_norm = 1.0
dropout_p = 0.5
gender_input_size = 4
tense_input_size = 3
context_hidden_size = 256
context_output_size = hidden_size
patience = 3
output_size = tokenizer.vocab_size # 119547
vocab_size = output_size
pad_idx = tokenizer.pad_token_id
model_path = './models/'
model_name = 'seq2seq_attention (fixed).model'

### One-hot кодирование значений рода и времени

In [10]:
gender_to_vec = {
    'masc': [1, 0, 0, 0],
    'fem': [0, 1, 0, 0],
    'neut': [0, 0, 1, 0],
    'undefined': [0, 0, 0, 1]
}

In [11]:
tense_to_vec = {
    'pres': [1, 0, 0],
    'past': [0, 1, 0],
    'fut': [0, 0, 1]
}

### Разбиение текстов на батчи

In [12]:
def get_n_batches(df, batch_size=batch_size):
    return len(df) // batch_size

In [13]:
train_n_batches = get_n_batches(train_df)
test_n_batches = get_n_batches(test_df)
val_n_batches = get_n_batches(val_df)

In [18]:
def make_batches(df, n_batches, batch_size=batch_size, cache_path='./data/cached/train_batches_64.btch', download_cache=True):
    try:
        if not download_cache:
            raise Exception
            
        with open(cache_path, 'rb') as f:
            batches = pickle.load(f)
            print(f'Batches are loaded from cache: {cache_path}')
    except:
        batches = []
    
        for n_batch in tqdm(range(n_batches), 'Making batches'):
            batch = {
                'orig_texts': df.orig_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)],
                'lemm_texts': df.lemm_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)],
                'nsubj': df.nsubj.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)],
                'gender': df.gender.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)],
                'tense': df.tense.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
            }

            batches.append(batch)
    
    return batches

In [15]:
train_batches = make_batches(train_df, train_n_batches, download_cache=True)
test_batches = make_batches(test_df, test_n_batches, download_cache=False)
val_batches = make_batches(val_df, val_n_batches, download_cache=False)

Batches are loaded from cache: ./data/cached/train_batches_64.btch


HBox(children=(FloatProgress(value=0.0, description='Making batches', max=723.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Making batches', max=723.0, style=ProgressStyle(descripti…




In [16]:
# with open('./data/cached/train_batches_64.btch', 'wb') as f:
#     pickle.dump(train_batches, f)

### Токенизация текстов, индексация токенов, batch padding

In [19]:
def text_batch_to_tensor(text_batch, tokenizer=tokenizer, add_special_tokens=True):
    tensor = tokenizer(
        text_batch,
        add_special_tokens=add_special_tokens,
        padding='longest',
        return_attention_mask=False,
        return_token_type_ids=False,
        return_tensors='pt'
    ).get('input_ids')
    
    tensor[tensor == tokenizer.cls_token_id] = tokenizer.bos_token_id
    tensor[tensor == tokenizer.sep_token_id] = tokenizer.eos_token_id

    return tensor

In [26]:
def to_tensors(batches, tokenizer=tokenizer):
#     tensors_batches = []
    
    for batch in batches:
        orig_tensor = text_batch_to_tensor(batch.get('orig_texts'))
        lemm_tensor = text_batch_to_tensor(batch.get('lemm_texts'))
        
#         nsubj_tensor = text_batch_to_tensor(batch.get('nsubj'), add_special_tokens=False)
        nsubj_tensor = tokenizer(list(chain(*[batch.get('nsubj') for batch in batches])),
                                 add_special_tokens=False, padding='longest',
                                 return_tensors='pt').get('input_ids')
    
        gender_tensor = torch.tensor([gender_to_vec.get(gender) for gender in batch.get('gender')],
                                     dtype=torch.float32)
        tense_tensor = torch.tensor([tense_to_vec.get(tense) for tense in batch.get('tense')],
                                    dtype=torch.float32)
        
        tensors_batch = [orig_tensor, lemm_tensor,
                         nsubj_tensor, gender_tensor, tense_tensor]
        
        yield tensors_batch

In [27]:
train_tensors = to_tensors(train_batches)
test_tensors = to_tensors(test_batches)
val_tensors = to_tensors(val_batches)

In [28]:
sample = next(train_tensors)

KeyboardInterrupt: 

## Построение модели

In [20]:
class ContextMem(nn.Module):
    def __init__(self, gender_input_size, tense_input_size, hidden_size, output_size, nsubj_embedding_size, device):
        super(ContextMem, self).__init__()
        
        self.device = device

        self.gender_proj = nn.Linear(gender_input_size, hidden_size, bias=False)
        self.tense_proj = nn.Linear(tense_input_size, hidden_size, bias=False)
        self.fc_out = nn.Linear(hidden_size * 2 + nsubj_embedding_size, output_size, bias=False)
        
    def forward(self, nsubj_embedding, gender, tense):
        # nsubj_embedding_shape: (token_len, batch_size, embedding_size)
        # gender_shape: (batch_size, input_size)
        # tense_shape: (batch_size, input_size)
        print('nsubj shape:', nsubj_embedding.shape)
        token_len = nsubj_embedding.shape[0]
        
        gender = self.gender_proj(gender).repeat(token_len, 1, 1)
        # gender_shape: (nsubj_len, batch_size, hidden_size)
        print('gender shape:', gender.shape)
        
        tense = self.tense_proj(tense).repeat(token_len, 1, 1)
        # tense_shape: (nsubj_len, batch_size, hidden_size)    
        print('tense shape:', tense.shape)
        
        context = torch.cat([nsubj_embedding, gender, tense], dim=-1)
        # context_shape: (nsubj_len, batch_size, hidden_size * 2 + embedding_size) 
        
        context = self.fc_out(context)
        # context_shape: (nsubj_len, batch_size, output_size)
        
        return context

In [21]:
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, hidden_size: int, pad_idx: int,
                 device, dropout_p: float, embedding=None, pretrained_embedding_loaded=False):
        super(EncoderRNN, self).__init__()
        
        self.device = device
        
        self.hidden_size = hidden_size
        
        self.embedding = embedding
        self.pretrained_embedding_loaded = pretrained_embedding_loaded
    
        self.rnn = nn.GRU(embedding_size, hidden_size, dropout=0.0, bidirectional=True)
        
    def forward(self, sequence, hidden):
        # sequence_shape: (seq_len, batch_size)
        if self.pretrained_embedding_loaded:
            with torch.no_grad():
                embedding = self.embedding(sequence)
        else:
            embedding = self.embedding(sequence)
        # embedding_shape: (seq_len, batch_size, embedding_size)
        encoder_states = self.rnn(embedding, hidden)[0]
        # encoder_states: (seq_len, batch_size, hidden_size * 2)
        
        return encoder_states

In [22]:
class DecoderRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, hidden_size: int, output_size: int, pad_idx: int,
                 device, dropout_p: float, embedding=None, pretrained_embedding_loaded=False):
        super(DecoderRNN, self).__init__()
        
        self.device = device
        
        self.hidden_size = hidden_size
        
        self.embedding = embedding
        self.pretrained_embedding_loaded = pretrained_embedding_loaded
        
        self.attn_weights = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size, bias=False),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False),
            nn.Softmax(dim=1)
        )
        self.rnn = nn.GRU(embedding_size + 2 * hidden_size, hidden_size, dropout=0.0)
        self.fc_out = nn.Linear(hidden_size, output_size)
    
    def forward(self, token, encoder_states):
        # token_shape: (token_len, batch_size)
        
        encoder_states = torch.transpose(encoder_states, 1, 0)
        # encoder_states_shape: (batch_size, seq_len, hidden_size * 2)
        if self.pretrained_embedding_loaded:
            with torch.no_grad():
                embedding = self.embedding(token)
        else:
            embedding = self.embedding(token)
        # embedding_shape: (token_len, batch_size, embedding_size)
        
        seq_len = encoder_states.shape[1]
        token_len = token.shape[0]
        
        attn_weights = self.attn_weights(encoder_states)
        # attn_weights_shape: (batch_size, seq_len, 1)
        
        context_vec = torch.bmm(attn_weights.permute(0, 2, 1), encoder_states).permute(1, 0, 2).repeat(token_len, 1, 1)
        # context_vec_shape: (token_len, batch_size, hidden_size * 2)
        
        combined = torch.cat((embedding, context_vec), dim=2)
        # combined_shape: (token_len, batch_size, embedding_size + 2 * hidden_size)
        
        rnn_out = self.rnn(combined)[0]
        # lstm_out_shape: (seq_len=1, batch_size, hidden_size)
        fc_out = self.fc_out(rnn_out)
        # fc_out_shape: (seq_len=1, batch_size, output_size)
        
        return fc_out

In [23]:
class Seq2SeqModel(nn.Module):
    def __init__(self, 
                 vocab_size, embedding_size, hidden_size, output_size,
                 gender_input_size, tense_input_size, context_hidden_size, context_output_size,
                 pad_idx, device, dropout_p, pretrained_embedding=None):
        super(Seq2SeqModel, self).__init__()
        
        self.device = device
        
        if pretrained_embedding is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embedding, padding_idx=pad_idx)
            self.pretrained_embedding_loaded = True
        else:
            self.embedding = nn.Sequential(
                nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx),
                nn.Dropout(dropout_p)
            )
            self.pretrained_embedding_loaded = False
        
        self.context_mem = ContextMem(gender_input_size, tense_input_size, context_hidden_size, context_output_size, embedding_size, device).to(device)
        self.encoder = EncoderRNN(vocab_size, embedding_size, hidden_size,
                                  pad_idx, device, dropout_p,
                                  self.embedding, self.pretrained_embedding_loaded).to(device)
        self.decoder = DecoderRNN(vocab_size, embedding_size, hidden_size, output_size,
                                  pad_idx, device, dropout_p,
                                  self.embedding, self.pretrained_embedding_loaded).to(device)
        
        self.vocab_size = vocab_size
        
    def forward(self, input, target, context, teacher_forcing_ratio=0.0):
        batch_size = input.shape[1]
        target_len = target.shape[0]
        target_vocab_size = self.vocab_size
        
        outputs = torch.zeros(target_len, batch_size, target_vocab_size, device=self.device)
        
        nsubj, gender, tense = context
        nsubj = nsubj.permute(1, 0)
        # nsubj_shape:  (token_len, batch_size)
        # gender_shape: (batch_size, gender_input_size)
        # tense_shape:  (batch_size, tense_input_size)
        
        if self.pretrained_embedding_loaded:
            with torch.no_grad():
                nsubj_embedding = self.embedding(nsubj)
        else:
            nsubj_embedding = self.embedding(nsubj)
            # nsubj_embedding_shape: (token_len, batch_size, embedding_size)
        
        hidden = self.context_mem(nsubj_embedding, gender, tense)[:2]
        # hidden_shape: (2, batch_size, context_output_size=hidden_size)
        
        encoder_states = self.encoder(input, hidden)
        # encoder_states_shape: (seq_len, batch_size, hidden_size * 2)
        
        prev_token_idx = target[0] # sos_token
        # prev_token_shape: (batch_size)
        
        for t in range(1, target_len):
            output, hidden, cell = self.decoder(prev_token_idx, encoder_states)
            # output_shape: (1, batch_size, output_size)
            outputs[t] = output.squeeze(0)
            
            best_prediction = outputs[t].argmax(dim=1)
            # best_prediction_shape: (batch_size)
            prev_token_idx = target[t] if random() < teacher_forcing_ratio else best_prediction
        
        return outputs

## Обучение модели

### Функция сохранения текущего состояния модели

In [24]:
def save_model(model, optimizer, epoch, path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'criterion': criterion,
        'epoch': epoch
    }
    
    torch.save(checkpoint, path)

### Функция загрузки уже тренировавшейся модели

In [25]:
def load_model(model, optimizer, criterion, path, for_inference=True, device=torch.device('cpu')):
    checkpoint = torch.load(path, map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    
    if not for_inference:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        criterion = checkpoint['criterion']

        return epoch

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
model = Seq2SeqModel(vocab_size, embedding_size, hidden_size, output_size,
                     gender_input_size, tense_input_size, context_hidden_size, context_output_size, 
                     pad_idx, device, dropout_p).to(device)

In [28]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [29]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [30]:
try:
    epoch = load_model(model, optimizer, criterion, model_path + model_name, for_inference=True)
    print(f'Loaded model from {model_path}')
except:
    print(f'No models found at {model_path}')
    epoch = 1

No models found at ./models/


### Определение функции обучения сети

In [31]:
def train(model, optimizer, criterion, train_data, val_data, test_data, epochs_amount, max_norm, patience=3, current_epoch=1, n_prints=5):
    min_mean_val_loss = float('+inf')
    initial_patiece = patience
    print_every = len(train_df) // n_prints
    
    for epoch in tqdm(range(current_epoch, epochs_amount + 1), 'Epochs'):
        print(f'\nEpoch [{epoch} / {epochs_amount}]')
        
        model.train()
        for iteration, (input, target, nsubj, gender, tense) in enumerate(tqdm(train_data, 'Epoch training iterations')):
            optimizer.zero_grad()
            # input = lemm_texts, target = orig_texts
            
            input = torch.transpose(input, 1, 0).to(device)
            # input_shape: (seq_len, batch_size)
            
            target = torch.transpose(target, 1, 0).to(device)
            # target_shape: (seq_len, batch_size)
            
            context = (nsubj.to(device), gender.to(device), tense.to(device))
            
            output = model(input, target, context)
            # output_shape: (seq_len, batch_size, vocab_size) but need (N, vocab_size)
            
            target = target[1:].reshape(-1)
            # now target_shape is (seq_len * batch_size)
            
            vocab_size = output.shape[2]
            
            output = output[1:].reshape(-1, vocab_size)
            # now output_shape is (seq_len * batch_size, vocab_size)
            
            loss = criterion(output, target)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            
            optimizer.step()
            
            if iteration % print_every == 0:
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            elif iteration == len(train_data):
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            
            
        with torch.no_grad():
            model.eval()
            val_loss = []
            
            for input, target, nsubj, gender, tense in tqdm(val_data, 'Epoch validating iterations'):
                input = torch.transpose(input, 1, 0).to(device)
                target = torch.transpose(target, 1, 0).to(device)
                context = (nsubj.to(device), gender.to(device), tense.to(device))
                
                output = model(input, target, context)
                vocab_size = output.shape[2]
                output = output[1:].reshape(-1, orig_vocab_size)
                target = target[1:].reshape(-1)
                
                val_loss.append(criterion(output, target).item())
            
            mean_val_loss = sum(val_loss) / len(val_loss)
            print(f'\tValidation loss = {mean_val_loss}')
            if mean_val_loss < min_mean_val_loss:
                try:
                    save_model(model, optimizer, epoch, model_path + model_name)
                    min_mean_val_loss = mean_val_loss
                    patience = initial_patiece
                except Exception as exc:
                    print(exc)
            else:
                patience -= 1
            
            test_data = DataLoader(test_data.dataset, batch_size=1, shuffle=True)
            for input, target, nsubj, gender, tense in test_data:
                target = target.squeeze(0).to(device)
                context = (nsubj.to(device), gender.to(device), tense.to(device))
                
                input = torch.transpose(input, 1, 0)
                target_len = target.shape[0]
                
                output = test_evaluate(model, input, context, target_len)
                decoded_input = [text_transformer.vocab.idx_to_token(idx.item()) for idx in input]
                decoded_target = [text_transformer.vocab.idx_to_token(idx.item()) for idx in target]
                
                print(f'\tInput: {decoded_input}')
                print(f'\tOutput: {output}')
                print(f'\tTarget: {decoded_target}')
                break
        
        if patience == 0:
            print(f'\nModel learning finished due to early stopping')
            break


In [32]:
train(model, optimizer, criterion, train_tensors, val_tensors, test_tensors, epochs_amount, max_norm, patience, epoch)

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=50.0, style=ProgressStyle(description_width=…


Epoch [1 / 50]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch training iterations', layout=Layo…

nsubj shape: torch.Size([3, 64, 300])
gender shape: torch.Size([3, 64, 256])
tense shape: torch.Size([3, 64, 256])




RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED