## Импорт необходимых зависимостей

In [2]:
import pandas as pd
import nltk
import torch
import torch.nn as nn
import torch.optim
import numpy as np
import time
import pickle
import pathlib

from pprint import pprint
from random import random, sample
from typing import List, Union
from collections import Counter
from itertools import chain
from functools import reduce
from tqdm.auto import tqdm
from sklearn import model_selection
from torch.utils.data import DataLoader, TensorDataset

In [3]:
nltk.download('punkt', quiet=True);

In [4]:
RANDOM_STATE = 1

## Подготовка данных

In [5]:
df = pd.read_csv('./data/lenta/dataset.csv')

In [6]:
df = df.sample(n=50000, random_state=RANDOM_STATE)
df

Unnamed: 0,orig_texts,lemm_texts,nsubj,gender,tense
1245806,об этом сообщает риа новости со ссылкой на мат...,о это сообщать риа новость с ссылка на мать по...,риа,neut,pres
1594042,генеральный прокурор рф владимир устинов счита...,генеральный прокурор рф владимир устинов счита...,прокурор,masc,pres
705659,"телеканал «дождь» восстановил вещание, прерван...","телеканал « дождь » восстановить вещание, прер...",телеканал,masc,past
603796,соответствующее требование прозвучало во время...,соответствующий требование прозвучать в время ...,требование,neut,past
1430273,"в пятницу вечером на сайтах ""единой россии"", ""...","в пятница вечером на сайт ""единый россия"", ""гр...",заявления,neut,past
...,...,...,...,...,...
1741764,в россии может появиться новый деловой телевиз...,в россия мочь появиться новый деловой телевизи...,канал,masc,pres
243076,пресс-секретарь президента эрнесто абелла назв...,пресс-секретарь президент эрнесто абелла назва...,пресс-секретарь,masc,past
1809379,"а вот винсент дамфусс, который должен был игра...","а вот винсент дамфусс, который должный быть иг...",винсент,masc,fut
690325,"задача исследователей заключается в том, чтобы...","задача исследователь заключаться в тот, чтобы ...",задача,fem,pres


### Определение классов словаря и подготовщика данных

In [7]:
class Vocab:
    def __init__(self, tokens=None, unk_id=None):
        self.unk_id = unk_id
        
        self.tokens = tokens
        self.tokens_to_ids = {token: id for id, token in enumerate(tokens)} if tokens is not None else None
    
    def id_to_token(self, id):
        return self.tokens[id]
    
    def token_to_id(self, token):
        return self.tokens_to_ids.get(token, self.unk_id)

