In [69]:
import sys
import os
import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

In [70]:

# Get the data - download the dataset if no data_dir is specified.
# so we have the data already in 'data/' for you

# Importing CNN/DailyMail articles dataset
train_stream_fnction = trax.data.TFDS('cnn_dailymail',
                                 data_dir='../news_data/',
                                 keys=('article', 'highlights'),
                                 train=True)

# This should be much faster as the data is downloaded already.
eval_stream_fnction = trax.data.TFDS('cnn_dailymail',
                                data_dir='../news_data/',
                                keys=('article', 'highlights'),
                                train=False)

#### Create tokenize and detokenize functions

In [71]:
# Now need create helper functions to tokenize and detokenize data. Tokenise converts a text sentence to its
# corresponding token list (i.e. list of indices). Also converts words to subwords.
# similarly we need to have detokenize function to reconvert the tokens to its sentence

def tokenize(input_str,EOS=1):
    """ convert input string to a feature dictionary"""
#     trax.data.tokenize method takes streams and returns streams, we user iter to have one elment stream
    input_sting=next(trax.data.tokenize(iter([input_str]),
                                       vocab_dir='vocab_dir/',
                                       vocab_file='summarize32k.subword.subwords'))
#     put EOS at the end of sentence
    return list(input_string)+[EOS]

def detokenize(input_integers):
    """convert input intergers to string"""
    string_converted=trax.data.detokenize(input_integers,
                                        vocab_dir='vocab_dir/',
                                        vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(string_converted)


In [72]:
# Language model and preprocessing
# language models only predicts next work,we concatenate inputs with target and seperate them
# with a seperator and concatenate them. Further padding masks are used 0s and 1s in input and targets 
# respectively. So the focus is model to pay attention on summary.

In [73]:
# mask tokens
SEP=0 #Padding or separator
EOS=1 #end of token sentence

# # Now lets concatenate input tokens and targets using 0 as seperator
def preprocess(stream):
    """get the data stream and seperate with 0, stream data comming with articles and summary"""
    for (article,summary) in stream:
        combine=np.array(list(article)+[EOS,SEP]+list(summary)+[EOS])
        mask=[0]*(len(list(article))+2)+[1]*(len(list(summary))+1)
        yield combine,combine,np.array(mask)

# # make data pipeline as follows
input_pipeline=trax.data.Serial(
#     first tokennize
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                        vocab_file='summarize32k.subword.subwords'),
#     now use the above function preprocess
    preprocess,
#     need to filter out the strings longer than 2018
    trax.data.FilterByLength(2048)
)

# # Apply above pipeline to both train and evaluation data
train_stream=input_pipeline(train_stream_fnction())
eval_stream=input_pipeline(eval_stream_fnction())

# get one by one
train_input,train_target,train_mask=next(train_stream)
# train and target shoud be same language model
assert sum((train_input-train_target)**2)==0


In [74]:
# prints mask, 0s on article, 1s on summary
print(f'Single example mask:\n\n {train_mask}')

Single example mask:

 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0

In [75]:
# prints: [Example][<EOS>][<pad>][Example Summary][<EOS>]
print(f'Single example:\n\n {detokenize(train_input)}')

Single example:

 A curry house has been fined £3,000 after two dead mice were
discovered next to a sack of onions. Health inspectors also found
rodent droppings in the kitchen and food stored inside a dirty shed
when they visited Khans Tandoori and Balti Takeaway in Portsmouth,
Hampshire. The kitchen sink was also dry and had no soap and the staff
toilet was in a poor condition, while layers of dirt, grease and
debris had built up in the areas where food was handled and stored.
Khans Tandoori and Balti Takeaway in Portsmouth, Hampshire, has been
fined £3,000 after these dead mice were discovered next to a sack of
onions . Magistrates ordered the restaurant to pay a £1,800 fine,
£1,117 in costs and a £36 victim surcharge, after owner of Harold
Southsea Ltd, admitted five charges under hygiene legislation on
behalf of the takeaway. Health officer Christopher Larkin from
Portsmouth City Council had discovered the dead mice on sticky boards
that had been put out to catch the vermin when h

