In [None]:
# Uncomment if notebook is run in Google Colab
# # Install required libraries
# %%capture
# !pip install transformers datasets evaluate
# !pip install rouge-score
# !pip install nltk
# !pip install sentencepiece

In [None]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from datasets import load_dataset
import evaluate
import torch
import nltk
from nltk import sent_tokenize

In [None]:
nltk.download('punkt')

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
rouge = evaluate.load('rouge', seed=2022)

In [None]:
model_checkpoint = 'google/pegasus-xsum'

tokenizer = PegasusTokenizer.from_pretrained(model_checkpoint)

model = PegasusForConditionalGeneration.from_pretrained(model_checkpoint)
model.to(device)

In [None]:
data_test = load_dataset('xsum', split='test', trust_remote_code=True)

In [None]:
batch_size = 8

In [None]:
# map data correclty
def generate_summary(batch):
    inputs = tokenizer(batch['document'],
                       padding=True,
                       truncation=True,
                       max_length=512,
                       add_special_tokens=False,
                       return_tensors='pt')
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    # Decoding parameters set according to config.json
    outputs = model.generate(input_ids,
                             attention_mask=attention_mask,
                             max_length=64, # config
                             num_beams=8, # config
                             length_penalty=0.6 # config
                             )

    # all special tokens will be removed
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    batch['prediction'] = output_str

    return batch

In [None]:
results = data_test.map(generate_summary,
                        batched=True,
                        batch_size=batch_size,
                        remove_columns=['document', 'id'])

labels = results['summary']
predictions = results['prediction']

In [None]:
# ROUGE expects a newline after each sentence
clean_preds = ["\n".join(sent_tokenize(pred.replace('<n>', '\n'))) for pred in predictions]
clean_labels = [label.replace(" .", ".") for label in labels]

for i in range(len(predictions)):
    print(f"Item {i}:")
    print(f"Ground truth: {clean_labels[i]}")
    print(f"Prediction: {clean_preds[i]}")
    print("\n")

In [None]:
rouge_output = rouge.compute(predictions=clean_preds, references=clean_labels, use_stemmer=True)
rouge_metrics = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
rouge_scores = {metric: round(rouge_output[metric].mid.fmeasure * 100, 2)
                for metric in rouge_metrics}

print(rouge_scores)