In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset

In [None]:
#Using a Pegasus transformer trained on the CNN-DailyMail dataset

tokenizer = AutoTokenizer.from_pretrained('google/pegasus-cnn_dailymail')
model = AutoModelForSeq2SeqLM.from_pretrained('google/pegasus-cnn_dailymail')

In [3]:
dataset = load_dataset('samsum', trust_remote_code=True)    #Using the SAMSUM conversation dataset for finetuning

In [None]:
dataset

In [5]:
max_input_length = 512      #Parameters to handle the model's constraints
max_target_length = 128

In [6]:
def preprocess_function(examples):
    '''Pre-processing function used to prepare dataset for training'''
    inputs = [dialogue for dialogue in examples['dialogue']]
    model_inputs = tokenizer(inputs, max_length=max_input_length, padding="max_length", truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['summary'], max_length=max_target_length, padding="max_length", truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = dataset.map(preprocess_function, batched=True) #Mapping pre-processing function to dataset

In [None]:
#The values of the arguments vary as per need
#Here, mixed precision training, gradient checkpointing and gradient accumulation are used to speed up training
training_args = Seq2SeqTrainingArguments(
    output_dir="./pegasus-samsum",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    fp16=True,
    gradient_checkpointing=True,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    gradient_accumulation_steps=4,
    predict_with_generate=False,
    logging_dir="./logs",
    logging_steps=100
    push_to_hub=False
)

In [None]:
#Creating a trainer for the model using the training arguments
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer
)

In [None]:
trainer.train() #Training the model...trained model will be saved in the pegasus-samsum directory