In [76]:
# inputs are in different lengths and padding them with 0s will wast computational resource so good approach 
# would be group the strings to specific sizes and process, ww use buckets to create batched generators
# buckets are defined based on boundaries and batch sizes, batch size[i] signifies the 
# batch size for the items with length < boundaries[i], so we use batch size 4 of length<512, 8 of length<256,16 of sentence lewnght <128 so on

boundaries =[128,256,512,1024]
batch_size=[16,8,4,2,1]

# now create the stream
train_batch_stream=trax.data.BucketByLength(boundaries,batch_size)(train_stream)
eval_batch_stream=trax.data.BucketByLength(boundaries,batch_size)(eval_stream)


In [77]:
# different articles will be produced every time
input_batch,_,mask_batch=next(train_batch_stream)
input_batch.shape

(2, 1024)

In [78]:
# let's see the corresponding iteger values
input_batch

array([[    7,  6062, 19330,    21, 15570,  3385, 10077,    23,    46,
        11555,   527,  1147,  1865, 23891,  6412,   379,   514,  3558,
         8998,  5754,   527, 17296,  6933, 18453,  2337,   320, 26377,
           14,   186,  6157,  2416,    23,    46, 11555,   527,  1147,
         1865, 23891,  6412,     3,  3385, 10077,  6503, 20889, 24810,
           16,    72,   441,     6,   104,     6, 23324,  7511,    41,
          806,  1019,   134,   809,    15,  9158, 15019,   410,   395,
          799,    28,  2145,   809,  9744,  5429,   812,     3,     9,
         2823,     6,   104,     6,   292, 13431,  1353,   592,   233,
           19,  8272,   527,  1147,  1865, 23891,  6412,   214,   213,
           72,   331,    35,  6122,    87,   527,    15,  4175,   117,
          143,    18,    46,   576,   132,   213,  2297,   138,  1099,
         5400,    22,   379, 17895,   102,   213,   347,     2,   368,
        10077,   127,    22,  1425,   117,  3965, 19330,    21,    80,
      

In [79]:
# above input batch consists:
#     all the values corresponding to words, first 1 represents the <EOS> of the article followed by 0 (pads)
#     after the first 0, other values show the summary words and the second 1 represent the <EOS> tag for summary
#       All the other 0s are to maintain the consistancy of the  length - max length specified in the bucket 
# show the processes input data batch
print("Article and summary:\n\n",detokenize(input_batch[0]))

Article and summary:

 'Vindicated': Stephen Perry has been cleared of seven sexual offences
. An insurance boss accused of hiring attractive teenage girls to ogle
and grope has been cleared of seven sexual offences. Stephen Perry
denied sexually assaulting two 17-year-olds when they worked for him
at his Rotherham business throughout a trial at Sheffield Crown Court.
The 58-year-old widow was today found not guilty of seven sexual
offences against the two women but admitted some of his behaviour
'could have been taken in the wrong way'. rothe . Speaking after the
case, Mr Perry said he felt 'vindicated' by the outcome of the trial.
'I feel I have been vindicated. These allegations were unfounded and I
told the police as much. 'I am a friendly man and one of them
described me as a fatherly figure which is about right. I encouraged a
friendly, relaxed atmosphere at work and they misconstrued things.'
The insurance chief added he would not hire teenagers to work at his
firm in the future

In [80]:
# so the structure is article <EOS><PADs>article summary <EOS><pads>
# loss is taken only on the summary using cross entropy as loss function
#  Now create helper functions to create tensor and to display tensors using jax numpy array

def create_tensor(tensor):
    """input list of lists out put a tensor"""
    return jnp.array(tensor)

def display_tensor(tensor,name):
    """ display the name and tensor"""
    print(f'{name} shape: {tensor.shape}\n')
    print(f'{tensor}\n')
    

# try with dummy data and build attention - dot product
$$
\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\
$$

In [81]:
# Create dotproduct attention
def DotProductAttention(query,key,value,mask):
    """ dot product self attention
    args: query - jax.intepreters.xla.DeviceArray: array of query representations with shape L_q by d
    key jax.intepreters.xla.DeviceArray: array of query representation with shape L_k by d
    value jax.intepreters.xla.DeviceArray: array of value representation with shpeL_k by d where L_v=L_k
    mask jax.interpreters.sla.DeviceArray: attention mask, gates attention with shape L_q by L_k ( this is due to dot product)
    returns jax.intperters.xla.DeviceArray: self-attention array for q,k,v arrays L_q by L_k"""
    
    assert query.shape[-1]==key.shape[-1]==value.shape[-1],"Embeedding dimesions of q,k,v must be same"
#     get the depth dimentionality of the query embedding  for the scaling down the dot product
    depth=query.shape[-1]
    
#     get the scaled query key dot product according to formula above
    dots=jnp.matmul(query,jnp.swapaxes(key,-1,-2))/jnp.sqrt(depth)
    
#     now apply the mask
    if mask is not None:
        dots=jnp.where(mask,dots,jnp.full_like(dots,-1e9))
    
#     softmax
    logsumexp=trax.fastmath.logsumexp(dots,axis=-1, keepdims=True)
    
#     now get the exponnential of dots minus logsumexp to get softmax
    dots=jnp.exp(dots-logsumexp)
#      now multiply by values to get the attention
    attention=jnp.matmul(dots,value)
    
    return attention
    

In [82]:
# Now inpmelent the causal attention: multi headed attention with mask to attend only the words that occured before
# 1. copute attention heads - get input with dimention (batch size, seqlen,n_heads X d_head) then splits the last(depth)
# dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size X n_heads,seqlen,d_head)

# 2. dot product self attention
# 3. compute attention output

In [83]:
def compute_attention_heads_closure(n_heads,d_head):
    """Function that simulates environment inside CasualAttention 
    function.
    Args: 
    d_head(int): dimensionality of heads.
    n_heads (int): number of attention heads
    Returns:
    function: compute_attention_heads function"""
    
    def compute_attention_heads(x):
        """Compute attention heads.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape(batch_size,
            seqlen,n_heads X d_head)
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen,d_head).
        """
        #x batch dimension
        batch_size=x.shape[0]
        #length of sequence should be size of x's first dimension without counting batch dim
        seqlen=x.shape[1]
        #now change the shape from batch_size,seqlen,n_heads*d_head to batch_size,seqlen,n_heads,d_head
        x=jnp.reshape(x,(batch_size,seqlen,n_heads,d_head))
        #then transpose batch_size,seqlen,n_heads,d_head-->batch_size,n_heads,seqlen,d_head
        #here the values within the tuple  are the indexes of the dimensions of x and need to rearrange them
        x=jnp.transpose(x,(0,2,1,3))
        #now reshape to batch_size,n_heads,seqlen,d_head -->batch_size*n_heads,seqlen,d_head
        x=jnp.reshape(x,(-1,seqlen,d_head))
        return x
    return compute_attention_heads
        

In [84]:
# Now create the dotproduct self attention with mask
def dot_product_self_attention(q,k,v):
    """Masked dot product self attention
    args: q (jax.interpreters.xla.DeviceArray): queries.
          k (jax.interpreters.xla.DeviceArray):keys.
          v (jax.interpreters.xla.DeviceArray):values
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor"""
#     mask size should be size of L_q. q has shape of (batch_size,L_q,d)
    mask_size=q.shape[-2]
    # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
    # Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_
    # Use jnp.tril() - Lower triangle of an array and jnp.ones()
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    
    ### END CODE HERE ###
    
    return DotProductAttention(q, k, v, mask)

In [85]:
# Now compute attention output, change the dims back

In [86]:
def compute_attention_output_closure(n_heads,d_head):
    """Function to simulate environment inside Causal attention function
    input parameters: 
        n_deads: number of attention heads(int)
        d_head:dimensionality of heads(int)
        
    return:
        computed attention output function
    """
    
    def compute_attention_output(x):
        """
        input parameters:
        x: tensor of shape (batch_size X n_heads,sqlen,d_head)
        Returns:
        reshaped tensor (batch_size,seqlen,n_heads X d_head)"""
        seqlen=x.shape[1]
#         now reshape to (batch_size,n_heads,seqlen,d_head)
        x=jnp.reshape(x,(-1,n_heads,seqlen,d_head))
#     now  transpose to shape (batch_size,seqlen,n_heads,d_heads)
        x=jnp.transpose(x,(0,2,1,3))
#         now reshape back to allow concatenation
        return jnp.reshape(x,(-1,seqlen,n_heads*d_head))
    
    return compute_attention_output
        
        
    

In [87]:
# now put everything together to causal attention

In [88]:
def CausalAttention(d_feature,
                    n_heads,
                    compute_attention_heads_closure=compute_attention_heads_closure,
                    dot_product_self_attention=dot_product_self_attention,
                    compute_attention_output_closure=compute_attention_output_closure,
                    mode='train'):
    """ transformer multi head attention.
    input parameters:
        d_feature : feature embedding dimensionality(int).
        n_heads : number of attention heads
        compute_attention_heads_closure (function): Closure around attention heads.
        dot_product_self_attention (function): dot product function
        compute_attention_output_closure(function):closure around attention output
        mode : 'train' or 'eval' (sting)
        
    returns:
        Multi-head self attention model
    """
    
#     need to ensure d_feature is multiplication of n_heads
    assert d_feature%n_heads ==0
    d_head=d_feature//n_heads
    
    ComputeAttentionHeads=tl.Fn('Attention_heads',compute_attention_heads_closure(n_heads,d_head),n_out=1)
    
    return tl.Serial(
        tl.Branch(# create queries, keys and values
            [tl.Dense(d_feature),ComputeAttentionHeads],
            [tl.Dense(d_feature),ComputeAttentionHeads],
            [tl.Dense(d_feature),ComputeAttentionHeads]),
        tl.Fn('DotProductAttention',dot_product_self_attention,n_out=1),
        tl.Fn('AttentionOutput',compute_attention_output_closure(n_heads,d_head),n_out=1), #to allow for parallel
        tl.Dense(d_feature) #final dense layer
    )

In [89]:
print(CausalAttention(d_feature=512, n_heads=8))

Serial[
  Branch_out3[
    [Dense_512, Attention_heads]
    [Dense_512, Attention_heads]
    [Dense_512, Attention_heads]
  ]
  DotProductAttention_in3
  AttentionOutput
  Dense_512
]


#### Transformer decoder block

In [90]:
def DecoderBlock(d_model,depth_of_ff_layer,n_heads,dropout,mode,ff_activation):
    """This fuction returns a list of layers for a transformer decoder block.
        The input is an activation tensor
        
        Inputs:
            d_model:depth of embedding (int)
            depth_of_ff_layer: depth_of_ff_layer
            n_heads: number of attention heads
            dropout: dropout rate (float)
            mode: 'train' or 'eval' 
            ff_activation(function): the non linearity of feed-forward layer
        Returns:
            list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor"""
    
#     Create masked multi-head attention block using CausalAttention fucntion
    causal_attention=CausalAttention(
                                    d_model,
                                    n_heads=n_heads,
                                    mode=mode)
    
#     now create feed-forward block(list) with 2 dense layers with dropout and input normalised
    feed_forward=[
#         Normalised input layer
        tl.LayerNorm(),
#         add first feed forward dense layer( n units are depth_of_ff_layer)
        tl.Dense(depth_of_ff_layer),
        ff_activation(), #Relu activation
        tl.Dropout(rate=dropout,mode=mode),
#         add second feed forward layer
        tl.Dense(d_model),
        tl.Dropout(rate=dropout,mode=mode)
    ]
#     add two residual block
    return [
        tl.Residual(
        #Normalise layer input
        tl.LayerNorm(),
        causal_attention,
        tl.Dropout(rate=dropout,mode=mode)),
        tl.Residual(
        feed_forward),
    ]

In [91]:
print(DecoderBlock(d_model=512, depth_of_ff_layer=2048, n_heads=8, dropout=0.1, mode='train', ff_activation=tl.Relu))

[Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Serial[
        Branch_out3[
          [Dense_512, Attention_heads]
          [Dense_512, Attention_heads]
          [Dense_512, Attention_heads]
        ]
        DotProductAttention_in3
        AttentionOutput
        Dense_512
      ]
      Dropout
    ]
  ]
  Add_in2
], Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Dense_2048
      Serial[
        Relu
      ]
      Dropout
      Dense_512
      Dropout
    ]
  ]
  Add_in2
]]


