In [None]:
# Uncomment if notebook is run in Google Colab
# # Install required libraries
# %%capture
# !pip install transformers datasets evaluate
# !pip install rouge-score
# !pip install nltk
# !pip install sentencepiece

In [None]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from datasets import load_dataset
import evaluate
import torch, random
import nltk
from nltk import sent_tokenize

In [None]:
nltk.download('punkt')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
rouge = evaluate.load('rouge', seed=2022)

In [None]:
model_checkpoint = 'google/pegasus-cnn_dailymail'

tokenizer = PegasusTokenizer.from_pretrained(model_checkpoint)

model = PegasusForConditionalGeneration.from_pretrained(model_checkpoint)
model.to(device)

In [None]:
data_test = load_dataset('cnn_dailymail', '3.0.0', split='test', trust_remote_code=True)

In [None]:
batch_size = 4

In [None]:
# map data correclty
def generate_summary(batch):
    inputs = tokenizer(batch['article'],
                       padding=True,
                       truncation=True,
                       max_length=1024,
                       add_special_tokens=False,
                       return_tensors='pt')
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    # The following decoding parameters are set according to the PEGASUS's config.json file
    outputs = model.generate(input_ids,
                             attention_mask=attention_mask,
                             min_length=32,
                             max_length=128,
                             num_beams=8,
                             length_penalty=0.8
                             )

    # all special tokens will be removed
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch['pred'] = output_str

    return batch

In [None]:
results = data_test.map(generate_summary,
                        batched=True,
                        batch_size=batch_size,
                        remove_columns=['article'])

labels = results['highlights']
predictions = results['pred']

In [None]:
# ROUGE expects a newline after each sentence
clean_preds = ["\n".join(sent_tokenize(pred.replace("<n>", " " ))) for pred in predictions]
clean_labels = [label.replace(" .", ".") for label in labels]

for i in range(len(predictions)):
    print(f"Item {i}:")
    print(f"Ground truth: {clean_labels[i]}")
    print(f"Prediction: {clean_preds[i]}")
    print("\n")

In [None]:
rouge_output = rouge.compute(predictions=clean_preds, references=clean_labels, use_stemmer=True)
rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
rouge_dict = dict((rn, round(rouge_output[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)

print(rouge_dict)

In [None]:
# ROUGE expects a newline after each sentence
clean_preds = ["\n".join(sent_tokenize(pred.replace("<n>", "\n"))) for pred in predictions]
clean_labels = [label.replace(" .", ".") for label in labels]

for i in range(len(predictions)):
    print(f"Item {i}:")
    print(f"Ground truth: {clean_labels[i]}")
    print(f"Prediction: {clean_preds[i]}")
    print("\n")

In [None]:
rouge_output = rouge.compute(predictions=clean_preds, references=clean_labels, use_stemmer=True)
rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
rouge_dict = dict((rn, round(rouge_output[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)

print(rouge_dict)