In [1]:
import torch
import nltk.tokenize as tk
from rouge_metric import PerlRouge
from datasets import load_dataset

In [None]:
data_train = load_dataset('xsum', split='train')
data_test = load_dataset('xsum', split='test')
data_valid = load_dataset('xsum', split='validation')

In [4]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
except ImportError:
    raise ImportError("import failure")

In [5]:
def sum_gen(tokenizer,model,article,maxlength):
    global torch_device 
    torch_device= torch.device('cuda:0')
    model=model.to(torch_device)
    article_input_ids = tokenizer.encode_plus(article.replace('\n',''), return_tensors='pt', max_length=maxlength,truncation=True)['input_ids'].to(torch_device)
    summary_ids = model.generate(article_input_ids,num_beams=4,length_penalty=4.0,max_length=maxlength,no_repeat_ngram_size=3)
    summary = tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
    return summary


def sum_doc_gen(filestr,n_phrase_para,tokenizer,model):
    #load doc and generate original summary
    doc=filestr
    tokens = tk.sent_tokenize(doc)
    sum_ori_doc=sum_gen(tokenizer,model,doc,512)

    n_phrase=len(tokens)
    n_para=int(n_phrase/n_phrase_para)+1

    #divide doc into paragraphs
    para=[]
    for i in range(n_para):
        one_para=""
        for j in range(n_phrase_para):
            if(i*n_phrase_para+j<n_phrase):
                one_para+=tokens[i*n_phrase_para+j]
        
        para.append(one_para)
    #now para[] is a list where each element is one paragraph(one string)

    #generate summary for each paragraph
    sum_para=[]
    for i in range(n_para):
        sum_para.append(sum_gen(tokenizer,model,para[i].replace('\n',''),256))

    #change each paragraph for its summary in doc
    docs_modified=[]
    for i in range(n_para):
        one_text=""
        for j in range(n_para):
            if j==i:
                one_text+=sum_para[j]
            else:
                one_text+=para[j]
        docs_modified.append(one_text)

    #generate summaries for modified docs
    sums_modified=[]
    for i in range(n_para):
        sums_modified.append(sum_gen(tokenizer,model,docs_modified[i].replace('\n',''),512))

    return sum_ori_doc,n_para,para,sum_para,sums_modified

def evaluate(n_para,sum_ori_doc,sums_modified):
    rouge = PerlRouge(rouge_n_max=3)

    score=0
    for i in range(n_para):
        scores = rouge.evaluate([sums_modified[i]], [[sum_ori_doc]])
        score+=scores['rouge-2']['r']
    return score/n_para

def evaluate_rouge(sum,ref):
    rouge = PerlRouge(rouge_n_max=3)
    score = rouge.evaluate([sum],[[ref]])
    return score['rouge-2']['r']


In [6]:
tokenizer_bart = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model_bart = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

tokenizer_distilbart = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
model_distilbart = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6")


tokenizer_bert = AutoTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
model_bert = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")

tokenizer_bart_base = AutoTokenizer.from_pretrained("facebook/bart-base")
model_bart_base = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")

'''
tokenizer_pega = AutoTokenizer.from_pretrained("google/pegasus-xsum")
model_pega = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-xsum")
'''

'\ntokenizer_pega = AutoTokenizer.from_pretrained("google/pegasus-xsum")\nmodel_pega = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-xsum")\n'

In [7]:
score_my_bart=[]
score_rouge_bart=[]
for i in range(50):
    sum_ori_doc,n_para,para,sum_para,sums_modified=sum_doc_gen(data_test[i]['document'].replace('\n',''),4,tokenizer_bart,model_bart)
    score_my_bart.append(evaluate(n_para,sum_ori_doc,sums_modified))
    score_rouge_bart.append(evaluate_rouge(sum_ori_doc,data_test[i]['summary'].replace('\n','')))

score_my_distilbart=[]
score_rouge_distilbart=[]
for i in range(50):
    sum_ori_doc,n_para,para,sum_para,sums_modified=sum_doc_gen(data_test[i]['document'].replace('\n',''),4,tokenizer_distilbart,model_distilbart)
    score_my_distilbart.append(evaluate(n_para,sum_ori_doc,sums_modified))
    score_rouge_distilbart.append(evaluate_rouge(sum_ori_doc,data_test[i]['summary'].replace('\n','')))

score_my_bart_base=[]
score_rouge_bart_base=[]
for i in range(50):
    sum_ori_doc,n_para,para,sum_para,sums_modified=sum_doc_gen(data_test[i]['document'].replace('\n',''),4,tokenizer_bart_base,model_bart_base)
    score_my_bart_base.append(evaluate(n_para,sum_ori_doc,sums_modified))
    score_rouge_bart_base.append(evaluate_rouge(sum_ori_doc,data_test[i]['summary'].replace('\n','')))

score_my_bert=[]
score_rouge_bert=[]
for i in range(50):
    sum_ori_doc,n_para,para,sum_para,sums_modified=sum_doc_gen(data_test[i]['document'].replace('\n',''),4,tokenizer_bert,model_bert)
    score_my_bert.append(evaluate(n_para,sum_ori_doc,sums_modified))
    score_rouge_bert.append(evaluate_rouge(sum_ori_doc,data_test[i]['summary'].replace('\n','')))

In [8]:
print(sum(score_my_bart)/50,sum(score_rouge_bart)/50)
print()
print(sum(score_my_distilbart)/50,sum(score_rouge_distilbart)/50)
print()
print(sum(score_my_bart_base)/50,sum(score_rouge_bart_base)/50)
print()
print(sum(score_my_bert)/50,sum(score_rouge_bert)/50)

0.7901536508730159 0.06358719999999998

0.765154726825397 0.06568919999999999

0.8305175104761904 0.15791059999999996

0.6961957437301588 0.07380739999999998


In [None]:
import matplotlib.pyplot as plt
plt.figure
plt.plot([sum(score_my_bart)/50,sum(score_my_distilbart)/50,sum(score_my_bart_base)/50,sum(score_my_bert)/50],label="our method")
plt.plot([sum(score_rouge_bart)/50,sum(score_rouge_distilbart)/50,sum(score_rouge_bart_base)/50,sum(score_rouge_bert)/50],label='ROUGE method')
plt.legend()
plt.title('comparison')