In [2]:
import simplifier
import os
from tqdm import tqdm
import pickle
import time
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch

In [3]:
num_test = 100
name_to_save = "../coreference_resolution/data/" + str(num_test) + "_news.pkl"

In [4]:
# load model
model_name = "google/pegasus-cnn_dailymail"
model = PegasusForConditionalGeneration.from_pretrained(model_name).to("cpu")
tokenizer = PegasusTokenizer.from_pretrained(model_name)
print("model and tokenizer loaded")

model and tokenizer loaded


In [5]:
# load original dataset (pickle)
with open("../coreference_resolution/data/sample_news.pkl", "rb") as f:
    news_list = pickle.load(f)

news_list = news_list[:num_test]
len(news_list)

100

In [8]:
# generate and simplify

print("number of articles to process is ", num_test, '\n')

total_news_list = []

ts = time.time()
for idx, news in tqdm(enumerate(news_list)):

    # stop at chosen num_test
    if idx == num_test:
        break

    # progress printing
    print("Start ", idx)
    if (idx != 0) and (idx % 100 == 0):
        print("Writing story {} of {}; {:.2f} percent done. Time spent: {:.2f}".format(
            idx, num_test, float(idx) * 100.0 / float(num_test), time.time() - ts))


    news_dict = dict()
    content = news['content']
    ref_summary = news['summary']

    # generate summary for original input (post-process)

    inputs = tokenizer(content, max_length=1024, return_tensors="pt", truncation=True).to("cpu")
    summary_ids = model.generate(inputs["input_ids"], num_beams=2, max_length=50)
    gen_summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    # simplify the summary (post-process)
    gen_sim_summary = simplifier.simplify(gen_summary)

    # simplify the input (pre-process)
    sim_content = simplifier.simplify(content)

    # generate summary for simplified input (pre-process)

    inputs = tokenizer(sim_content, max_length=1024, return_tensors="pt", truncation=True, padding=True).to("cpu")
    summary_ids = model.generate(inputs["input_ids"], num_beams=2, max_length=50)
    sim_gen_summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[
        0]

    # combine all into dict
    print("Start 5", idx)
    news_dict["ori_content"] = content
    news_dict["ref_summary"] = ref_summary
    news_dict["gen_summary"] = gen_summary
    news_dict["gen_sim_summary"] = gen_sim_summary
    news_dict["sim_content"] = sim_content
    news_dict["sim_gen_summary"] = sim_gen_summary

    # append to a global list
    total_news_list.append(news_dict)

total_news_list.to_pickle(name_to_save)
print("DONE! generate and simplify")

number of articles to process is  100 

DONE! generate and simplify
