## Импорт необходимых зависимостей

In [1]:
import pandas as pd
import numpy as np

import nltk

import torch
import torch.nn as nn
import torch.optim

import pickle
import pathlib

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, TensorDataset

from sklearn import model_selection

from pprint import pprint
from random import choice
from typing import List, Union
from collections import Counter
from itertools import chain

from tqdm.auto import tqdm

ModuleNotFoundError: No module named 'tensorboard'

In [None]:
nltk.download('punkt', quiet=True);

## Подготовка данных

In [None]:
RANDOM_STATE = 1

In [None]:
df = pd.read_csv('./data/lenta/dataset.csv')

In [5]:
df = df.sample(n=50000, random_state=RANDOM_STATE)
df

Unnamed: 0,orig_texts,lemm_texts,nsubj,gender,tense,sent_length
782094,"депутат от ""справедливой россии"" дмитрий гудко...","депутат от ""справедливый россия""дмитрий гудков...",депутат,masc,past,23.0
1589446,"как сообщается, бои идут между иракцами, амери...","как сообщаться, бой идти между иракец, америка...",бои,masc,pres,15.0
1268862,"обращение подписали уже 15 тысяч человек, пише...","обращение подписать уже 15 тысяча человек, пис...",тысяч,fem,past,15.0
392907,конфликт между ними произошел из-за брата осуж...,конфликт между они произойти из-за брат осуждё...,конфликт,masc,past,23.0
486500,речь идет об эпизоде «суровый дом» (восьмая се...,речь идти о эпизод « суровый дом » (восьмой се...,речь,fem,pres,23.0
...,...,...,...,...,...,...
1534870,чем больше времени дети проводят у телевизоров...,чем большой время ребёнок проводить у телевизо...,дети,masc,pres,13.0
647177,этим и воспользовалась женщина (имя которой не...,это и воспользоваться женщина (имя который не ...,женщина,fem,past,29.0
595915,представитель «народной армии донбасса» требов...,представитель « народный армия донбасс » требо...,представитель,masc,past,11.0
1081552,по размеру розничной сети почтовый банк станет...,по размер розничный сеть почтовый банк стать к...,банк,masc,fut,15.0


### Определение классов словаря и подготовщика данных

In [6]:
class Vocab:
    def __init__(self, tokens=None, unk_id=None):
        self.unk_id = unk_id
        
        self.tokens = tokens
        self.tokens_to_ids = {token: id for id, token in enumerate(tokens)} if tokens is not None else None
    
    def id_to_token(self, id):
        return self.tokens[id]
    
    def token_to_id(self, token):
        return self.tokens_to_ids.get(token, self.unk_id)

In [7]:
class Tokenizer:
    def __init__(self, vocab_size):
        
        self.tokenizer = nltk.tokenize.word_tokenize
        
        self.special_tokens = {
            'pad_token'        : '<pad>',
            'unk_token'        : '<unk>',
            'sos_token'        : '<sos>',
            'eos_token'        : '<eos>',
            'g_masc_token'     : '<masc>',
            'g_fem_token'      : '<fem>',
            'g_neut_token'     : '<neut>',
            'g_undefined_token': '<undef>',
            't_past_token'     : '<past>',
            't_pres_token'     : '<pres>',
            't_fut_token'      : '<fut>'
        }
        
        self.special_ids = {token: id for id, token in enumerate(self.special_tokens.keys())}
        
        self.pad_token = {'id'   : self.special_ids['pad_token'],
                          'token': self.special_tokens['pad_token']}
        
        self.unk_token = {'id'   : self.special_ids['unk_token'],
                          'token': self.special_tokens['unk_token']}
        
        self.sos_token = {'id'   : self.special_ids['sos_token'],
                          'token': self.special_tokens['sos_token']}
        
        self.eos_token = {'id'   : self.special_ids['eos_token'],
                          'token': self.special_tokens['eos_token']}
        
        self.gender_tokens = {
            'masc'     : {'id'   : self.special_ids['g_masc_token'],
                          'token': self.special_tokens['g_masc_token']},
            
            'fem'      : {'id'   : self.special_ids['g_fem_token'],
                          'token': self.special_tokens['g_fem_token']},
            
            'neut'     : {'id'   : self.special_ids['g_neut_token'],
                          'token': self.special_tokens['g_neut_token']},
            
            'undefined': {'id'   : self.special_ids['g_undefined_token'],
                          'token': self.special_tokens['g_undefined_token']}
        }
        
        self.tense_tokens = {
            'past': {'id'   : self.special_ids['t_past_token'],
                     'token': self.special_tokens['t_past_token']},
            
            'pres': {'id'   : self.special_ids['t_pres_token'],
                     'token': self.special_tokens['t_pres_token']},
            
            'fut' : {'id'   : self.special_ids['t_fut_token'],
                     'token': self.special_tokens['t_fut_token']}
        }
        
        self.vocab_size = vocab_size
        
        self.vocab_cache_path = {
            'dir': './data/cached',
            'filename': 'vocab.pkl'
        }
        
        self.vocab = None
        
    def _tokenize(self, input: Union[List[str], str]):
        """Input (Union[List[str], str]): a list of string sequences or a single sequence."""
        
        if type(input) is list:
            tokens = list(chain(*[self.tokenizer(text, 'russian') for text in tqdm(input, 'Tokenizing texts')]))
        else:
            tokens = self.tokenizer(input, 'russian')
        return tokens
    
    def build_vocab(self, texts=None, save_vocab=True, load_vocab=False):
        if load_vocab:
            self.vocab = self.load_vocab(dir_path=self.vocab_cache_path['dir'], filename=self.vocab_cache_path['filename'])
            if self.vocab is not None:
                return
        
        print('Building vocab from texts...')
        tokens = self._tokenize(texts)
        
        n_first = self.vocab_size - len(self.special_tokens)
        
        tokens = [token for token, _ in Counter(tokens).most_common(n_first)]
        tokens = list(self.special_tokens.values()) + tokens
                
        self.vocab = Vocab(tokens, self.unk_token['id'])
        print('Success')
        
        if save_vocab:
            dir_path = self.vocab_cache_path['dir']
            filename = self.vocab_cache_path['filename']
            file_path = dir_path + '/' + filename
            
            print(f'Saving vocab at {file_path} ...')
            self.save_vocab(dir_path=dir_path, filename=filename)
    
    def save_vocab(self, dir_path, filename):
        try:
            pathlib.Path(dir_path).mkdir(exist_ok=True)
            file_path = dir_path + '/' + filename

            with open(file_path, 'wb') as f:
                pickle.dump(self.vocab, f)

            print(f'Vocab is saved successfully at {file_path}')
            
        except Exception as e:
            print(f'Failed to save vocab due to:\n{e}')
    
    def load_vocab(self, dir_path, filename):
        try:
            file_path = dir_path + '/' + filename
            
            with open(file_path, 'rb') as f:
                vocab = pickle.load(f)
            
            print(f'Vocab is loaded successfully from {file_path}')
            
            return vocab
            
        except Exception as e:
            print(f'Failed to load vocab due to:\n{e}')
            
            return None
        
    def _pad_sequence(self, ids: List[int], max_seq_len) -> List[int]:
        if len(ids) >= max_seq_len:
            ids = ids[:max_seq_len]
        else:
            pad_len = max_seq_len - len(ids)
            ids.extend(pad_len * [self.pad_token.get('id')])
        
        return ids
    
    def _add_special_tokens(self, ids: List[int], context=None) -> None:
        ids.insert(0, self.sos_token.get('id'))
        try:
            eos_position = ids.index(self.pad_token.get('id'))
        except ValueError:
            eos_position = len(ids)
        ids.insert(eos_position, self.eos_token.get('id'))
        
        if context is not None:
            nsubj, gender, tense = context
            
            context_info = [
                self.tense_tokens[tense].get('id'),
                self.gender_tokens[gender].get('id'),
                self.vocab.token_to_id(nsubj)
            ]
            
            for item in context_info:
                ids.insert(0, item)
        
    def encode(self, input: Union[str, List[str]], context=None, add_special_tokens=True, max_seq_len=None, return_tensor=False):
        """Input (Union[List[str], str]): a list of string sequences or a single sequence."""
        # context = (nsubj, gender, tense)
        if type(input) is str:
            tokens = self._tokenize(input)
            ids = [self.vocab.token_to_id(token) for token in tokens]
            
            if max_seq_len is not None:
                ids = self._pad_sequence(ids, max_seq_len)
            
            if add_special_tokens:
                self._add_special_tokens(ids, context)
            
            if not return_tensor:
                return ids
            else:
                return torch.tensor(ids)
            
        else:
            tokenized_sents = [self._tokenize(sent) for sent in input]
            sents_ids = [[self.vocab.token_to_id(token) for token in sent] for sent in tokenized_sents]
            
            max_seq_len = max(map(len, sents_ids)) if max_seq_len is None else max_seq_len
            
            padded_sequences = [self._pad_sequence(ids, max_seq_len) for ids in sents_ids]
            padded_seq_and_context = zip(padded_sequences, context) if context is not None else None
            
            if add_special_tokens:
                if padded_seq_and_context is not None:
                    for ids, context in padded_seq_and_context:
                        self._add_special_tokens(ids, context)
                else:
                    for ids in padded_sequences:
                        self._add_special_tokens(ids)
            
            if not return_tensor:
                return padded_sequences
            else:
                return torch.tensor(padded_sequences)
            
    def decode(self, encoded_seq: Union[List[int], torch.Tensor], remove_special_tokens=False, return_tokenized=True):
        if type(encoded_seq) is list:
            
            if remove_special_tokens:
                encoded_seq = encoded_seq[4:-1]
            
            decoded_seq = [self.vocab.id_to_token(id) for id in encoded_seq]
            
            if return_tokenized:
                return decoded_seq
            else:
                return ' '.join(decoded_seq)
        
        else:
            if len(encoded_seq.shape) > 1:
                encoded_seq.squeeze_(0)
                
            if remove_special_tokens:
                encoded_seq = encoded_seq[4:-1]
            
            decoded_seq = [self.vocab.id_to_token(id.item()) for id in encoded_seq]
            
            if return_tokenized:
                return decoded_seq
            else:
                return ' '.join(decoded_seq)
            

In [8]:
vocab_size = 100000
max_seq_len = None
batch_size = 8

### Разбиение данных на обучающие, тестовые и валидационные

In [9]:
train_df, test_df = model_selection.train_test_split(df, train_size=0.9)
test_df, val_df = model_selection.train_test_split(test_df, test_size=0.5)

### Подготовка словаря

In [10]:
tokenizer = Tokenizer(vocab_size)

In [11]:
texts = train_df.lemm_texts.to_list() + train_df.orig_texts.to_list()

In [12]:
tokenizer.build_vocab(texts)

Vocab is loaded successfully from ./data/cached/vocab.pkl


### Разбиение данных на батчи

In [13]:
def make_batched_dataset(df, max_seq_len=max_seq_len, tokenizer=tokenizer, batch_size=batch_size):
    n_batches = len(df) // batch_size
    
    for n_batch in range(n_batches):
        
        orig   = df.orig_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        lemm   = df.lemm_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        nsubj  = df.nsubj.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        gender = df.gender.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        tense  = df.tense.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        
        context = list(zip(nsubj, gender, tense))
        
        encoded_input  = tokenizer.encode(lemm, context=context, add_special_tokens=True, max_seq_len=max_seq_len, return_tensor=True)
        encoded_target = tokenizer.encode(orig, add_special_tokens=True, max_seq_len=max_seq_len, return_tensor=True)
        
        batch = (encoded_input.permute(1, 0), encoded_target.permute(1, 0))
        
        yield batch

In [14]:
train_n_batches = len(train_df) // batch_size
val_n_batches = len(val_df) // batch_size
test_n_batches = len(test_df)

In [15]:
train_data = make_batched_dataset(train_df)
val_data = make_batched_dataset(val_df)
test_data = make_batched_dataset(test_df, batch_size=1)

In [16]:
def save_processed_data(train_data, val_data, test_data):
    path = {
        'dir': './data/cached',
        'name': 'processed_data.pkl'
    }
    
    try:
        pathlib.Path(path['dir']).mkdir(exist_ok=True)
        file_path = path['dir'] + '/' + path['name']

        with open(file_path, 'wb') as f:
            pickle.dump((train_data, val_data, test_data), f)

        print(f'Data is saved successfully at {file_path}')

    except Exception as e:
        print(f'Failed to save data due to:\n{e}')

In [17]:
def load_processed_data(path='./data/cached/processed_data.pkl'):
    try:
        with open(path, 'rb') as f:
            data = pickle.load(f)

        print(f'Data is loaded successfully from {path}')

        return data

    except Exception as e:
        print(f'Failed to load data due to:\n{e}')

        return [None] * 3

In [18]:
load_data = True
save_data = True

if load_data:
    train_data, val_data, test_data = load_processed_data()

if not load_data or train_data is None:
    train_data = [batch for batch in tqdm(make_batched_dataset(train_df), desc='Unpacking train batches', total=train_n_batches)]
    val_data = [batch for batch in tqdm(make_batched_dataset(val_df), desc='Unpacking validation batches', total=val_n_batches)]
    test_data = [batch for batch in tqdm(make_batched_dataset(test_df, batch_size=1), desc='Unpacking test batches', total=test_n_batches)]

if save_data:
    save_processed_data(train_data, val_data, test_data)

Data is loaded successfully from ./data/cached/processed_data.pkl
Data is saved successfully at ./data/cached/processed_data.pkl


## Определение модели

### Определение класса модели

In [19]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, embedding_size, nhead,
                 num_encoder_layers, num_decoder_layers, 
                 dim_feedforward, dropout, vocab_size,
                 max_seq_len, pad_token_id, device):
        
        super(Seq2SeqTransformer, self).__init__()
        
        self.pad_token_id = pad_token_id
        self.max_seq_len = max_seq_len
        self.embedding_size = embedding_size
        
        self.to(device)
        
        self.device = device
        
        self.word_embedding = nn.Embedding(vocab_size, embedding_size, pad_token_id)
        self.input_pos_encoding = nn.Embedding(max_seq_len, embedding_size)
        self.target_pos_encoding = nn.Embedding(max_seq_len, embedding_size)
        
        self.transformer = nn.Transformer(embedding_size, nhead, num_encoder_layers,
                                          num_decoder_layers, dim_feedforward, dropout)
        
        self.fc_out = nn.Linear(embedding_size, vocab_size)
        
    def get_padding_mask(self, input):
        # input shape: (seq_len, batch_size)
        padding_mask = input.permute(1, 0) == self.pad_token_id
        return padding_mask.to(self.device)
    
    def forward(self, input, target):
        # input shape: (input_seq_len, batch_size)
        # target shape: (target_seq_len, batch_size)
    
        embedded_input = self.word_embedding(input)
        embedded_target = self.word_embedding(target)
        # embedded_input shape: (input_seq_len, batch_size, embedding_size)
        # embedded_target shape: (target_seq_len, batch_size, embedding_size)
        
        batch_size = input.shape[1]
        
        input_seq_len = input.shape[0]
        target_seq_len = target.shape[0]
    
        input_positions = torch.arange(0, input_seq_len).unsqueeze(1).expand(input_seq_len, batch_size).to(self.device)
        target_positions = torch.arange(0, target_seq_len).unsqueeze(1).expand(target_seq_len, batch_size).to(self.device)
        # input_positions shape: (input_seq_len, batch_size)
        # target_positions shape: (target_seq_len, batch_size)
        
        input_positions = self.input_pos_encoding(input_positions)
        target_positions = self.target_pos_encoding(target_positions)
        # input_positions shape: (input_seq_len, batch_size, embedding_size)
        # target_positions shape: (target_seq_len, batch_size, embedding_size)
        
        embedded_input += input_positions
        embedded_target += target_positions
        
        input_padding_mask = self.get_padding_mask(input)
        # input_padding_mask shape: (batch_size, input_seq_len)
        
        target_mask = self.transformer.generate_square_subsequent_mask(target_seq_len).to(self.device)
        # target_mask shape: (target_seq_len, target_seq_len)
        
        output = self.transformer(embedded_input, embedded_target,
                                  tgt_mask=target_mask,
                                  src_key_padding_mask=input_padding_mask)
        # output shape: (target_seq_len, batch_size, embedding_size)
        
        output = self.fc_out(output)
        # output shape: (target_seq_len, batch_size, vocab_size)
        
        return output

## Определение функций-утилит

### Сохранение модели

In [20]:
def save_model(model, optimizer, epoch, val_loss, train_loss, path='./models/seq2seq_transformer.model'):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'val_loss': val_loss,
        'train_loss': train_loss
    }
    
    torch.save(checkpoint, path)
    print(f'\n\tModel saved successfully at {path}')

### Загрузка модели

In [21]:
def load_model(model, optimizer=None, path='./model/seq2seq_transformer.model', device=torch.device('cpu')):
    checkpoint = torch.load(path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
    epoch      = checkpoint['epoch']
    val_loss   = checkpoint['val_loss']
    train_loss = checkpoint['train_loss']

    return {'epoch': epoch, 'val_loss': val_loss, 'train_loss': train_loss}

## Место хранения модели

In [22]:
model_path = {
    'dir': './models/',
    'name': 'seq2seq_transformer.model'
}

## Обучение модели

### Определение параметров обучения

In [23]:
learning_params = {
    'learning_rate': 1e-03,
    'epochs': 10,
    'max_norm': 1.0,
    'patience': 3
}

### Определение параметров сети

In [24]:
params = {
    'embedding_size': 512,
    'nhead': 8,
    'num_encoder_layers': 6,
    'num_decoder_layers': 6,
    'dim_feedforward': 2048,
    'dropout': 0.1,
    'vocab_size': vocab_size,
    'max_seq_len': 50,
    'pad_token_id': tokenizer.pad_token['id'],
    'device': torch.device('cuda')
}

### Инициализация модели, оптимизатора и функции потерь

In [25]:
model = Seq2SeqTransformer(**params).to(params['device'])

In [26]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_params['learning_rate'])

In [27]:
criterion = nn.CrossEntropyLoss(ignore_index=params['pad_token_id'])

In [28]:
load_pretrained_model = False
train_state = None

if load_pretrained_model:
    try:
        train_state = load_model(
            model, optimizer,
            model_path['dir'] + model_path['name'],
            params['device']
        )
        print(f"Model loaded successfully from {model_path.get('dir') + model_path.get('name')}")
    
    except Exception as e:
        print(f'Load failed due to:\n{e}')

epoch = train_state['epoch'] if train_state is not None else 0

In [29]:
train_loss_writer = SummaryWriter('./runs/loss')

### Train-скрипт

In [30]:
def train(
    model, optimizer, criterion,
    train_data, val_data, test_data,
    epochs, max_norm, patience, current_epoch,
    device, tokenizer, model_path, max_seq_len,
    train_loss_writer, n_prints=10
):
    
    min_mean_val_loss = float('+inf')
    initial_patience = patience
    print_every = len(train_data) // n_prints
    
    for epoch in tqdm(range(current_epoch, epochs), 'Epochs'):
        print(f'\nEpoch [{epoch} / {epochs}]')
        
        model.train()
        for iteration, (input, target) in enumerate(tqdm(train_data, desc='Training iterations')):
            input  = input.to(device)
            target = target.to(device)
            # input shape : (input_seq_len, batch_size)
            # target shape: (target_seq_len, batch_size)
            
            optimizer.zero_grad()
            
            output = model(input, target[:-1])
            # output shape: (target_seq_len, batch_size, vocab_size)
            
            vocab_size = output.shape[2]
            
            output = output.reshape(-1, vocab_size)
            # output shape: (target_seq_len * batch_size, vocab_size)
            
            target = target[1:].reshape(-1)
            # target shape: (target_seq_len * batch_size)
            
            loss = criterion(output, target)
            loss.backward()
            
            global_step = epoch * (len(train_data) + 1) + iteration
            train_loss_writer.add_scalar('Training loss', loss, global_step=global_step)
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            
            optimizer.step()
            
            if iteration % print_every == 0:
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            elif iteration == len(train_data):
                print(f'\tIteration #{iteration}: training loss = {loss.item()}')
            
            
        with torch.no_grad():
            model.eval()
            
            val_loss = []
            
            for iteration, (input, target) in enumerate(tqdm(val_data, desc='Validating iterations')):
                input  = input.to(device)
                target = target.to(device)
                
                output = model(input, target[:-1])
                vocab_size = output.shape[2]
                output = output.reshape(-1, vocab_size)
                
                target = target[1:].reshape(-1)
                
                local_val_loss = criterion(output, target)
                val_loss.append(local_val_loss.item())
                
            mean_val_loss = sum(val_loss) / len(val_loss)
            print(f'\tValidation loss = {mean_val_loss}')
            
            if mean_val_loss < min_mean_val_loss:
                try:
                    save_model(model, optimizer, epoch, mean_val_loss, loss)
                    min_mean_val_loss = mean_val_loss
                    patience = initial_patience
                except Exception as e:
                    print(f'Model training stopped due to unhandled exception:\n{e}')
            else:
                patience -= 1
                
            
            test_sample = choice(test_data)
            
            predictions = [tokenizer.sos_token.get('id')]
            for i in range(max_seq_len):
                target = torch.tensor(predictions, device=device).unsqueeze(1)
                
                output = model(test_sample[0].to(device), target)
                best_prediction = output.argmax(2)[-1].item()
                predictions.append(best_prediction)
                
                if best_prediction == tokenizer.eos_token.get('id'):
                    break
            
            decoded_output = tokenizer.decode(predictions,    return_tokenized=False)
            decoded_input  = tokenizer.decode(test_sample[0], return_tokenized=False)
            decoded_target = tokenizer.decode(test_sample[1], return_tokenized=False)
            
            print(f'\tInput : {decoded_input}')
            print(f'\tOutput: {decoded_output}')
            print(f'\tTarget: {decoded_target}')
            
        if patience == 0:
            print(f'\nModel learning finished due to early stopping')

In [31]:
train(
    model, optimizer, criterion,
    train_data, val_data, test_data,
    learning_params['epochs'], learning_params['max_norm'],
    learning_params['patience'], epoch, params['device'], tokenizer,
    model_path['dir'] + model_path['name'], params['max_seq_len'],
    train_loss_writer
)

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]


Epoch [0 / 10]


Training iterations:   0%|          | 0/5625 [00:00<?, ?it/s]

	Iteration #0: training loss = 11.73180866241455


KeyboardInterrupt: 