In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.amp import GradScaler, autocast
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from datasets import load_dataset
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
custom_cache_dir = "/home/maantonov_1/HF_data"

model_name = "IlyaGusev/rugpt3medium_sum_gazeta"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=custom_cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=custom_cache_dir).to(device)
dataset = load_dataset("RussianNLP/Mixed-Summarization-Dataset")

In [None]:
def tokenize_function(examples):
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    sep_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.convert_tokens_to_ids("\n")
    for text, summary in zip(examples["text"], examples["summary"]):
        article_tokens = tokenizer(text, max_length=600, truncation=True, add_special_tokens=False)["input_ids"]
        summary_tokens = tokenizer(summary, max_length=128, truncation=True, add_special_tokens=False)["input_ids"]
        input_ids = article_tokens + [sep_token_id] + summary_tokens
        attention_mask = [1] * len(input_ids)
        labels = [-100] * (len(article_tokens) + 1) + summary_tokens
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        labels_list.append(labels)
    return {"input_ids": input_ids_list, "attention_mask": attention_mask_list, "labels": labels_list}

try:
    from datasets import load_from_disk

    dataset_path = "DL_PROJECT/data3"

    tokenized_dataset = load_from_disk(dataset_path)
except:
    tokenized_dataset = dataset['train'].map(tokenize_function, batched=True)
    tokenized_dataset.save_to_disk("DL_PROJECT/data3")

def collate_fn(batch):
    max_length = max(len(item["input_ids"]) for item in batch)
    input_ids_batch, attention_mask_batch, labels_batch = [], [], []
    for item in batch:
        pad_len = max_length - len(item["input_ids"])
        input_ids_batch.append(item["input_ids"] + [tokenizer.pad_token_id] * pad_len)
        attention_mask_batch.append(item["attention_mask"] + [0] * pad_len)
        labels_batch.append(item["labels"] + [-100] * pad_len)
    return {
        "input_ids": torch.tensor(input_ids_batch, dtype=torch.long),
        "attention_mask": torch.tensor(attention_mask_batch, dtype=torch.long),
        "labels": torch.tensor(labels_batch, dtype=torch.long)
    }

train_dataloader = DataLoader(tokenized_dataset, batch_size=8, num_workers=2, shuffle=True, collate_fn=collate_fn)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)
scaler = GradScaler('cuda')
accumulation_steps = 8

model.train()
for epoch in range(3):
    optimizer.zero_grad()
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
    running_loss = 0
    for i, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        with autocast(device_type='cuda'):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss / accumulation_steps
        scaler.scale(loss).backward()
        if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_dataloader):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        running_loss += loss.item() * accumulation_steps
        progress_bar.set_postfix({"Loss": running_loss / (i + 1), "cur loss":loss.item() * accumulation_steps})
    progress_bar.close()
    save_path = f"./gpt3_sum_epoch_{epoch + 1}"
    os.makedirs(save_path, exist_ok=True)
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)

