In [116]:
import sys
import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

In [117]:

# Get the data - download the dataset if no data_dir is specified.
# so we have the data already in 'data/' for you

# Importing CNN/DailyMail articles dataset
train_stream_fnction = trax.data.TFDS('cnn_dailymail',
                                 data_dir='../news_data/',
                                 keys=('article', 'highlights'),
                                 train=True)

# This should be much faster as the data is downloaded already.
eval_stream_fnction = trax.data.TFDS('cnn_dailymail',
                                data_dir='../news_data/',
                                keys=('article', 'highlights'),
                                train=False)

#### Create tokenize and detokenize functions

In [118]:
# Now need create helper functions to tokenize and detokenize data. Tokenise converts a text sentence to its
# corresponding token list (i.e. list of indices). Also converts words to subwords.
# similarly we need to have detokenize function to reconvert the tokens to its sentence

def tokenize(input_str,EOS=1):
    """ convert input string to a feature dictionary"""
#     trax.data.tokenize method takes streams and returns streams, we user iter to have one elment stream
    input_sting=next(trax.data.tokenize(iter([input_str]),
                                       vocab_dir='vocab_dir/',
                                       vocab_file='summarize32k.subword.subwords'))
#     put EOS at the end of sentence
    return list(input_string)+[EOS]

def detokenize(input_integers):
    """convert input intergers to string"""
    string_converted=trax.data.detokenize(input_integers,
                                        vocab_dir='vocab_dir/',
                                        vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(string_converted)


In [119]:
# Language model and preprocessing
# language models only predicts next work,we concatenate inputs with target and seperate them
# with a seperator and concatenate them. Further padding masks are used 0s and 1s in input and targets 
# respectively. So the focus is model to pay attention on summary.

In [120]:
# mask tokens
# SEP=0 #Padding or separator
# EOS=1 #end of token sentence

# # Now lets concatenate input tokens and targets using 0 as seperator
def preprocess(stream):
    """get the data stream and seperate with 0, stream data comming with articles and summary"""
    for (article,summary) in stream:
        combine=np.array(list(article)+[EOS,SEP]+list(summary)+[EOS])
        mask=[0]*(len(list(article))+2)+[1]*(len(list(summary))+1)
        yield combine,combine,np.array(mask)

# # make data pipeline as follows
input_pipeline=trax.data.Serial(
#     first tokennize
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                        vocab_file='summarize32k.subword.subwords'),
#     now use the above function preprocess
    preprocess,
#     need to filter out the strings longer than 2018
    trax.data.FilterByLength(2048)
)

# # Apply above pipeline to both train and evaluation data
train_stream=input_pipeline(train_stream_fnction())
eval_stream=input_pipeline(eval_stream_fnction())

# get one by one
train_input,train_target,train_mask=next(train_stream)
# train and target shoud be same language model
assert sum((train_input-train_target)**2)==0


In [121]:
# prints mask, 0s on article, 1s on summary
print(f'Single example mask:\n\n {train_mask}')

Single example mask:

 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0

In [122]:
# prints: [Example][<EOS>][<pad>][Example Summary][<EOS>]
print(f'Single example:\n\n {detokenize(train_input)}')

Single example:

 One of the country’s biggest rail terminals will be ‘effectively
closed’ today because of over-running engineering works. There will be
no trains in or out of King’s Cross in London due to delays to Network
Rail works north of the station. The disruption comes on one of the
busiest travel days of the year, as thousands of people try to return
home after visiting family for Christmas. Scroll down for video .
Services in and out of London Kings Cross station have been cancelled
today, it has been announced . Frustration: Travellers at the London
station, one of the busiest in the country, where services are
cancelled . The disruptions at the station, which is managed by
Network Rail, will affect those planning to travel on East Coast,
First Hull Trains, Grand Central and Great Northern services. East
Coast Trains made the announcement on its website yesterday evening,
where it advised passengers to delay their travel if possible. It also
said that a revised timetable is

