In [None]:
# Uncomment if notebook is run in Google Colab
# # Install required libraries
# %%capture
# !pip install transformers datasets evaluate
# !pip install rouge-score
# !pip install nltk

In [None]:
from transformers import BartForConditionalGeneration, BartTokenizer
from datasets import load_dataset
import evaluate
import torch, random
import nltk
from nltk import sent_tokenize

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
nltk.download('punkt')

In [None]:
rouge = evaluate.load('rouge', seed=2022)

In [None]:
model_checkpoint = 'facebook/bart-large-xsum'

In [None]:
tokenizer = BartTokenizer.from_pretrained(model_checkpoint)

model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
model.to(device)

In [None]:
model.num_parameters()

In [None]:
data_test = load_dataset('xsum', split='test', trust_remote_code=True)

In [None]:
batch_size = 8

In [None]:
# map data correclty
def generate_summary(batch):
    inputs = tokenizer(batch['document'],
                       padding=True,
                       truncation=True,
                       max_length=1024,
                       add_special_tokens=False,
                       return_tensors='pt')
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    # The following decoding parameters are set according to the BART's config.json file
    outputs = model.generate(input_ids,
                             attention_mask=attention_mask,
                             min_length=11,
                             max_length=62,
                             num_beams=6,
                             no_repeat_ngram_size=3,
                             length_penalty=1.0,
                             early_stopping=True
                             )

    # all special tokens will be removed
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch['pred'] = output_str

    return batch

In [None]:
results = data_test.map(generate_summary,
                        batched=True,
                        batch_size=batch_size,
                        remove_columns=['document', 'id'])

labels = results['summary']
predictions = results['pred']

In [None]:
# ROUGE expects a newline after each sentence
clean_preds = ["\n".join(sent_tokenize(pred.replace('[X_SEP]', ' '))) for pred in predictions]
clean_labels = [label.replace(" .", ".") for label in labels]

for i in range(len(predictions)):
    print(f"Item {i}:")
    print("Label:")
    print(clean_labels[i])
    print("\n")
    print("Prediction:")
    print(clean_preds[i])
    print("\n")

In [None]:
rouge_output = rouge.compute(predictions=clean_preds, references=clean_labels, use_stemmer=True)
rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
rouge_dict = dict((rn, round(rouge_output[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)

print(rouge_dict)