In [8]:
class DataPreperator:
    def __init__(self, vocab_size):
        
        self.tokenizer = nltk.tokenize.word_tokenize
        
        self.special_tokens = {
            'pad_token'        : '<pad>',
            'unk_token'        : '<unk>',
            'sos_token'        : '<sos>',
            'eos_token'        : '<eos>',
            'g_masc_token'     : '<masc>',
            'g_fem_token'      : '<fem>',
            'g_neut_token'     : '<neut>',
            'g_undefined_token': '<undef>',
            't_past_token'     : '<past>',
            't_pres_token'     : '<pres>',
            't_fut_token'      : '<fut>'
        }
        
        self.special_ids = {token: id for id, token in enumerate(self.special_tokens.keys())}
        
        self.pad_token = {'id'   : self.special_ids['pad_token'],
                          'token': self.special_tokens['pad_token']}
        
        self.unk_token = {'id'   : self.special_ids['unk_token'],
                          'token': self.special_tokens['unk_token']}
        
        self.sos_token = {'id'   : self.special_ids['sos_token'],
                          'token': self.special_tokens['sos_token']}
        
        self.eos_token = {'id'   : self.special_ids['eos_token'],
                          'token': self.special_tokens['eos_token']}
        
        self.gender_tokens = {
            'masc'     : {'id'   : self.special_ids['g_masc_token'],
                          'token': self.special_tokens['g_masc_token']},
            
            'fem'      : {'id'   : self.special_ids['g_fem_token'],
                          'token': self.special_tokens['g_fem_token']},
            
            'neut'     : {'id'   : self.special_ids['g_neut_token'],
                          'token': self.special_tokens['g_neut_token']},
            
            'undefined': {'id'   : self.special_ids['g_undefined_token'],
                          'token': self.special_tokens['g_undefined_token']}
        }
        
        self.tense_tokens = {
            'past': {'id'   : self.special_ids['t_past_token'],
                     'token': self.special_tokens['t_past_token']},
            
            'pres': {'id'   : self.special_ids['t_pres_token'],
                     'token': self.special_tokens['t_pres_token']},
            
            'fut' : {'id'   : self.special_ids['t_fut_token'],
                     'token': self.special_tokens['t_fut_token']}
        }
        
        self.vocab_size = vocab_size
        
        self.vocab_cache_path = {
            'dir': './data/cached',
            'filename': 'vocab.pkl'
        }
        
        self.vocab = None
        
    def _tokenize(self, input: Union[List[str], str]):
        """Input (Union[List[str], str]): a list of string sequences or a single sequence."""
        
        if type(input) is list:
            tokens = list(chain(*[self.tokenizer(text, 'russian') for text in tqdm(input, 'Tokenizing texts')]))
        else:
            tokens = self.tokenizer(input, 'russian')
        return tokens
    
    def build_vocab(self, texts=None, save_vocab=True, load_vocab=False):
        if load_vocab:
            self.vocab = self.load_vocab(dir_path=self.vocab_cache_path['dir'], filename=self.vocab_cache_path['filename'])
            if self.vocab is not None:
                return
        
        print('Building vocab from texts...')
        tokens = self._tokenize(texts)
        
        n_first = self.vocab_size - len(self.special_tokens)
        
        tokens = [token for token, _ in Counter(tokens).most_common(n_first)]
        tokens = list(self.special_tokens.values()) + tokens
                
        self.vocab = Vocab(tokens, self.unk_token['id'])
        print('Success')
        
        if save_vocab:
            dir_path = self.vocab_cache_path['dir']
            filename = self.vocab_cache_path['filename']
            file_path = dir_path + '/' + filename
            
            print(f'Saving vocab at {file_path} ...')
            self.save_vocab(dir_path=dir_path, filename=filename)
    
    def save_vocab(self, dir_path, filename):
        try:
            pathlib.Path(dir_path).mkdir(exist_ok=True)
            file_path = dir_path + '/' + filename

            with open(file_path, 'wb') as f:
                pickle.dump(self.vocab, f)

            print(f'Vocab is saved successfully at {file_path}')
            
        except Exception as e:
            print(f'Failed to save vocab due to:\n{e}')
    
    def load_vocab(self, dir_path, filename):
        try:
            file_path = dir_path + '/' + filename
            
            with open(file_path, 'rb') as f:
                vocab = pickle.load(f)
            
            print(f'Vocab is loaded successfully from {file_path}')
            
            return vocab
            
        except Exception as e:
            print(f'Failed to load vocab due to:\n{e}')
            
            return None
        
    def _pad_sequence(self, ids: List[int], max_seq_len) -> List[int]:
        if len(ids) >= max_seq_len:
            ids = ids[:max_seq_len]
        else:
            pad_len = max_seq_len - len(ids)
            ids.extend(pad_len * [self.pad_token.get('id')])
        
        return ids
    
    def _add_special_tokens(self, ids: List[int], context=None) -> None:
        ids.insert(0, self.sos_token.get('id'))
        try:
            eos_position = ids.index(self.pad_token.get('id'))
        except ValueError:
            eos_position = len(ids)
        ids.insert(eos_position, self.eos_token.get('id'))
        
        if context is not None:
            nsubj, gender, tense = context
            
            context_info = [
                self.tense_tokens[tense].get('id'),
                self.gender_tokens[gender].get('id'),
                self.vocab.token_to_id(nsubj)
            ]
            
            for item in context_info:
                ids.insert(0, item)
        
    def encode(self, input: Union[str, List[str]], context=None, add_special_tokens=True, max_seq_len=None, return_tensor=False):
        """Input (Union[List[str], str]): a list of string sequences or a single sequence."""
        # context = (nsubj, gender, tense)
        if type(input) is str:
            tokens = self._tokenize(input)
            ids = [self.vocab.token_to_id(token) for token in tokens]
            
            if max_seq_len is not None:
                ids = self._pad_sequence(ids, max_seq_len)
            
            if add_special_tokens:
                self._add_special_tokens(ids, context)
            
            if not return_tensor:
                return ids
            else:
                return torch.tensor(ids)
            
        else:
            tokenized_sents = [self._tokenize(sent) for sent in input]
            sents_ids = [[self.vocab.token_to_id(token) for token in sent] for sent in tokenized_sents]
            
            max_seq_len = max(map(len, sents_ids)) if max_seq_len is None else max_seq_len
            
            padded_sequences = [self._pad_sequence(ids, max_seq_len) for ids in sents_ids]
            padded_seq_and_context = zip(padded_sequences, context) if context is not None else None
            
            if add_special_tokens:
                if padded_seq_and_context is not None:
                    for ids, context in padded_seq_and_context:
                        self._add_special_tokens(ids, context)
                else:
                    for ids in padded_sequences:
                        self._add_special_tokens(ids)
            
            if not return_tensor:
                return padded_sequences
            else:
                return torch.tensor(padded_sequences)
            
    def decode(self, encoded_seq: Union[List[int], torch.Tensor], remove_special_tokens=False, return_tokenized=True):
        if type(encoded_seq) is list:
            
            if remove_special_tokens:
                encoded_seq = encoded_seq[4:-1]
            
            decoded_seq = [self.vocab.id_to_token(id) for id in encoded_seq]
            
            if return_tokenized:
                return decoded_seq
            else:
                return ' '.join(decoded_seq)
        
        else:
            if len(encoded_seq.shape) > 1:
                encoded_seq.unsqueeze_(0)
                
            if remove_special_tokens:
                encoded_seq = encoded_seq[4:-1]
            
            decoded_seq = [self.vocab.id_to_token(id.item()) for id in encoded_seq]
            
            if return_tokenized:
                return decoded_seq
            else:
                return ' '.join(decoded_seq)
            