In [92]:
# Put together everything build language model


In [93]:
def transformer_language_model(vocab_size=33300,
                               d_model=512,
                               depth_of_ff_layer=2048,
                               n_layers=6,
                               n_heads=8,
                               dropout=0.1,
                               max_len=4096,
                               mode='train',
                               ff_activation=tl.Relu):
    """This function returns a transformer language model
     input is tensors of tokens and this model uses the decorder part of the overall transformer
     
     input parameters:
         vocab_size: vocab size (int)
         d_model: depth of embedding(int)
         depth_of_ff_layer: depth_of_ff_layer(int)
         n_layers: number of decoder layers (int)
         n_heads: number of attention heads (int)
         dropout:dropout rate (float)
         max_len: maximum symbol length for positional encording (int)
         mode(str): train or eval
         ff_activation (function): the non linearity in the feed forward layer
         
    returns:
        trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens
        to activations over a vocab set.
         
     """
    positional_encoder=[
        #Add embedding layer of dimension (vocab_size,d_model)
        tl.Embedding(vocab_size,d_model),
        # dropout
        tl.Dropout(rate=dropout,mode=mode),
        # add positional encoding layer with max input length and specified mode
        tl.PositionalEncoding(max_len=max_len,mode=mode)
        ]
    # now create stack - list of decoder blocks with n_layers with necessary parameters
    decoder_blocks = [
        DecoderBlock(d_model,depth_of_ff_layer,n_heads,dropout,mode,ff_activation) for _ in range(n_layers)
    ]
    
    # now build the complete model
    return tl.Serial(
#         use teacher forcing for efficient training
        tl.ShiftRight(mode=mode),
#         add position encoder
        positional_encoder,
#         add decoder block
        decoder_blocks,
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax()
    
    )
    

