In [106]:
#!pip install transformers==2.8.0
#!pip install torch==1.4.0

In [88]:
import torch
import json 
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

model = T5ForConditionalGeneration.from_pretrained('t5-small')
if torch.cuda.is_available():
  model.to('cuda')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




In [91]:
def summarize(text, num_beams=4, no_repeat_ngram_size=2, min_length=30, max_length=100, early_stopping=True, skip_special_tokens=True, do_sample=False):
    preprocess_text = text.strip().replace("\n","")
    t5_prepared_Text = "summarize: "+preprocess_text
    #print ("original text preprocessed: \n", preprocess_text)

    tokenized_text = tokenizer.encode(t5_prepared_Text, return_tensors="pt", max_length=512).to(device)


    # summmarize 
    summary_ids = model.generate(tokenized_text,
                                    num_beams=4,
                                    no_repeat_ngram_size=2,
                                    min_length=30,
                                    max_length=100,
                                    early_stopping=True)

    output = tokenizer.decode(summary_ids[0], skip_special_tokens=skip_special_tokens)
    return output

In [92]:
from generators import get_cnn_dm_both_generator

output = []
#i = 0

test_data_path = './dataset/chunked/test_*.bin'
for article, abstract in get_cnn_dm_both_generator(test_data_path):
    article_len = len(article)
    if article_len > 5000:
        print(f'Skipping text - len={article_len}!')
        # Tokenizer cannot handle inputs longer than that
        continue

    print(f'Summarizing text - len={article_len}')
    t5_abstract = summarize(article)
    output.append({
        'article': article,
        'abstract': abstract,
        't5_abstract': t5_abstract
    })    

with open('t5_output_.json', 'w') as fout:
    json.dump(output, fout, indent=2)

Skipping text - len=7367!
Summarizing text - len=3739


  beam_id = beam_token_id // vocab_size


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Summarizing text - len=4440
Skipping text - len=5044!
Skipping text - len=7428!
Skipping text - len=8885!
Summarizing text - len=4052
Skipping text - len=8057!
Summarizing text - len=2990
Summarizing text - len=3765
Skipping text - len=5546!
Summarizing text - len=2709
Summarizing text - len=3812
Skipping text - len=5766!
Skipping text - len=7668!
Skipping text - len=6688!
Summarizing text - len=3814
Summarizing text - len=4533
Summarizing text - len=4018
Skipping text - len=6377!
Skipping text - len=5575!
Summarizing text - len=4586
Summarizing text - len=4878
Summarizing text - len=4312
Skipping text - len=5555!
Summarizing text - len=3038
Summarizing text - len=3236
Summarizing text - len=2391
Summarizing text - len=3518
Summarizing text - len=3800
Summarizing text - len=3052
Skipping text - len=7005!
Skipping text - len=6644!
Summarizing text - len=4425
Summarizing text - len=4192
Summarizing text - len=2707
Summarizi

Test Results

In [93]:
# test output
with open('t5_output_.json', 'r') as openfile:
 
    # Reading from json file
    json_object = json.load(openfile)

print("######### abstract #########")
print(json_object[400]['abstract'])
print("######### t5_abstract #########")
print(json_object[400]['t5_abstract'])



######### abstract #########
<s> promoter : manny pacquiao 's has between 800 and 900 friends who want tickets to historic las vegas clash . </s> <s> kenny bayless named as referee in mayweather-pacquiao 's bout , dubbed the `` fight of the century '' </s>
######### t5_abstract #########
manny pacquiao will fight floyd mayweather in las vegas on may 2. nevada state athletic commission says kenny bayless will be the referee of next month's fight with psg and the mgm grand in october. the cheapest tickets for the bout are priced at $ 1,500 but demand is such that some seats could fetch as much as $ 11,000 on the secondary


Eval results with rouge

In [96]:
import pandas as pd

df = pd.read_json('t5_output_.json')


In [97]:
df

Unnamed: 0,article,abstract,t5_abstract
0,-lrb- cnn -rrb- the palestinian authority offi...,<s> membership gives the icc jurisdiction over...,the formal accession was marked with a ceremon...
1,"-lrb- cnn -rrb- on may 28 , 2014 , some 7,000 ...",<s> amnesty international releases its annual ...,"some 7,000 people gathered in a stadium in chi..."
2,"-lrb- cnn -rrb- seventy years ago , anne frank...",<s> museum : anne frank died earlier than prev...,anne frank died of typhus in a nazi concentrat...
3,-lrb- cnn -rrb- a duke student has admitted to...,<s> student is no longer on duke university ca...,the prestigious private school didn't identify...
4,-lrb- cnn -rrb- never mind cats having nine li...,"<s> theia , a bully breed mix , was apparently...",stray pooch in washington state has used up at...
...,...,...,...
8286,crown princess mary attended an anzac day cere...,<s> australian-born royal placed a wreath at a...,crown princess mary attended an anzac day cere...
8287,telecom watchdogs are to stop a rip-off that a...,<s> operators are charging up to 20p a minute ...,telecom watchdogs are to stop a rip-off that a...
8288,it is a week which has seen him in deep water ...,<s> hardy was convicted of domestic abuse agai...,it is a week which has seen him in deep water ...
8289,an hiv self-testing kit is on sale for the fir...,<s> the 99.7 per cent accurate biosure hiv sel...,the 99.7 per cent accurate biosure hiv self te...


In [79]:
!pip install rouge


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rouge
  Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: rouge
Successfully installed rouge-1.0.1
[0m

In [98]:
from rouge import Rouge

rouge = Rouge()


In [99]:
pred_str = df['t5_abstract']
label_str = df['abstract']

rouge_output = rouge.get_scores(pred_str, label_str)

print(rouge_output)

Output hidden; open in https://colab.research.google.com to view.