# Required installs

In [None]:
!pip install -q transformers
!pip install -q sentencepiece

# Imports

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import transformers

from transformers import PegasusTokenizer, TFPegasusModel, TFPegasusForConditionalGeneration
from transformers import T5Tokenizer, TFT5Model, TFT5ForConditionalGeneration

# Utility functions

Wrap text around for aesthetics.

In [None]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

# CNN/DM dataset loading

Download and load raw data. Data is in binary format in a tf.Dadaset structure --> we process it later

In [None]:
data, info = tfds.load('cnn_dailymail', with_info=True)

INFO:absl:No config specified, defaulting to first: cnn_dailymail/plain_text
INFO:absl:Load dataset info from /root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0
INFO:absl:Reusing dataset cnn_dailymail (/root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0)
INFO:absl:Constructing tf.data.Dataset for split None, from /root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0


Extract train, val, and test data

In [None]:
train_data, val_data, test_data = data['train'], data['validation'], data['test']

The datasets have each an "article" and the "highlight", let's extract here only the articles that we want to summarize

In [None]:
X_train = train_data.map(lambda x: x['article'])
X_val = val_data.map(lambda x: x['article'])
X_test = test_data.map(lambda x: x['article'])

Print a few example articles

In [None]:
num_example = 2
for c, elem in enumerate(X_train):
  print(elem.numpy().decode())
  print('\n')
  if c>=num_example-1:
    print('--------------------')
    print('Each element of X_train is a:', elem.dtype)
    break


By . Associated Press . PUBLISHED: . 14:11 EST, 25 October 2013 . | . UPDATED: . 15:36 EST, 25 October 2013 . The bishop of the Fargo Catholic Diocese in North Dakota has exposed potentially hundreds of church members in Fargo, Grand Forks and Jamestown to the hepatitis A virus in late September and early October. The state Health Department has issued an advisory of exposure for anyone who attended five churches and took communion. Bishop John Folda (pictured) of the Fargo Catholic Diocese in North Dakota has exposed potentially hundreds of church members in Fargo, Grand Forks and Jamestown to the hepatitis A . State Immunization Program Manager Molly Howell says the risk is low, but officials feel it's important to alert people to the possible exposure. The diocese announced on Monday that Bishop John Folda is taking time off after being diagnosed with hepatitis A. The diocese says he contracted the infection through contaminated food while attending a conference for newly ordained b

# Pegasus model

Download and create the model, get associated tokenizer

In [None]:
pegasus_model = TFPegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum')
pegasus_tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum')

All model checkpoint layers were used when initializing TFPegasusForConditionalGeneration.

All the layers of TFPegasusForConditionalGeneration were initialized from the model checkpoint at google/pegasus-xsum.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFPegasusForConditionalGeneration for predictions without further training.


Check model summary (pretty impressive!)

In [None]:
pegasus_model.summary()

Model: "tf_pegasus_for_conditional_generation_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 model (TFPegasusMainLayer)  multiple                  569748480 
                                                                 
Total params: 569,844,583
Trainable params: 569,748,480
Non-trainable params: 96,103
_________________________________________________________________


Construct list of articles to summarize and tokenize

In [None]:
num_articles = 2
summarize = []
for c, elem in enumerate(X_train):
  summarize.append(elem.numpy().decode())
  if c>=num_example-1:
    break

pegasus_inputs = pegasus_tokenizer(summarize, return_tensors='tf', padding=True)
print('There are', pegasus_inputs['input_ids'].shape[0], 'articles of length', pegasus_inputs['input_ids'].shape[1], 'to summarize')

There are 2 articles of length 456 to summarize


Run inference, use the model on CNN/DM data to summarize inputs articles

In [None]:
pegasus_summary_ids = pegasus_model.generate(pegasus_inputs['input_ids'], min_length=100, num_beams=5, no_repeat_ngram_size=1)

print([pegasus_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False)
       for g in pegasus_summary_ids])



# T5 Model

Download and create the model, get associated tokenizer

In [None]:
t5_model = TFT5ForConditionalGeneration.from_pretrained('t5-large')
t5_tokenizer = T5Tokenizer.from_pretrained('t5-large')

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.75G [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-large.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

Check model summary

In [None]:
t5_model.summary()

Model: "tft5_for_conditional_generation"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 shared (TFSharedEmbeddings)  multiple                 32899072  
                                                                 
 encoder (TFT5MainLayer)     multiple                  302040576 
                                                                 
 decoder (TFT5MainLayer)     multiple                  402728448 
                                                                 
Total params: 737,668,096
Trainable params: 737,668,096
Non-trainable params: 0
_________________________________________________________________


Tokenize the list of articles + add 'summarize:' because T5 needs to know the task

In [None]:
num_articles = 2
summarize = []
for c, elem in enumerate(X_train):
  summarize.append('summarize: ' + elem.numpy().decode())
  if c>=num_example-1:
    break

t5_inputs = t5_tokenizer(summarize, return_tensors='tf', padding=True)
print('There are', t5_inputs['input_ids'].shape[0], 'articles of length', t5_inputs['input_ids'].shape[1], 'to summarize')

There are 2 articles of length 548 to summarize


Run inference, use the model on CNN/DM data to summarize inputs articles

In [None]:
t5_summary_ids = t5_model.generate(t5_inputs['input_ids'], num_beams=3, no_repeat_ngram_size=1)

print([t5_tokenizer.decode(g, skip_special_tokens=True, 
                           clean_up_tokenization_spaces=False) for g in t5_summary_ids])

['bishop of the fargo Catholic diocese in north Dakota has exposed potentially hundreds church members to', 'police lieutenant Ralph mata worked in the division that investigates allegations of wrongdo']
