In [None]:
!pip install --upgrade jax
!pip install --upgrade jaxlib
!pip install --upgrade trax

In [None]:
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

In [None]:
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [None]:
import sys
import os

import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

In [None]:
train_stream_fn = trax.data.TFDS('cnn_dailymail',
                                 data_dir='data/',
                                 keys=('article', 'highlights'),
                                 train=True)


In [None]:
eval_stream_fn = trax.data.TFDS('cnn_dailymail',
                                data_dir='data/',
                                keys=('article', 'highlights'),
                                train=False)

In [None]:
#Tokenize and Detokenize
def tokenize(input_str, EOS=1):
    
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='/content/gdrive/MyDrive/Summarizer/vocab_dir/',
                                      vocab_file='summarize32k.subword.subwords'))
    
    # Mark the end of the sentence with EOS
    return list(inputs) + [EOS]

def detokenize(integers):
   
  
    s = trax.data.detokenize(integers,
                             vocab_dir='/content/gdrive/MyDrive/Summarizer/vocab_dir/',
                             vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(s)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# Special tokens
SEP = 0 # Padding or separator token
EOS = 1 # End of sentence token

# Concatenate tokenized inputs and targets using 0 as separator.
def preprocess(stream):
    for (article, summary) in stream:
        joint = np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
        mask = [0] * (len(list(article)) + 2) + [1] * (len(list(summary)) + 1) 
        yield joint, joint, np.array(mask)


input_pipeline = trax.data.Serial(
   
    trax.data.Tokenize(vocab_dir='/content/gdrive/MyDrive/Summarizer/vocab_dir',
                       vocab_file='summarize32k.subword.subwords'),

    preprocess,

    trax.data.FilterByLength(2048)
)

# Apply preprocessing to data streams.
train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

assert sum((train_input - train_target)**2) == 0  

In [None]:
# prints mask, 0s on article, 1s on summary
print(f'Single example mask:\n\n {train_mask}')

Single example mask:

 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0

In [None]:
#Example
print(f'Single example:\n\n {detokenize(train_input)}')

Single example:

 Chelsea's early season form may have led to comparisons with the
Arsenal 'Invincibles' side, but Gary Neville believes they aren't even
as good as the Chelsea side from 10 years ago. Jose Mourinho's side
are currently four points clear at the top of the Premier League, but
after letting leads slip against both Manchester City and United,
their killer instinct has been called into question. 'If a team are
going to be playing for a 1-0 then you better see it out,' Neville
said on Monday Night Football. 'When I saw Jose Mourinho two weeks ago
he talked about the 2005 (Chelsea) team and (compared) the team he had
then to the team he has now and he said the killer instinct's missing.
Chelsea have dropped more points from winning positions this season
than they did in the whole of 2004/05. Chelsea took the lead against
both Manchester United and Manchester City, but drew both matches.
'When I look at the statistics they are staggering - 28 times they
(the 2004/05 team) scor

In [None]:
#Create Boundaries
boundaries =  [128, 256,  512, 1024]
batch_sizes = [8,    8,    8,    8, 8]

# Create the streams.
train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(eval_stream)

In [None]:
input_batch, _, mask_batch = next(train_batch_stream)
batch=next(train_batch_stream)
# Shape of the input_batch
input_batch.shape


(8, 1024)

In [None]:
# print corresponding integer values
print(input_batch)

In [None]:
# print the article and its summary
print('Article:\n\n', detokenize(input_batch[0]))

Article:

 By. Deborah Arthurs. UPDATED:. 13:47 EST, 4 January 2012. As Burberry
today unveils its new spring/summer campaign in store windows and
across social media networking sites, millions of ardent fans will be
watching. Just last month, it was announced that the British company
had become the world's most successful  luxury fashion brand on
Facebook and Twitter, with a record 10million fans on Facebook, and
almost 700,000 following the brand's regular UK feed on Twitter.
Meanwhile, they have thousands more global Twitter fans following
their international feeds and post exclusive content on their own
YouTube channel. Behind the scenes: Eddie Redmayne and Cara Delevingne
on set of the latest Burberry campaign. Burberry's social media
success has grown exponentially - and it is still growing fast. The
secret, say consumer experts, is the fact that Burberry share so much
unique content exclusively with their followers on social networking
platforms, and post new and different conte

In [None]:
#Create Transformer model for Summarization
model = trax.models.TransformerLM(vocab_size=33000, d_model=512, d_ff=2048,
                                  n_layers=6, n_heads=8, max_len=4096, dropout=0.1,
                                  mode='train', ff_activation=tl.Relu)

In [None]:
model

In [None]:
#Training and Evaluation
from trax.supervised import training

train_task = training.TrainTask( 
  labeled_data=train_batch_stream, # The training generator
  loss_layer=tl.CrossEntropyLoss(), # Loss function 
  optimizer=trax.optimizers.Adam(0.01), # Optimizer 
  lr_schedule=trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01),
  n_steps_per_checkpoint=2,
  n_steps_per_permanent_checkpoint = 100 
)

In [None]:
eval_task = training.EvalTask( 
  labeled_data=eval_batch_stream, # The evaluation generator
  metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] # CrossEntropyLoss and Accuracy
)