In [94]:
# Take a look at the Transformer
print(transformer_language_model(n_layers=1))

Serial[
  Serial[
    ShiftRight(1)
  ]
  Embedding_33300_512
  Dropout
  PositionalEncoding
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Serial[
          Branch_out3[
            [Dense_512, Attention_heads]
            [Dense_512, Attention_heads]
            [Dense_512, Attention_heads]
          ]
          DotProductAttention_in3
          AttentionOutput
          Dense_512
        ]
        Dropout
      ]
    ]
    Add_in2
  ]
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Dense_2048
        Serial[
          Relu
        ]
        Dropout
        Dense_512
        Dropout
      ]
    ]
    Add_in2
  ]
  LayerNorm
  Dense_33300
  LogSoftmax
]


In [95]:
# Implement the training

In [96]:
from trax.supervised import training

def training_loop(transformer_language_model,train_gen,eval_gen,output_dir="./"):
    """
    Input:
        transformer_language_model (trax.layers.combinators.Serial):previously built model
        train_gen(generator): Training stream of data.
        eval_gen(generator): Evaluation stream of data
        output dir(str): loacation to save the train model
        
    returns:
        trax.supervised.training.Loop: Training loop.
    """
    
    output_dir=os.path.expanduser(output_dir)
    lr_schedule=trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000,max_value=0.01)
    
    train_task=training.TrainTask(
        labeled_data=train_gen,
        loss_layer=tl.CrossEntropyLoss(),
        optimizer=trax.optimizers.Adam(0.01),
        lr_schedule=lr_schedule,
        n_steps_per_checkpoint=10)
    
    eval_task=training.EvalTask(
        labeled_data=eval_gen,
        metrics=[tl.CrossEntropyLoss(),tl.Accuracy()])
    
    loop=training.Loop(transformer_language_model(                
                                            d_model=512,
                                            depth_of_ff_layer=2048,
                                            n_layers=6,
                                            n_heads=8,
                                            mode='train'),
                                            train_task,
                                            eval_tasks=[eval_task],
                                            output_dir=output_dir)
    return loop

In [None]:
loop = training_loop(transformer_language_model, train_batch_stream, eval_batch_stream)
loop.run(10)


Step      1: Total number of trainable weights: 55144980
Step      1: Ran 1 train steps in 56.83 secs
Step      1: train CrossEntropyLoss |  10.46865654
Step      1: eval  CrossEntropyLoss |  10.43511200
Step      1: eval          Accuracy |  0.00000000
