In [None]:
!pip install transformers==2.8.0
!pip install torch==1.4.0

In [None]:
import torch
import json 
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

model = T5ForConditionalGeneration.from_pretrained('t5-small')
if torch.cuda.is_available():
  model.to('cuda')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




In [None]:
def summarize(text, num_beams=4, no_repeat_ngram_size=2, min_length=30, max_length=100, early_stopping=True, skip_special_tokens=True, do_sample=False):
    preprocess_text = text.strip().replace("\n","")
    t5_prepared_Text = "summarize: "+preprocess_text
    #print ("original text preprocessed: \n", preprocess_text)

    tokenized_text = tokenizer.encode(t5_prepared_Text, return_tensors="pt", max_length=512).to(device)


    # summmarize 
    summary_ids = model.generate(tokenized_text,
                                    num_beams=4,
                                    no_repeat_ngram_size=2,
                                    min_length=30,
                                    max_length=100,
                                    early_stopping=True)

    output = tokenizer.decode(summary_ids[0], skip_special_tokens=skip_special_tokens)
    return output

In [None]:
from generators import get_cnn_dm_both_generator

output = []
#i = 0

test_data_path = './dataset/chunked/test_*.bin'
for article, abstract in get_cnn_dm_both_generator(test_data_path):
    article_len = len(article)
    if article_len > 5000:
        print(f'Skipping text - len={article_len}!')
        # Tokenizer cannot handle inputs longer than that
        continue

    print(f'Summarizing text - len={article_len}')
    t5_abstract = summarize(article)
    output.append({
        'article': article,
        'abstract': abstract,
        't5_abstract': t5_abstract
    })    

with open('t5_output_.json', 'w') as fout:
    json.dump(output, fout, indent=2)

Test Results

In [None]:
# test output
with open('t5_output_.json', 'r') as openfile:
 
    # Reading from json file
    json_object = json.load(openfile)

print("######### abstract #########")
print(json_object[400]['abstract'])
print("######### t5_abstract #########")
print(json_object[400]['t5_abstract'])



Eval results with rouge

In [None]:
import pandas as pd

df = pd.read_json('t5_output_.json')


In [None]:
df

In [None]:
!pip install rouge


In [None]:
from rouge import Rouge

rouge = Rouge()


In [None]:
pred_str = df['t5_abstract']
label_str = df['abstract']

rouge_output = rouge.get_scores(pred_str, label_str)

print(rouge_output)