In [1]:
import sys
import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

In [60]:

# Get the data - download the dataset if no data_dir is specified.
# so we have the data already in 'data/' for you

# Importing CNN/DailyMail articles dataset
train_stream_fnction = trax.data.TFDS('cnn_dailymail',
                                 data_dir='../news_data/',
                                 keys=('article', 'highlights'),
                                 train=True)

# This should be much faster as the data is downloaded already.
eval_stream_fnction = trax.data.TFDS('cnn_dailymail',
                                data_dir='../news_data/',
                                keys=('article', 'highlights'),
                                train=False)

#### Create tokenize and detokenize functions

In [61]:
# Now need create helper functions to tokenize and detokenize data. Tokenise converts a text sentence to its
# corresponding token list (i.e. list of indices). Also converts words to subwords.
# similarly we need to have detokenize function to reconvert the tokens to its sentence

def tokenize(input_str,EOS=1):
    """ convert input string to a feature dictionary"""
#     trax.data.tokenize method takes streams and returns streams, we user iter to have one elment stream
    input_sting=next(trax.data.tokenize(iter([input_str]),
                                       vocab_dir='vocab_dir/',
                                       vocab_file='summarize32k.subword.subwords'))
#     put EOS at the end of sentence
    return list(input_string)+[EOS]

def detokenize(input_integers):
    """convert input intergers to string"""
    string_converted=trax.data.detokenize(input_integers,
                                        vocab_dir='vocab_dir/',
                                        vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(string_converted)


In [62]:
# Language model and preprocessing
# language models only predicts next work,we concatenate inputs with target and seperate them
# with a seperator and concatenate them. Further padding masks are used 0s and 1s in input and targets 
# respectively. So the focus is model to pay attention on summary.

In [68]:
# mask tokens
# SEP=0 #Padding or separator
# EOS=1 #end of token sentence

# # Now lets concatenate input tokens and targets using 0 as seperator
def preprocess(stream):
    """get the data stream and seperate with 0, stream data comming with articles and summary"""
    for (article,summary) in stream:
        combine=np.array(list(article)+[EOS,SEP]+list(summary)+[EOS])
        mask=[0]*(len(list(article))+2)+[1]*(len(list(summary))+1)
        yield combine,combine,np.array(mask)

# # make data pipeline as follows
input_pipeline=trax.data.Serial(
#     first tokennize
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                        vocab_file='summarize32k.subword.subwords'),
#     now use the above function preprocess
    preprocess,
#     need to filter out the strings longer than 2018
    trax.data.FilterByLength(2048)
)

# # Apply above pipeline to both train and evaluation data
train_stream=input_pipeline(train_stream_fnction())
eval_stream=input_pipeline(eval_stream_fnction())

# get one by one
train_input,train_target,train_mask=next(train_stream)
# train and target shoud be same language model
assert sum((train_input-train_target)**2)==0


In [69]:
# prints mask, 0s on article, 1s on summary
print(f'Single example mask:\n\n {train_mask}')

Single example mask:

 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0

In [70]:
# prints: [Example][<EOS>][<pad>][Example Summary][<EOS>]
print(f'Single example:\n\n {detokenize(train_input)}')

Single example:

 By . Ian Parkes, Press Association . Max Chilton has every confidence
he will be retained by Marussia for a third consecutive season.
Chilton started the campaign relatively strongly, claiming the best
results of his Formula One career by finishing 13th in the season-
opening race in Australia and again in Bahrain. In Monaco, however,
team-mate Jules Bianchi stole Chilton's thunder as the Frenchman
scored Marussia's first points from their four and a half years in F1
with ninth place in Monaco. Centre of attention: Max Chilton remains
hopeful of being retained by Marussia for the 2015 season . Since then
Chilton has struggled for form and results, but the 23-year-old from
Reigate in Surrey sees no reason why Marussia would not retain him for
2015. 'I naturally want to stay with the team,' said Chilton. 'Like a
lot of these things they filter down from the top, and there are a lot
of rumours with regard to the top of the grid, with people moving
around and you don't re

In [105]:
# inputs are in different lengths and padding them with 0s will wast computational resource so good approach 
# would be group the strings to specific sizes and process, ww use buckets to create batched generators
# buckets are defined based on boundaries and batch sizes, batch size[i] signifies the 
# batch size for the items with length < boundaries[i], so we use batch size 4 of length<512, 8 of length<256,16 of sentence lewnght <128 so on

boundaries =[128,256,512,1024]
batch_size=[16,8,4,2,1]

# now create the stream
train_batch_stream=trax.data.BucketByLength(boundaries,batch_size)(train_stream)
eval_batch_stream=trax.data.BucketByLength(boundaries,batch_size)(eval_stream)


In [109]:
# different articles will be produced every time
input_batch,_,mask_batch=next(train_batch_stream)
input_batch.shape

(2, 1024)

