In [1]:
import re

from datasets import load_dataset, load_metric
import evaluate
import nltk
import nltk.data
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import (
    AdamW, AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

In [2]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\milan\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

## Setup and Preprocessing for Model Load

In [3]:
DATASET_NAME = "multi_x_science_sum"
DOC_SEP = " ||||| "
BATCH_SIZE = 16
MAX_LENGTH_ENC = 4096
MAX_LENGTH_DEC = 256

rouge = load_metric("rouge")

dataset = load_dataset(DATASET_NAME)

pat = re.compile("@cite_[0-9]+")

  import sys
Found cached dataset multi_x_science_sum (C:/Users/milan/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
def preprocess_dataset(example):

    abstracts = example["abstract"].split("| Abstract: ")[-1]
    related_work = pat.sub("@cite", example["related_work"])
    ref_abstracts = filter(bool, example["ref_abstract"]["abstract"])
    output = {
        "abstracts": f"{abstracts}{DOC_SEP}{DOC_SEP.join(ref_abstracts)}",
        "related_work": related_work
    }
    return output

def preprocess_dataset_batched(example):
    abstracts = [
        abstract.split("| Abstract: ")[-1] + DOC_SEP + DOC_SEP.join([x for x in ref_abstract["abstract"] if x])
        for abstract, ref_abstract in zip(example["abstract"], example["ref_abstract"])
    ]
    related_work = [pat.sub("@cite", rw) for rw in example["related_work"]]
    output = {
        "abstracts": abstracts,
        "related_work": related_work,
    }
    return output

dataset_processed = {}
for split in dataset.keys():
    dataset_processed[split] = dataset[split].map(
        preprocess_dataset_batched,
        remove_columns=dataset[split].column_names,
        batched=True,
        batch_size=BATCH_SIZE,
    )

Loading cached processed dataset at C:\Users\milan\.cache\huggingface\datasets\multi_x_science_sum\default\1.1.0\2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729\cache-b4725fa052a1384b.arrow
Loading cached processed dataset at C:\Users\milan\.cache\huggingface\datasets\multi_x_science_sum\default\1.1.0\2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729\cache-cee45552c5bd8e14.arrow
Loading cached processed dataset at C:\Users\milan\.cache\huggingface\datasets\multi_x_science_sum\default\1.1.0\2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729\cache-e4494002ad8ba52d.arrow


In [5]:
dataset_processed

{'train': Dataset({
     features: ['related_work', 'abstracts'],
     num_rows: 30369
 }),
 'test': Dataset({
     features: ['related_work', 'abstracts'],
     num_rows: 5093
 }),
 'validation': Dataset({
     features: ['related_work', 'abstracts'],
     num_rows: 5066
 })}

## Loading the Model and Tokenizer

In [6]:
def get_tokenizer(host_tokenizer: str):
  """return the tokenizer and model for LLM training"""

  return (AutoTokenizer.from_pretrained(host_tokenizer), AutoModelForSeq2SeqLM.from_pretrained(host_tokenizer))


centrum_tokenizer, centrum_model = get_tokenizer("ratishsp/Centrum") 

In [7]:
print(centrum_tokenizer, centrum_model)

LEDTokenizerFast(name_or_path='ratishsp/Centrum', vocab_size=50265, model_max_length=16384, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)}) LEDForConditionalGeneration(
  (led): LEDModel(
    (shared): Embedding(50266, 768, padding_idx=1)
    (encoder): LEDEncoder(
      (embed_tokens): E

In [8]:
def tokenize_dataset_batched(example):
    # Tokenizer input
    input_encoding = centrum_tokenizer(
        example["abstracts"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH_ENC,
        return_tensors="pt",
    )

    # Tokenizer output
    output_encoding = centrum_tokenizer(
        example["related_work"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH_DEC,
        return_tensors="pt",
    )

    # Modify output encoding to ignore padding in loss function
    # torch ignore -100 in loss function computation
    labels = output_encoding["input_ids"].clone()
    labels[labels == centrum_tokenizer.pad_token_id] = -100

    # Global attention with vectorized operations (optimized for GPU)
    input_ids = input_encoding["input_ids"]
    docsep_token_id = centrum_tokenizer.convert_tokens_to_ids(DOC_SEP)
    global_attention_mask = (input_ids == centrum_tokenizer.cls_token_id) | (input_ids == docsep_token_id)

    return {
        "input_ids": input_encoding["input_ids"],
        "attention_mask": input_encoding["attention_mask"],
        "global_attention_mask": global_attention_mask.float(),
        "labels": labels,
    }

centrum_tokenizer.add_tokens(DOC_SEP, special_tokens=True)
centrum_model.resize_token_embeddings(len(centrum_tokenizer))
docsep_token_id = centrum_tokenizer.convert_tokens_to_ids(DOC_SEP)

dataset_tokenized = {}
for split in dataset_processed.keys():
    dataset_tokenized[split] = (
        dataset_processed[split]
        .select(range(200))
        .map(
            tokenize_dataset_batched,
            remove_columns=dataset_processed[split].column_names,
            batched=True,
            batch_size=BATCH_SIZE,
        )
    )

Loading cached processed dataset at C:\Users\milan\.cache\huggingface\datasets\multi_x_science_sum\default\1.1.0\2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729\cache-97a7b1f93f662561.arrow


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\milan\.cache\huggingface\datasets\multi_x_science_sum\default\1.1.0\2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729\cache-17898f0ab0a58fc9.arrow


In [9]:
dataset_tokenized

{'train': Dataset({
     features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
     num_rows: 200
 }),
 'test': Dataset({
     features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
     num_rows: 200
 }),
 'validation': Dataset({
     features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
     num_rows: 200
 })}

## Fine-tuning the Model with Trainer

In [10]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    predict_with_generate=True,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=1000,
    num_train_epochs=1,
    learning_rate=2e-4,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    warmup_steps=250,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=500
)

In [11]:
trainer = Seq2SeqTrainer(
    model=centrum_model,
    args=training_args,
    train_dataset=dataset_tokenized["train"],
    eval_dataset=dataset_tokenized["validation"],
)

In [12]:
trainer.train()



  0%|          | 0/13 [00:00<?, ?it/s]

RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 226492416 bytes.