In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW, Adam
from torch.amp import GradScaler, autocast
from transformers import MBartTokenizer, MBartForConditionalGeneration
from transformers import T5TokenizerFast, AutoModelForSeq2SeqLM 
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm

In [None]:
custom_cache_dir = "/home/maantonov_1/HF_data"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "Kirili4ik/mbart_ruDialogSum"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=custom_cache_dir)
model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=custom_cache_dir).to(device)

dataset = load_dataset("RussianNLP/Mixed-Summarization-Dataset")

In [None]:
def preprocess_function(examples):
    inputs = [doc for doc in examples['text']]
    model_inputs = tokenizer(inputs, max_length=600, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['summary'], max_length=150, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
path = 'DL_PROJECT/data_mbart_tok'
try:
    tokenized_datasets = load_from_disk(path)
except:
    tokenized_datasets = dataset.map(preprocess_function, batched=True)
    tokenized_datasets.save_to_disk(path)


train_dataset = tokenized_datasets['train']
eval_dataset = tokenized_datasets['test']

In [None]:
def collate_fn(batch):
    input_ids = torch.stack([torch.tensor(item['input_ids'], dtype=torch.long) for item in batch])
    attention_mask = torch.stack([torch.tensor(item['attention_mask'], dtype=torch.long) for item in batch])
    labels = torch.stack([torch.tensor(item['labels'], dtype=torch.long) for item in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn=collate_fn)
eval_dataloader = DataLoader(eval_dataset, batch_size=8, collate_fn=collate_fn)

In [None]:
torch.cuda.empty_cache()
optimizer = Adam(model.parameters(), lr=5e-5)


accumulation_steps = 16
scaler = GradScaler()

model.train()
for epoch in range(3):
    optimizer.zero_grad() 
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}", leave=False)
    
    running_loss = 0
    
    for i, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        with autocast('cuda'):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
        
        running_loss += loss.item()
        loss = loss / accumulation_steps
        scaler.scale(loss).backward()
        
        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
        progress_bar.set_postfix({"Running Loss": running_loss / (i+1)})
        
    if len(train_dataloader) % accumulation_steps != 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    
    progress_bar.close()

In [None]:
save_directory = "/home/maantonov_1/HF_data/mbart"

model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)