In [1]:
from transformers import PegasusXForConditionalGeneration, PegasusXConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from datasets import load_from_disk
from transformers import AutoTokenizer
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
checkpoint_path = "/kaggle/input/m/frostedpilot/pegasus-x/transformers/default/1/checkpoint-3784"

In [3]:
train_data = load_from_disk('/kaggle/input/mslr2022/uncleaned_mslr/mslr_train')
test_data = load_from_disk('/kaggle/input/mslr2022/uncleaned_mslr/mslr_test')
infer_data = load_from_disk('/kaggle/input/mslr2022/uncleaned_mslr/mslr_inference')

In [4]:
train_data = train_data.to_pandas()
test_data = test_data.to_pandas()
infer_data = infer_data.to_pandas()

In [5]:
MAX_INPUT_LENGTH=4096
MAX_OUTPUT_LENGTH=512
NON_MASK_RATIO=0.5

In [6]:
tokenizer = AutoTokenizer.from_pretrained("google/pegasus-x-base")
tokenizer.add_special_tokens({'additional_special_tokens': ['<SEP>']})

tokenizer_config.json:   0%|          | 0.00/2.02k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/1.91M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/6.60M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

1

In [7]:
def get_src_tgt_and_mask(truncated_docs,target,tokenizer,max_len_input,max_len_output):
    src=tokenizer(truncated_docs,max_length=max_len_input,padding="max_length",truncation=True)
    tgt=tokenizer(target,max_length=max_len_output,padding="max_length",truncation=True)
    input_ids=src.input_ids
    global_attention_mask=[0 for _ in range(len(input_ids))]
    global_attention_mask[input_ids==tokenizer.vocab["<SEP>"]]=1
    global_attention_mask[0]=1
    labels=tgt.input_ids
    labels = [label if label != tokenizer.pad_token_id else -100 for label in labels] 
    return {
        "input_ids":torch.tensor(input_ids,dtype=torch.long),
        "attention_mask":torch.tensor(src.attention_mask,dtype=torch.long),
        "global_attention_mask":torch.tensor(global_attention_mask,dtype=torch.long),
        "labels":torch.tensor(labels,dtype=torch.long)
    }
class PegasusDataset(Dataset):
    def __init__(self,data,tokenizer,max_input_len=4096,max_output_len=512):
        self.data=data
        self.max_input_len=max_input_len
        self.max_output_len=max_output_len
        self.tokenizer=tokenizer
    def __len__(self):
        return self.data.shape[0]
    def __getitem__(self,index):
        row=self.data.loc[index]
        data=get_src_tgt_and_mask(row["title_abstract"],row["target"],self.tokenizer,self.max_input_len,self.max_output_len)
        return data

In [8]:
train_dataset = PegasusDataset(train_data, tokenizer)
val_dataset = PegasusDataset(test_data, tokenizer)
infer_dataset = PegasusDataset(infer_data, tokenizer)

In [9]:
config = PegasusXConfig.from_pretrained(checkpoint_path)
model = PegasusXForConditionalGeneration.from_pretrained(checkpoint_path, config=config)

In [10]:
model.resize_token_embeddings(len(tokenizer))
model.config.max_decoder_position_embeddings=512
model.gradient_checkpointing_enable()

In [11]:
batch_size=3
num_devices=torch.cuda.device_count()
batch_size_per_device=batch_size//num_devices
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True, 
    per_device_train_batch_size=batch_size_per_device,
    per_device_eval_batch_size=batch_size_per_device,
    output_dir="./results/",
    logging_dir="./logs/test/",
    save_strategy="steps",
    save_steps=250,
    logging_strategy="steps",
    logging_steps=20,
    eval_strategy="steps",
    eval_steps=250,
    save_total_limit=2,
    #load_best_model_at_end=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    gradient_accumulation_steps=5,
    num_train_epochs=6,
    max_grad_norm=1.0,
    learning_rate=1e-5,
    warmup_steps=1000,
    lr_scheduler_type="linear",
    report_to="wandb",
    run_name="Restart",
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

In [12]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb_key")

In [13]:
torch.cuda.empty_cache()

In [14]:
wandb.login(key=wandb_key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtuantruongvu[0m ([33mtuantruongvu-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [15]:
trainer.train(resume_from_checkpoint=checkpoint_path)

There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].
  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250429_034040-4rglqfpr[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mRestart[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/tuantruongvu-hanoi-university-of-science-and-technology/huggingface[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/tuantruongvu-hanoi-university-of-science-and-technology/huggingface/runs/4rglqfpr[0m
  checkpoint_rng_state = torch.load(rng_file)


Step,Training Loss,Validation Loss
4000,2.9632,2.858593
4250,3.1001,2.8553
4500,3.0143,2.851815
4750,3.0439,2.849363
5000,2.9923,2.847518
5250,3.0499,2.846164
5500,3.0817,2.845488


TrainOutput(global_step=5676, training_loss=1.0028466049258855, metrics={'train_runtime': 36652.3138, 'train_samples_per_second': 2.323, 'train_steps_per_second': 0.155, 'total_flos': 4.1504592632374886e+17, 'train_loss': 1.0028466049258855, 'epoch': 6.0})

In [16]:
torch.cuda.empty_cache()

In [17]:
from tqdm import tqdm

In [18]:
!pip install rouge-score -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
