In [1]:
from datasets import load_dataset
import torch
import evaluate
import nltk
from tqdm import tqdm
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large')
model = PegasusForConditionalGeneration.from_pretrained('google/pegasus-large')

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-large and are newly initialized: ['model.encoder.embed_positions.weight', 'model.decoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
nltk.download('punkt')
rouge = evaluate.load('rouge')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\DuongNgoKien\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
dataset = load_dataset('xsum')

In [4]:
train_texts, train_labels = dataset['train']['document'][:1000], dataset['train']['summary'][:1000]

In [5]:
def process_data(document):
    labels = ""
    l_sentences = document.split("\n")
    sentences = []
    for s in l_sentences:
        s = nltk.sent_tokenize(s)
        sentences.extend(s)
    len_doc = len(sentences)
    n_mask = int(len_doc / 5)
    if n_mask == 0: n_mask =1
    ma_rouge = torch.zeros(size=(len_doc, len_doc))
    for i in range(len_doc):
        for j in range(i+1, len_doc):
            m = rouge.compute(predictions=[sentences[i]], references=[[sentences[j]]])
            ma_rouge[i,j] = m['rouge1']
            ma_rouge[j,i] = m['rouge1']
    mean_rouge = torch.mean(ma_rouge, 1)
    _, indexes = torch.topk(mean_rouge, n_mask)
    indexes = torch.sort(indexes, 0).values
    for i in range(indexes.size()[0]):
        masked_sentence = sentences[indexes[i]]
        document = document.replace(masked_sentence, "<mask_1>")
        labels += masked_sentence + " "
    return document.replace('\n', ' '), labels

In [9]:
nltk.sent_tokenize('The long road to the 2026 World Cup in the United States, Canada and Mexico began in earnest in Asia with 36 teams in action Thursday.Australia, who reached the last 16 at the Qatar 2022 World Cup before bowing out 2-1 to Lionel Messi and eventual champions Argentina, hammered Bangladesh 7-0 in Melbourne')

['The long road to the 2026 World Cup in the United States, Canada and Mexico began in earnest in Asia with 36 teams in action Thursday.Australia, who reached the last 16 at the Qatar 2022 World Cup before bowing out 2-1 to Lionel Messi and eventual champions Argentina, hammered Bangladesh 7-0 in Melbourne']

In [6]:
f_text = open("texts1.txt", "a", encoding='utf-8')
f_label = open("labels1.txt", "a", encoding='utf-8')
for i in tqdm(range(500, 1000)):
    doc, label = process_data(train_texts[i])
    f_text.write(doc + "\n")
    f_label.write(label + "\n")
f_text.close()
f_label.close()

100%|██████████| 500/500 [13:21:22<00:00, 96.16s/it]     