In [123]:
# inputs are in different lengths and padding them with 0s will wast computational resource so good approach 
# would be group the strings to specific sizes and process, ww use buckets to create batched generators
# buckets are defined based on boundaries and batch sizes, batch size[i] signifies the 
# batch size for the items with length < boundaries[i], so we use batch size 4 of length<512, 8 of length<256,16 of sentence lewnght <128 so on

boundaries =[128,256,512,1024]
batch_size=[16,8,4,2,1]

# now create the stream
train_batch_stream=trax.data.BucketByLength(boundaries,batch_size)(train_stream)
eval_batch_stream=trax.data.BucketByLength(boundaries,batch_size)(eval_stream)


In [124]:
# different articles will be produced every time
input_batch,_,mask_batch=next(train_batch_stream)
input_batch.shape

(1, 1095)

In [125]:
# let's see the corresponding iteger values
input_batch

array([[    9,  3890,  2726,  1442,  1475,    31, 13062,  1379,  1329,
          132, 10448,  5403,   251,  4018,   691, 14850,    28, 14496,
          340,  1733,     6,   112,  2365,    95,  4240, 14662,    17,
         9012,  5121,  2796,  2301,  9528,    79,     3,  1803,  4219,
            6,   605,  4970, 10492,  8338,  1619,   213,   344,     2,
        13022,    16,   441,   694,   186, 15432,   595,  3890,  2726,
         1442,    80,   269, 13628,  4961,   412,   444,     6,   179,
         4970,  7372, 15301,   246,  3060,     2,   192,  2262,  7158,
         5745,  5047,    58,     2, 12746,    75,     6,   605, 12929,
            5, 16647,   186,   203,  1223,  4970,  7358,    43,  2696,
         6948,     3,  2301,  9528,    79,  2262, 11728, 14920,    79,
        20227, 26377,   471,  2696,    28,   532, 25222,   232,  1124,
            2,  5684,   691,   492, 22100,     4,   186,  8349,  1831,
         4219,     6,   605,  2814, 18685, 21351,     2,    35,    77,
      

In [126]:
# above input batch consists:
#     all the values corresponding to words, first 1 represents the <EOS> of the article followed by 0 (pads)
#     after the first 0, other values show the summary words and the second 1 represent the <EOS> tag for summary
#       All the other 0s are to maintain the consistancy of the  length - max length specified in the bucket 
# show the processes input data batch
print("Article and summary:\n\n",detokenize(input_batch[0]))

Article and summary:

 The Ospreys opened their Champions Cup campaign in emphatic fashion by
posting a bonus point 42-7 victory over outclassed Liberty Stadium
visitors Treviso. Wales fly-half Dan Biggar ran the show, kicking 17
points and orchestrating Ospreys' best attacking moments as full-back
Dan Evans touched down twice, while wing Jeff Hassler, scrum-half Rhys
Webb and number eight Dan Baker also scored tries. Treviso wing
Ludovico Nitoglia scored a late consolation try, converted by former
Worcester and Wasps fly-half Joe Carlisle, but there was little for
the Italians to cheer on another dismal European occasion for them.
Dan Evans scores his second, and Osprey's fifth, try of the match in
what was a routine win over Treviso . It meant the unbeaten Guinness
PRO12 leaders took early charge of Pool Five, teeing up nicely next
Saturday's clash against reigning Aviva Premiership champions
Northampton at Franklin's Gardens. Saints will start as favourites -
they beat Ospreys twice

In [127]:
# so the structure is article <EOS><PADs>article summary <EOS><pads>
# loss is taken only on the summary using cross entropy as loss function
#  Now create helper functions to create tensor and to display tensors using jax numpy array

def create_tensor(tensor):
    """input list of lists out put a tensor"""
    return jnp.array(tensor)

def display_tensor(tensor,name):
    """ display the name and tensor"""
    print(f'{name} shape: {tensor.shape}\n')
    print(f'{tensor}\n')
    

# try with dummy data and build attention - dot product
$$
\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\
$$

In [128]:
# Create dotproduct attention
def DotProductAttention(query,key,value,mask):
    """ dot product self attention
    args: query - jax.intepreters.xla.DeviceArray: array of query representations with shape L_q by d
    key jax.intepreters.xla.DeviceArray: array of query representation with shape L_k by d
    value jax.intepreters.xla.DeviceArray: array of value representation with shpeL_k by d where L_v=L_k
    mask jax.interpreters.sla.DeviceArray: attention mask, gates attention with shape L_q by L_k ( this is due to dot product)
    returns jax.intperters.xla.DeviceArray: self-attention array for q,k,v arrays L_q by L_k"""
    
    assert query.shape[-1]==key.shape[-1]==value.shape[-1],"Embeedding dimesions of q,k,v must be same"
#     get the depth dimentionality of the query embedding  for the scaling down the dot product
    depth=query.shape[-1]
    
#     get the scaled query key dot product according to formula above
    dots=jnp.amtmul(query,jnp.swapaxes(key,-1,-2))/jnp.sqrt(depth)
    
#     now apply the mask
    if mask is None:
        dots=jnp.where(mask,dots,jnp.full_like(dots,-1e9))
    
#     softmax
    logsumexp=trax.fastmath.logsumexp(dots,axis=-1, keepdims=True)
    
#     now get the exponnential of dots minus logsumexp to get softmax
    dots=jnp.exp(dots-logsumexp)
#      now multiply by values to get the attention
    attention=jnp.matmul(dots,vlaue)
    
    return attention
    

In [129]:
# Now inpmelent the causal attention: multi headed attention with mask to attend only the words that occured before
# 1. copute attention heads - get input with dimention (batch size, seqlen,n_heads X d_head) then splits the last(depth)
# dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size X n_heads,seqlen,d_head)

# 2. dot product self attention
# 3. compute attention output

In [130]:
def compute_attention_heads_closure(n_heads,d_head):
    """Function that simulates environment inside CasualAttention 
    function.
    Args: 
    d_head(int): dimensionality of heads.
    n_heads (int): number of attention heads
    Returns:
    function: compute_attention_heads function"""
    
    def compute_attention_heads(x):
        """Compute attention heads.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape(batch_size,
            seqlen,n_heads X d_head)
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen,d_head).
        """
        #x batch dimension
        batch_size=x.shape[0]
        #length of sequence should be size of x's first dimension without counting batch dim
        seqlen=x.shape[1]
        #now change the shape from batch_size,seqlen,n_heads*d_head to batch_size,seqlen,n_heads,d_head
        x=jnp.reshape(x,(batch_size,seqlen,n_heads,d_head))
        #then transpose batch_size,seqlen,n_heads,d_head-->batch_size,n_heads,seqlen,d_head
        #here the values within the tuple  are the indexes of the dimensions of x and need to rearrange them
        x=jnp.transpose(x,(0,2,1,3))
        #now reshape to batch_size,n_heads,seqlen,d_head -->batch_size*n_heads,seqlen,d_head
        x=jnp.reshape(x,(-1,seqlen,d_head))
        return x
    return compute_attention_heads
        

In [140]:
# Now create the dotproduct self attention with mask
def dot_product_self_attention(q,k,v):
    """Masked dot product self attention
    args: q (jax.interpreters.xla.DeviceArray): queries.
          k (jax.interpreters.xla.DeviceArray):keys.
          v (jax.interpreters.xla.DeviceArray):values
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor"""
#     mask size should be size of L_q. q has shape of (batch_size,L_q,d)
    mask_size=q.shape[-2]
    # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
    # Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_
    # Use jnp.tril() - Lower triangle of an array and jnp.ones()
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    
    ### END CODE HERE ###
    
    return DotProductAttention(q, k, v, mask)