In [110]:
# let's see the corresponding iteger values
input_batch

array([[   27,  6327, 19480,   649,  3331,   540,   163, 10871,  5258,
         7511,    22,  4194,   320,   399,    28, 18870,   330, 15702,
          132,   163, 25209,  1560,  8833,   186,  5564,   213,   157,
         3408,   134,  1353,    36,   527,    15,  8866, 14161,     5,
            3, 10851,   800, 19480,   649,  3331,  2958,  6211,  5861,
          142,  1353,  3156,  8751,  2631,    78,   213,  6327,     6,
          766,  9217,   138,    78,  2613,  4979,  7511,    22,   980,
          213,  5556,   186,  1410,   320,  6909,    95,   186,   399,
            3,  6211,  5861,   142,     2,  1874,     2,   721,   320,
          399,    36,   157,  1124,   320,  6166,   213,  8833,   179,
           78,    50, 16612,   186,   102,   290,   181,   409,   469,
         1241,   132,    41,    25,   217,   320,   211,   213,  8833,
        14643,  7023,     3,  6327, 19480,   649,  3331,  2958,  6211,
         5861,   142,     2,   231,     2,   540,   163, 10871,  5258,
      

In [113]:
# above input batch consists:
#     all the values corresponding to words, first 1 represents the <EOS> of the article followed by 0 (pads)
#     after the first 0, other values show the summary words and the second 1 represent the <EOS> tag for summary
#       All the other 0s are to maintain the consistancy of the  length - max length specified in the bucket 
# show the processes input data batch
print("Article and summary:\n\n",detokenize(input_batch[0]))

Article and summary:

 A Baltimore Orioles fan got an unexpected surprise when he stopped to
help a motorist trapped in an overturned truck and realized the man
helping him was one of his sporting idols. Diehard Orioles fan Mike
Soukup was driving southbound on the Baltimore-Washington Parkway on
Monday afternoon when he saw the accident and decided to pull over and
help. Soukup, 55, started to help one man try to push the truck back
on its wheels and after four or five others joined in they were able
to get the truck upright. Baltimore Orioles fan Mike Soukup, right,
got an unexpected surprise when he stopped to help a motorist trapped
in an overturned truck and realized the man helping him was his
sporting idols, Chris Davis . When Soukup turned to congratulate the
achievement with the man next to him - the one who was the first on
the scene - he recognized him as Orioles corner infielder Chris Davis.
'I turned to high five the guy for a good job done getting this truck
up, and I tho

In [114]:
# so the structure is article <EOS><PADs>article summary <EOS><pads>
# loss is taken only on the summary using cross entropy as loss function
#  Now create helper functions to create tensor and to display tensors using jax numpy array

def create_tensor(tensor):
    """input list of lists out put a tensor"""
    return jnp.array(tensor)

def display_tensor(tensor,name):
    """ display the name and tensor"""
    print(f'{name} shape: {tensor.shape}\n')
    print(f'{tensor}\n')
    

# try with dummy data and build attention - dot product
$$
\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\
$$

In [115]:
# Create dotproduct attention
def DotProductAttention(query,key,value,mask):
    """ dot product self attention
    args: query - jax.intepreters.xla.DeviceArray: array of query representations with shape L_q by d
    key jax.intepreters.xla.DeviceArray: array of query representation with shape L_k by d
    value jax.intepreters.xla.DeviceArray: array of value representation with shpeL_k by d where L_v=L_k
    mask jax.interpreters.sla.DeviceArray: attention mask, gates attention with shape L_q by L_k ( this is due to dot product)
    returns jax.intperters.xla.DeviceArray: self-attention array for q,k,v arrays L_q by L_k"""
    
    assert query.shape[-1]==key.shape[-1]==value.shape[-1],"Embeedding dimesions of q,k,v must be same"
#     get the depth dimentionality of the query embedding  for the scaling down the dot product
    depth=query.shape[-1]
    
#     get the scaled query key dot product according to formula above
    dots=jnp.amtmul(query,jnp.swapaxes(key,-1,-2))/jnp.sqrt(depth)
    
#     now apply the mask
    if mask is None:
        dots=jnp.where(mask,dots,jnp.full_like(dots,-1e9))
    
#     softmax
    logsumexp=trax.fastmath.logsumexp(dots,axis=-1, keepdims=True)
    
#     now get the exponnential of dots minus logsumexp to get softmax
    dots=jnp.exp(dots-logsumexp)
#      now multiply by values to get the attention
    attention=jnp.matmul(dots,vlaue)
    
    return attention
    

In [None]:
# Now inpmelent the causal attention: multi headed attention with mask to attend only the words that occured before
# 1. copute attention heads - get input with dimention (batch size, seqlen,n_heads X d_head) then splits the last(depth)
# dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size X n_heads,seqlen,d_head)

# 2. dot product self attention
# 3. compute attention output