In [1]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np

from torch.utils.data import DataLoader, Dataset

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from transformers.optimization import AdamW

from random import choice, random

from tqdm.auto import tqdm

from sklearn import model_selection

## Подготовка данных

In [2]:
RANDOM_STATE = 1

In [3]:
df = pd.read_csv('../input/lenta-dataset/dataset.csv')

In [4]:
df.gender.replace('undefined', 'undefined_g', inplace=True)
df.number.replace('undefined', 'undefined_n', inplace=True)

In [5]:
df.tense = df.tense.apply(lambda t: np.nan if t == 'past' and random() >= 0.5 else t)
df.dropna(inplace=True)

In [6]:
df = df.sample(frac=0.2, random_state=RANDOM_STATE)

In [7]:
df

Unnamed: 0,orig_texts,part_lemm_texts,length,nsubj,gender,tense,number
1637988,"Он также пояснил, что компания не может контро...","Он также пояснить, что компания не может контр...",27,Он,masc,past,sing
1406376,Каким именно образом алкоголизм связан с видео...,какой именно образом алкоголизм связать с виде...,15,корреспонденты,masc,past,plur
642225,"К такому выводу пришли ученые, прочитавшие ген...","к такой вывод прислать учёный, прочитавший ген...",20,ученые,masc,past,plur
1668448,Городок аттракционов Paradise Amusements работ...,городок аттракцион paradise amusements работае...,18,Городок,masc,pres,sing
1791533,В результате Клинтон потерял равновесие и пока...,в результат Клинтон потерять равновесие и пока...,10,Клинтон,masc,past,sing
...,...,...,...,...,...,...,...
197938,"Глава МВД Владимир Колокольцев, прибывший в Кр...","глава МВД владимир колоколец, прибывший в крат...",22,Глава,masc,past,sing
1091943,"По мнению экспертов, неудачные улучшения не т...","по мнение эксперт, неудачные улучшение не толь...",20,улучшения,neut,pres,plur
552397,Мужчина по имени Джек Дэниэлс (Jack Daniel’s) ...,мужчина по имя джек дэниэлс (jack daniel ’ s) ...,29,Мужчина,masc,past,sing
832537,"Выход ""Клипперс"" в плей-офф стал возможен благ...","выход "" клипперс "" в плей-офф стать возможный ...",22,Выход,masc,past,sing


### Разбиение данных на обучающие, тестовые и валидационные

In [8]:
train_df, test_df = model_selection.train_test_split(df, train_size=0.9)
test_df, val_df = model_selection.train_test_split(test_df, test_size=0.5)

### Загрузка претренированной модели

In [9]:
model_name = "facebook/mbart-large-50"

In [10]:
model = MBartForConditionalGeneration.from_pretrained(model_name)

In [11]:
tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang='ru_RU', tgt_lang='ru_RU')

In [12]:
tokenizer

