In [None]:
from IPython.display import clear_output
!pip install -q triton bitsandbytes accelerate hf_xet
clear_output()

In [None]:
import gc

import numpy as np
import pandas as pd

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    BitsAndBytesConfig,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from huggingface_hub import login
import torch.utils.checkpoint

torch.utils.checkpoint.use_reentrant = False

In [None]:
BASE_MODEL = "google/long-t5-tglobal-base"
REPO_NAME = "Mels22/longt5-scisummnet"
DATA_CSV = "/kaggle/input/scisummnet-corpus/scisumm.csv"
LOCAL_DIR = "./checkpoint"

CHUNK_SIZE = 8192
OVERLAP_SIZE = 512
MAX_TARGET_LENGTH = 512

BATCH_SIZE = 4
TRAIN_BATCH_SIZE = 4
EVAL_BATCH_SIZE = 4
GRADIENT_ACCUMULATION = 4

LR = 5e-4
HF_TOKEN = "YOUR_HUGGINGFACE_TOKEN"  # Replace with your Hugging Face token
login(token=HF_TOKEN)

In [None]:
class ScisummnetDataset:
    def __init__(self, path, tokenizer, chunk_size=CHUNK_SIZE, overlap=OVERLAP_SIZE):
        df = pd.read_csv(path)
        self.hf_dataset = Dataset.from_pandas(df)
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size
        self.overlap = overlap

    def _process_data_to_model_inputs(self, batch):
        all_input_ids = []
        all_attention_masks = []
        all_labels = []

        for text, summary in zip(batch["text"], batch["summary"]):
            tokenized_inputs = self.tokenizer(
                text,
                return_overflowing_tokens=True,
                stride=self.overlap,
                truncation=True,
                max_length=self.chunk_size,
                padding="max_length",
            )

            tokenized_outputs = self.tokenizer(
                summary,
                truncation=True,
                max_length=MAX_TARGET_LENGTH,
                padding="max_length",
            )

            for input_ids, attention_mask in zip(
                tokenized_inputs["input_ids"], tokenized_inputs["attention_mask"]
            ):
                # Apply -100 masking to pad tokens in the label
                labels = [
                    -100 if token == self.tokenizer.pad_token_id else token
                    for token in tokenized_outputs["input_ids"]
                ]

                all_input_ids.append(input_ids)
                all_attention_masks.append(attention_mask)
                all_labels.append(labels)

        return {
            "input_ids": np.array(all_input_ids, dtype=np.int64),
            "attention_mask": np.array(all_attention_masks, dtype=np.int64),
            "labels": np.array(all_labels, dtype=np.int64),
        }

    def get_data(self, test_size=0.1):
        split_data = self.hf_dataset.train_test_split(test_size=test_size)
        train_ds = split_data["train"]
        val_ds = split_data["test"]

        train_data = train_ds.map(
            self._process_data_to_model_inputs,
            batched=True,
            batch_size=BATCH_SIZE,
            remove_columns=["text", "summary"],
        )
        train_data.set_format(
            type="torch",
            columns=["input_ids", "attention_mask", "labels"],
            output_all_columns=False,  # make sure only required tensors are kept
        )

        val_data = val_ds.map(
            self._process_data_to_model_inputs,
            batched=True,
            batch_size=BATCH_SIZE,
            remove_columns=["text", "summary"],
        )
        val_data.set_format(
            type="torch",
            columns=["input_ids", "attention_mask", "labels"],
            output_all_columns=False,  # make sure only required tensors are kept
        )

        return {"train": train_data, "val": val_data}, val_ds

In [None]:
class LongT5Model:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_checkpoint_dir = None
        self.resume_training = None
        self.data_collator = None
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        self._load_model()

    def _load_model(self):
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            BASE_MODEL,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            quantization_config=self.bnb_config,
        )
        lora_config = LoraConfig(
            use_dora=True,
            r=8,
            lora_alpha=16,
            lora_dropout=0.05,
            target_modules=["q", "k", "v", "out"],
            bias="none",
            task_type="SEQ_2_SEQ_LM",
        )
        self.model = prepare_model_for_kbit_training(base_model)
        self.model = get_peft_model(self.model, lora_config)
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        self.data_collator = DataCollatorForSeq2Seq(self.tokenizer, model=self.model)

    def train(self, train_data, epochs, lr=LR, commit_message="Done training"):
        if self.model is None:
            raise
        gc.collect()
        torch.cuda.empty_cache()

        self.model.train()
        self.model.config.use_cache = False
        self.model.gradient_checkpointing_enable()
        for name, param in self.model.named_parameters():
            if param.dtype in [
                torch.float32,
                torch.float16,
                torch.bfloat16,
                torch.complex64,
                torch.complex128,
            ]:
                param.requires_grad = True
        training_args = Seq2SeqTrainingArguments(
            num_train_epochs=epochs,
            output_dir=LOCAL_DIR,
            learning_rate=lr,
            save_strategy="epoch",
            save_total_limit=1,
            weight_decay=0.01,
            optim="paged_adamw_8bit",
            lr_scheduler_type="cosine",
            warmup_ratio=0.1,
            bf16=torch.cuda.is_bf16_supported(),
            label_names=["labels"],
            per_device_train_batch_size=TRAIN_BATCH_SIZE,
            gradient_accumulation_steps=GRADIENT_ACCUMULATION,
            report_to="none",
            logging_steps=0.1,
            push_to_hub=True,
            hub_model_id=REPO_NAME,
            hub_strategy="checkpoint",
        )

        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_data["train"],
        )

        trainer.train()
        trainer.save_model()
        trainer.push_to_hub(commit_message)
        print(f"Pushed to HUB")

In [None]:
t5 = LongT5Model()

In [None]:
scisummnet = ScisummnetDataset(DATA_CSV, t5.tokenizer)
data_loader, val_df = scisummnet.get_data()

In [None]:
t5.train(data_loader, epochs=3, commit_message="Train for 3 epochs")