Evaluation


In [None]:
model = trax.models.TransformerLM(vocab_size=33000, d_model=512, d_ff=2048,
                                  n_layers=6, n_heads=8, max_len=4096, dropout=0.1,
                                  mode='eval', ff_activation=tl.Relu)

In [None]:
model.init_from_file('/content/gdrive/MyDrive/Summarizer/models/model.pkl.gz',
                    weights_only=True)

In [None]:

def next_symbol(cur_output_tokens, model):
 
    token_length = len(cur_output_tokens)

    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    # Fill cur_output_tokens with 0's until it reaches padded_length
    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] 


    output, _ = model((padded_with_batch, padded_with_batch)) 

    log_probs = output[0, token_length, :]
  
    
    return int(np.argmax(log_probs))

In [None]:
# Test Input!
sentence_test_nxt_symbl = "I want to fly in the sky."
detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model)])

'The'

In [None]:

def greedy_decode(input_sentence, model, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize):



    # Use tokenize()
    cur_output_tokens = tokenize(input_sentence) + [0]    
    generated_output = [] 
    cur_output = 0 
    EOS = 1 
    
    while cur_output != EOS:
        # Get next symbol
        cur_output = next_symbol(cur_output_tokens, model)
        # Append next symbol to original sentence
        cur_output_tokens.append(cur_output)
        # Append next symbol to generated sentence
        generated_output.append(cur_output)
        
        print(detokenize(generated_output))
    
  
        
    return detokenize(generated_output)

In [None]:
# Test it out on a sentence!
test_sentence = "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))

It was a sunny day when I went to the market to buy some flowers. But
I only found roses, not tulips. 

:
: I
: I just
: I just found
: I just found ros
: I just found roses
: I just found roses,
: I just found roses, not
: I just found roses, not tu
: I just found roses, not tulips
: I just found roses, not tulips
: I just found roses, not tulips.
: I just found roses, not tulips.<EOS>
: I just found roses, not tulips.<EOS>


In [None]:
# Test it out with a whole article!
article = "It’s the posing craze sweeping the U.S. after being brought to fame by skier Lindsey Vonn, soccer star Omar Cummings, baseball player Albert Pujols - and even Republican politician Rick Perry. But now four students at Riverhead High School on Long Island, New York, have been suspended for dropping to a knee and taking up a prayer pose to mimic Denver Broncos quarterback Tim Tebow. Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were all suspended for one day because the ‘Tebowing’ craze was blocking the hallway and presenting a safety hazard to students. Scroll down for video. Banned: Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll (all pictured left) were all suspended for one day by Riverhead High School on Long Island, New York, for their tribute to Broncos quarterback Tim Tebow. Issue: Four of the pupils were suspended for one day because they allegedly did not heed to warnings that the 'Tebowing' craze at the school was blocking the hallway and presenting a safety hazard to students."
print(wrapper.fill(article), '\n')
print(greedy_decode(article, model))

It’s the posing craze sweeping the U.S. after being brought to fame by
skier Lindsey Vonn, soccer star Omar Cummings, baseball player Albert
Pujols - and even Republican politician Rick Perry. But now four
students at Riverhead High School on Long Island, New York, have been
suspended for dropping to a knee and taking up a prayer pose to mimic
Denver Broncos quarterback Tim Tebow. Jordan Fulcoly, Wayne Drexel,
Tyler Carroll and Connor Carroll were all suspended for one day
because the ‘Tebowing’ craze was blocking the hallway and presenting a
safety hazard to students. Scroll down for video. Banned: Jordan
Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll (all pictured
left) were all suspended for one day by Riverhead High School on Long
Island, New York, for their tribute to Broncos quarterback Tim Tebow.
Issue: Four of the pupils were suspended for one day because they
school was blocking the hallway and presenting a safety hazard to
students. 

Jordan
Jordan Ful
Jordan Fulcol


In [None]:
loop.run(1000)


Step     14: Ran 2 train steps in 28.24 secs
Step     14: train CrossEntropyLoss | -0.40733588
Step     14: eval  CrossEntropyLoss | -0.48936749
Step     14: eval          Accuracy |  0.05418719

Step     16: Ran 2 train steps in 140.57 secs
Step     16: train CrossEntropyLoss | -0.48392025
Step     16: eval  CrossEntropyLoss | -0.54614329
Step     16: eval          Accuracy |  0.03454895

Step     18: Ran 2 train steps in 127.37 secs
Step     18: train CrossEntropyLoss | -0.56032300
Step     18: eval  CrossEntropyLoss | -0.58724588
Step     18: eval          Accuracy |  0.03107344

Step     20: Ran 2 train steps in 27.10 secs
Step     20: train CrossEntropyLoss | -0.63797355
Step     20: eval  CrossEntropyLoss | -0.71727920
Step     20: eval          Accuracy |  0.03369066

Step     22: Ran 2 train steps in 119.53 secs
Step     22: train CrossEntropyLoss | -0.70920670
Step     22: eval  CrossEntropyLoss | -0.79851902
Step     22: eval          Accuracy |  0.03864734

Step     24: Ran