In [None]:
import datasets
import torch
from transformers import (
    default_data_collator,
    EarlyStoppingCallback,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer, 
    PegasusForConditionalGeneration, 
    PegasusTokenizer)
from datasets import load_metric

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset_path = '../datasets/'

In [None]:
max_source_length = 512
max_target_length = 64
padding = 'max_length' 

In [None]:
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum')
model = PegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum').to(device)

In [None]:
train_dataset = datasets.load_dataset('csv', data_files = dataset_path + 'sum_train.csv')['train']
val_dataset = datasets.load_dataset('csv', data_files = dataset_path + 'sum_val.csv')['train']
test_dataset = datasets.load_dataset('csv', data_files = dataset_path + 'sum_test.csv')['train']

In [None]:
data = {'train': train_dataset, 'validation': val_dataset, 'test': test_dataset}

In [None]:
def preprocess_function(samples):
    
    inputs = samples['argument']
    output = samples['key_point']
    inputs = ['' + inp for inp in inputs]

    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    with tokenizer.as_target_tokenizer():
        token = tokenizer(output, max_length=max_target_length, padding=padding, truncation=True)

    token['input_ids'] = [[(t if t != tokenizer.pad_token_id else -100) for t in tok] for tok in token['input_ids']]

    model_inputs['labels'] = token['input_ids']
    
    return model_inputs

In [None]:
train_dataset = train_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=1,
            load_from_cache_file=False,
        )

eval_dataset = val_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=1,
            load_from_cache_file=False,
        )

test_dataset = test_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=1,
            load_from_cache_file=False,
        )

In [10]:
bertscore_metric = load_metric('bertscore')
rouge_metric = load_metric('rouge')

def compute_metrics(eval):
    
    label_ids = eval.label_ids
    prediction_ids = eval.predictions

    prediction_str = tokenizer.batch_decode(prediction_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    result = rouge_metric.compute(predictions=prediction_str, references=label_str, use_stemmer=True)
    bert_scores = bertscore_metric.compute(predictions=[prediction_str], references=[label_str], lang='en')
    
    result = {key: round(value.mid.fmeasure * 100, 1) for key, value in result.items()}
    f1 = {'f1' : bert_scores['f1'][0]}
    precision = {'precision' : bert_scores['precision'][0]}
    recall = {'recall' : bert_scores['recall'][0]}
    result.update(f1)
    result.update(precision)
    result.update(recall)

    return result

In [None]:
training_args = Seq2SeqTrainingArguments(
    do_train=True,
    do_eval=True,
    predict_with_generate=True,
    evaluation_strategy='steps',
    output_dir='./models/pegasus_xsum/',
    num_train_epochs=30,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=50,
    weight_decay=0.01,
    logging_dir='./models/pegasus_xsum/',
    save_steps = 100,
    eval_steps=100,
    logging_steps=100,
    optim='adamw_hf',
    load_best_model_at_end=True,
    save_total_limit=4
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.005)]
)

model.to(device)

In [None]:
trainer.train()

In [None]:
trainer.save_model()