PreTrainedTokenizerFast(name_or_path='facebook/mbart-large-50', vocab_size=250054, model_max_len=1024, is_fast=True, padding_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>', 'additional_special_tokens': ['ar_AR', 'cs_CZ', 'de_DE', 'en_XX', 'es_XX', 'et_EE', 'fi_FI', 'fr_XX', 'gu_IN', 'hi_IN', 'it_IT', 'ja_XX', 'kk_KZ', 'ko_KR', 'lt_LT', 'lv_LV', 'my_MM', 'ne_NP', 'nl_XX', 'ro_RO', 'ru_RU', 'si_LK', 'tr_TR', 'vi_VN', 'zh_CN', 'af_ZA', 'az_AZ', 'bn_IN', 'fa_IR', 'he_IL', 'hr_HR', 'id_ID', 'ka_GE', 'km_KH', 'mk_MK', 'ml_IN', 'mn_MN', 'mr_IN', 'pl_PL', 'ps_AF', 'pt_XX', 'sv_SE', 'sw_KE', 'ta_IN', 'te_IN', 'th_TH', 'tl_XX', 'uk_UA', 'ur_PK', 'xh_ZA', 'gl_ES', 'sl_SI']})

In [13]:
special_tokens = {
    'masc': '<masc_g>',
    'fem': '<fem_g>',
    'neut': '<neut_g>',
    'undefined_g': '<undef_g>',
    'past': '<past_t>',
    'pres': '<pres_t>',
    'fut': '<fut_t>',
    'sing': '<sing_n>',
    'plur': '<plur_n>',
    'undefined_n': '<undef_n>'
}

In [14]:
num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': list(special_tokens.values())})

In [15]:
model.set_input_embeddings(model.resize_token_embeddings(num_added_tokens + tokenizer.vocab_size))

### Разбиение данных на батчи

In [16]:
batch_size = 4

In [17]:
def make_batched_dataset(df, tokenizer=tokenizer, batch_size=batch_size):
    n_batches = len(df) // batch_size
    
    for n_batch in range(n_batches):
        
        orig_texts   = df.orig_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        lemm_texts   = df.part_lemm_texts.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        nsubj_list   = df.nsubj.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        gender_list  = df.gender.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        tense_list   = df.tense.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        number_list  = df.number.to_list()[batch_size * n_batch:batch_size * (n_batch + 1)]
        
        bos_token = tokenizer.bos_token
        eos_token = tokenizer.eos_token
        
        inputs = zip(lemm_texts, nsubj_list, gender_list, tense_list, number_list)
        
        inputs = [f'{nsubj} {special_tokens[gender]} {special_tokens[tense]} {special_tokens[number]} {bos_token} {lemm} {eos_token}'
                  for lemm, nsubj, gender, tense, number in inputs]
        
        inputs = tokenizer(inputs, add_special_tokens=False, padding='longest',
                           return_tensors='pt')
        
        targets = [f'{bos_token} {orig} {eos_token}' for orig in orig_texts]
        
        with tokenizer.as_target_tokenizer():
            targets = tokenizer(targets, add_special_tokens=False, padding='longest',
                                return_tensors='pt', return_attention_mask=False, return_token_type_ids=False).input_ids
        
        yield inputs, targets
        

In [18]:
train_n_batches = len(train_df) // batch_size
val_n_batches = len(val_df) // batch_size
test_n_batches = len(test_df)

In [19]:
# def save_processed_data(train_data, val_data, test_data):
#     path = {
#         'dir': './data/cached',
#         'name': 'processed_data_mt5.pkl'
#     }
    
#     try:
#         pathlib.Path(path['dir']).mkdir(exist_ok=True)
#         file_path = path['dir'] + '/' + path['name']

#         with open(file_path, 'wb') as f:
#             pickle.dump((train_data, val_data, test_data), f)

#         print(f'Data is saved successfully at {file_path}')

#     except Exception as e:
#         print(f'Failed to save data due to:\n{e}')

In [20]:
# def load_processed_data(path='./data/cached/processed_data_mt5.pkl'):
#     try:
#         with open(path, 'rb') as f:
#             data = pickle.load(f)

#         print(f'Data is loaded successfully from {path}')

#         return data

#     except Exception as e:
#         print(f'Failed to load data due to:\n{e}')

#         return [None] * 3

In [21]:
train_data = make_batched_dataset(train_df)
val_data = make_batched_dataset(val_df)
test_data = []
for i, batch in enumerate(tqdm(make_batched_dataset(test_df, batch_size=1), desc='Unpacking test batches', total=500)):
    test_data.append(batch)
    if i == 500:
        break
# test_data = [batch for batch in tqdm(make_batched_dataset(test_df, batch_size=1), desc='Unpacking test batches', total=test_n_batches)]

Unpacking test batches:   0%|          | 0/500 [00:00<?, ?it/s]

In [22]:
# load_data = False
# save_data = True

# if load_data:
#     train_data, val_data, test_data = load_processed_data()

# if not load_data or train_data is None:
#     train_data = [batch for batch in tqdm(make_batched_dataset(train_df), desc='Unpacking train batches', total=train_n_batches)]
#     val_data = [batch for batch in tqdm(make_batched_dataset(val_df), desc='Unpacking validation batches', total=val_n_batches)]
#     test_data = [batch for batch in tqdm(make_batched_dataset(test_df, batch_size=1), desc='Unpacking test batches', total=test_n_batches)]

# if save_data:
#     save_processed_data(train_data, val_data, test_data)

In [23]:
class BatchedDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [24]:
def save_model(model, optimizer, path='./seq2seq_mbart_finetuned.model'):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    
    torch.save(checkpoint, path)
    print(f'\n\tModel saved successfully at {path}\n')

In [25]:
def load_model(model, optimizer, device, path='../input/mt5-model-finetuned/seq2seq_mt5_finetuned.model'):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f'\n\tModel loaded successfully from {path}\n')

## Обучение модели

### Определение параметров обучения

In [26]:
params = {
    'learning_rate': 5e-05,
    'epochs': 10,
    'max_norm': 1.0,
    'device': torch.device('cuda'),
    'max_seq_len': 150,
    'epochs': 5
}

In [27]:
# train_data = DataLoader(BatchedDataset(train_data), batch_size=None, shuffle=True, num_workers=0)
# val_data = DataLoader(BatchedDataset(val_data), batch_size=None, shuffle=False, num_workers=0)

In [28]:
optimizer = AdamW(model.parameters(), lr=params['learning_rate'])

