In [None]:
import os
import torch
import pandas as pd
from transformers import MBartForConditionalGeneration, MBart50Tokenizer, Trainer, TrainingArguments

In [None]:
data_dir = '../data/'
model_name = 'facebook/mbart-large-50'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class MBARTDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels['input_ids'][idx])
        return item
        
    def __len__(self):
        return len(self.encodings['input_ids'])

      
def prepare_data(model_name, 
                 train_texts, train_labels, 
                 val_texts=None, val_labels=None, 
                 test_texts=None, test_labels=None):
    """
    Prepare input data for model fine-tuning
    """
    tokenizer = MBart50Tokenizer.from_pretrained(model_name, src_lang="ru_RU", tgt_lang="ru_RU")

    prepare_val = False if val_texts is None or val_labels is None else True
    prepare_test = False if test_texts is None or test_labels is None else True

    def tokenize_data(texts, labels):
        encodings = tokenizer(texts, truncation=True, padding=True)
        decodings = tokenizer(labels, truncation=True, padding=True)
        dataset_tokenized = MBARTDataset(encodings, decodings)
        return dataset_tokenized

    train_dataset = tokenize_data(train_texts, train_labels)
    val_dataset = tokenize_data(val_texts, val_labels) if prepare_val else None
    test_dataset = tokenize_data(test_texts, test_labels) if prepare_test else None

    return train_dataset, val_dataset, test_dataset

In [None]:
train = pd.read_csv(os.path.join(data_dir, 'train.csv'), index=None, sep='\t')
val = pd.read_csv(os.path.join(data_dir, 'train.csv'), index=None, sep='\t')

In [None]:
train_texts, train_labels = train['text'].tolist(), train['title'].tolist()
val_texts, val_labels = val['text'].tolist(), val['title'].tolist()

In [None]:
train_dataset, val_dataset, _ = prepare_data(model_name, train_texts, train_labels, val_texts, val_labels)

In [None]:
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)

In [None]:
training_args = TrainingArguments(           
          num_train_epochs=2,           
          per_device_train_batch_size=3, 
          per_device_eval_batch_size=3, 
          save_steps=5000,                                           
          weight_decay=0.01, 
          evaluation_strategy='steps',
          eval_steps=5000,          
          logging_steps=1000,
          save_total_limit=5,             
          logging_dir='../logs', 
          output_dir='../checkpoints'
        )

trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset)

In [None]:
trainer.train()