In [9]:
vocab_size = 100000
max_seq_len = None
batch_size = 8

### Разбиение данных на обучающие, тестовые и валидационные

In [10]:
train_df, test_df = model_selection.train_test_split(df, train_size=0.9)
test_df, val_df = model_selection.train_test_split(test_df, test_size=0.5)

### Подготовка словаря

In [11]:
data_preperator = DataPreperator(vocab_size)

In [12]:
texts = train_df.lemm_texts.to_list() + train_df.orig_texts.to_list()

In [13]:
data_preperator.build_vocab(load_vocab=True)

Vocab is loaded successfully from ./data/cached/vocab.pkl


### Разбиение данных на батчи

In [16]:
def make_batched_dataset(df, data_preperator=data_preperator, batch_size=batch_size):
    n_batches = len(df) // batch_size
    
    for n_batch in range(n_batches):
        orig   = df.orig_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        lemm   = df.lemm_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        nsubj  = df.nsubj.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        gender = df.gender.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        tense  = df.tense.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        
        encoded_input  = data_preperator.encode(lemm, context=list(zip(nsubj, gender, tense)), return_tensor=True)
        encoded_target = data_preperator.encode(orig, return_tensor=True)
        
        batch = (encoded_input, encoded_target)
        
        yield batch

In [17]:
train_data = make_batched_dataset(train_df)
val_data = make_batched_dataset(val_df)
test_data = make_batched_dataset(test_df)

### Определение параметров сети

In [24]:
params = {
    'd_model': 512,
    'nhead': 8,
    'num_encoder_layers': 6,
    'num_decoder_layers': 6,
    'dim_feedforward': 2048,
    'dropout': 0.1,
    'activation': 'relu',
    'layer_norm_eps': 1e-05,
    'embedding_size': 300,
    'vocab_size': vocab_size,
    'pad_token_id': data_preperator.pad_token['id'],
    'batch_first': True,
    'device': torch.device('cpu')
}

In [None]:
# d_model, nhead, num_encoder_layers, num_decoder_layers,
#                  dim_feedforward, dropout, activation,
#                  custom_encoder, custom_decoder, layer_norm_eps,
#                  batch_first, device

In [None]:
class Seq2SeqTransformer(nn.Module):
    def __init__(
        self,
        d_model, n_head,
        num_encoder_layers, num_decoder_layers,
        dim_feedforward, dropout,
        activation, layer_norm_eps,
        embedding_size, vocab_size,
        pad_token_id, batch_first,
        device):
        
        super(Seq2SeqTransformer, self).__init__()
        
        self.transformer = nn.Transformer(d_model, n_head, num_encoder_layers,
                                          num_decoder_layers, dim_feedforward,
                                          dropout, activation,
                                          layer_norm_eps=layer_norm_eps,
                                          batch_first=batch_first, device=device)
        
        self.embedding = nn.Embedding(vocab_size, embedding_size, pad_token_id, device=device)