In [29]:
model = model.to(params['device'])

In [30]:
load_pretrained_model = False

if load_pretrained_model:
    load_model(model, optimizer, params['device'])

In [31]:
def train(model, optimizer, max_seq_len,
          train_data, val_data, test_data,
          epochs, max_norm, device,
          tokenizer, model_path, n_prints=10):
    
    min_mean_val_loss = float('+inf')
    print_every = train_n_batches // n_prints
    
    train_samples = []
    val_samples = []
    to_dataloader = True
    
    for epoch in tqdm(range(0, epochs), 'Epochs'):
        running_train_loss = 0.0
        print(f'\nEpoch [{epoch} / {epochs}]')
        
        model.train()
        tqdm_iter_batch = tqdm(train_data, desc='Training iterations', total=train_n_batches)
        for iteration, (input, target) in enumerate(tqdm_iter_batch):
            train_samples.append((input, target))
            
            input = {k: v.to(device) for k, v in input.items()}
            target = target.to(device)
            
            optimizer.zero_grad()
            
            output = model(**input, labels=target)
            
            loss = output.loss
            
            running_train_loss += loss.item()
            
            tqdm_iter_batch.set_postfix({'train_loss': loss.item()})
            
            loss.backward()
            
#             global_step = epoch * (len(train_data) + 1) + iteration
#             train_loss_writer.add_scalar('Training loss', loss, global_step=global_step)
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            
            optimizer.step()
            
            if iteration % print_every == 0:
                mean_train_loss = running_train_loss / print_every if iteration != 0 else running_train_loss
                running_train_loss = 0
                print(f'\n\tIteration #{iteration}: training loss = {mean_train_loss}\n')
                
                if iteration != 0:
                    save_model(model, optimizer)
                
                test_sample = choice(test_data)
                seq_len = test_sample[0].input_ids.shape[1]
                
                generated = model.generate(test_sample[0].input_ids.to(device), min_length=seq_len)

                decoded_output = tokenizer.decode(generated[0])
                decoded_input  = tokenizer.decode(test_sample[0].input_ids.squeeze(0))
                decoded_target = tokenizer.decode(test_sample[1].squeeze(0))

                print(f'\tInput : {decoded_input}')
                print(f'\tOutput: {decoded_output}')
                print(f'\tTarget: {decoded_target}')
                
                
            if to_dataloader:
                train_data = DataLoader(BatchedDataset(train_samples), batch_size=None, shuffle=True)
                val_data = DataLoader(BatchedDataset(val_samples), batch_size=None, shuffle=False)
                to_dataloader = False

In [32]:
train(model, optimizer, params['max_seq_len'],
      train_data, val_data, test_data,
      params['epochs'], params['max_norm'], params['device'],
      tokenizer, './seq2seq_mt5_finetuned.model')

Epochs:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch [0 / 5]


Training iterations:   0%|          | 0/54193 [00:00<?, ?it/s]


	Iteration #0: training loss = 4.643941879272461

	Input : правда <fem_g> <pres_t> <sing_n> <s> о это в среда, 9 ноября, сообщать « Украинская правда ». </s>
	Output: </s><s> <fem_g> <pres_t> <unk>................... о это в среда, 9 ноября, сообщать « Украинская правда ». </s>
	Target: <s> Об этом в среду, 9 ноября, сообщает «Украинская правда». </s>

	Iteration #5419: training loss = 0.22350331658865583


	Model saved successfully at ./seq2seq_mbart_finetuned.model

	Input : я <masc_g> <past_t> <sing_n> <s> именно поэтому я хотеть вы навестить », — сказать генерал в время встреча в среда, 6 апрель, с начальником генштаб вс рф валерий герасимов. </s>
	Output: </s><s> Именно поэтому я хочу вас навестить», — сказал генерал во время встречи в среду, 6 апреля, с начальником генштаба МВД РФ Валерием Герасиным. ✎✎✎✎✎✎<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

Training iterations:   0%|          | 0/54193 [00:00<?, ?it/s]


	Iteration #0: training loss = 0.07508623600006104

	Input : нарушение <neut_g> <pres_t> <sing_n> <s> по данные гибдд, именно это нарушение являться причина большинство авария с смертельный исходом. </s>
	Output: </s><s> По данным ГИБДД, это нарушение является причиной большинства аварий со смертельным исходом. именно именно, помнит гибдд, это нарушение является причиной большинства аварий со смертельным. </s>
	Target: <s> По данным ГИБДД, именно это нарушение является причиной большинства аварий со смертельным исходом. </s>


KeyboardInterrupt: 

In [33]:
torch.save(model.state_dict(), './seq2seq_finetuned_mbart_inference.model')