In [1]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from datasets import load_dataset, list_datasets, load_metric, list_metrics
import torch
from tqdm import tqdm
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

2022-04-10 19:03:06.818725: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [2]:
# Params
batch_size = 8


In [3]:
# Model and tokenizer
model_name = "google/pegasus-cnn_dailymail"
device = "cuda:2" if torch.cuda.is_available() else "cpu"
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)

In [4]:
# Load data
#xsum_train = load_dataset('xsum', split='train')
#xsum_valid = load_dataset('xsum', split='validation')
xsum_test = load_dataset('ccdv/cnn_dailymail','3.0.0', split='test')

Reusing dataset cnn_dailymail (/home/huangwenhao/.cache/huggingface/datasets/ccdv___cnn_dailymail/1.0.0/1.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f)


In [5]:
# Data preprocess
#print(tokenizer(xsum_valid[0]['document'],truncation=True, padding="longest", return_tensors="pt"))
#valid_dataset = xsum_valid.map(lambda e: tokenizer(e['document'],truncation=True, padding="longest"), batch_size=batch_size, batched=True)
#test_dataset = xsum_test.map(lambda e: tokenizer(e['document'],truncation=True, padding="longest"), batch_size=batch_size, batched=True)

test_dataset = xsum_test.map(lambda e: tokenizer(e['article'],truncation=True, padding="longest"), batch_size=batch_size, batched=True)
#valid_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])


Loading cached processed dataset at /home/huangwenhao/.cache/huggingface/datasets/ccdv___cnn_dailymail/1.0.0/1.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f/cache-b3e3564c736e665a.arrow


In [6]:
#valid_ld = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_ld = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
valid_predict = []
for batch in tqdm(valid_ld):
    #print(batch['input_ids'].shape)
    batch = {x:y.to(device) for x,y in batch.items()}
    #print(batch)
    translated = model.generate(**batch)
    valid_predict.extend(tokenizer.batch_decode(translated, skip_special_tokens=True))

100%|██████████| 709/709 [26:25<00:00,  2.24s/it]


In [7]:
test_predict = []
for batch in tqdm(test_ld):
    #print(batch['input_ids'].shape)
    batch = {x:y.to(device) for x,y in batch.items()}
    #print(batch)
    translated = model.generate(**batch)
    test_predict.extend(tokenizer.batch_decode(translated, skip_special_tokens=True))

100%|██████████| 1437/1437 [1:44:24<00:00,  4.36s/it]


In [50]:
src_text = [
    """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
]
batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)


translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

In [20]:
test_predict[:10]

['"I laughed and learned more from Jimmie in one hour than from anyone else," says "Hazzard" co-star John Schneider.<n>James Best played bumbling sheriff Rosco P. Coltrane on "The Dukes of Hazzard"',
 'Dr. Anthony Moschetto is charged in what authorities say was a failed plot to have another physician hurt or killed.<n>"None of anything in this case has any evidentiary value," his attorney says.',
 "CNN's John Sutter sat down with President Obama for a one-on-one interview.<n>He asked him about the science behind climate change and public health.<n>The President encouraged ordinary citizens, doctors and nurses to start putting some pressure on elected officials.",
 'A Russian TV channel aired Hillary Clinton\'s first campaign video with a rating stamp that means it\'s for mature audiences.<n>A clip of the video, which features a gay couple holding hands, got the 18+ rating from the independent TV Rain channel.<n>The channel told CNN that it didn\'t want to break the controversial law, 

In [19]:
xsum_test['highlights'][:10]

['James Best, who played the sheriff on "The Dukes of Hazzard," died Monday at 88 . "Hazzard" ran from 1979 to 1985 and was among the most popular shows on TV .',
 'A lawyer for Dr. Anthony Moschetto says the charges against him are baseless . Moschetto, 54, was arrested for selling drugs and weapons, prosecutors say . Authorities allege Moschetto hired accomplices to burn down the practice of former associate .',
 '"No challenge poses more of a public threat than climate change," the President says . He credits the Clean Air Act with making Americans "a lot" healthier .',
 "Presidential hopeful's video, featuring gay couple, gets mature rating in Russia . Russian TV channel feared airing it would break the country's anti-gay propaganda law . Clinton announced her support for same-sex marriage in 2013 .",
 "Raul Reyes: In seeking Latino vote, Marco Rubio his own worst enemy on two key issues: immigration reform, Cuba relations . He says on health care, climate change and other issues, 

In [21]:
#Metric
Rouge = load_metric('rouge')
result = Rouge.compute(predictions=test_predict, references=[tokenizer.convert_tokens_to_string(tokenizer.tokenize(s)[:125]) for s in xsum_test['highlights']])
print(result)

{'rouge1': AggregateScore(low=Score(precision=0.39772253224976767, recall=0.4651015078073494, fmeasure=0.4165496928304979), mid=Score(precision=0.4004648470303975, recall=0.4679978523142452, fmeasure=0.4189722014293267), high=Score(precision=0.403188436133444, recall=0.4708428967445666, fmeasure=0.4214240252205675)), 'rouge2': AggregateScore(low=Score(precision=0.19195002049237347, recall=0.22294554804299171, fmeasure=0.2001840360691533), mid=Score(precision=0.1943959145899981, recall=0.22552212112230757, fmeasure=0.2026123566478052), high=Score(precision=0.19691520457198555, recall=0.2283083467016058, fmeasure=0.20512438928062812)), 'rougeL': AggregateScore(low=Score(precision=0.28302725448104965, recall=0.33112382365638643, fmeasure=0.29649984925005723), mid=Score(precision=0.2853817948420493, recall=0.3338381561898218, fmeasure=0.2986779331210263), high=Score(precision=0.2880322689032132, recall=0.33659922587129976, fmeasure=0.30117401430473656)), 'rougeLsum': AggregateScore(low=Sco

In [2]:
from rouge import Rouge
rouge_score = rouge.get_scores(test_predict, [tokenizer.convert_tokens_to_string(tokenizer.tokenize(s)[:125]) for s in xsum_test['highlights']])

NameError: name 'rouge' is not defined

In [12]:
with open('predict.txt','w',encoding='utf-8') as f:
    for i in [tokenizer.convert_tokens_to_string(tokenizer.tokenize(s)[:125]) for s in xsum_test['highlights']]:
        f.write(i+'\n')