# Summarizing text with transformers

This notebook addresses the tasks of text summarization using a transformer-based language model. An interesting thing to note is that this is heavily based on attention and does not rely on sequences, which allows for parallel computing. 
Also, those more familiar with ML and NLP will notice that many things in this notebook seem to be unnecessarily implemented "from scratch" (or not as abstracted as it could be). In a production code this approach would not be recommended in general, but this notebook is more focused on exploring the concepts and experimenting with the code :)

In [35]:
import sys
import os

import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

<a name='1'></a>
## Part 1: Importing the dataset

In [None]:
# Importing CNN/DailyMail articles dataset
train_stream_fn = trax.data.TFDS('cnn_dailymail',
                                 data_dir='data/',
                                 keys=('article', 'highlights'),
                                 train=True)

eval_stream_fn = trax.data.TFDS('cnn_dailymail',
                                data_dir='data/',
                                keys=('article', 'highlights'),
                                train=False)

<a name='1.1'></a>
## 1.1 Tokenize & Detokenize helper functions

NLP models often use tokenized inputs instead of using the raw text input. With that in mind, we'll use some handy functions to get from a text input to its tokenized version, and from the tokenized version to text (after all, we want to able to read things!). More specifically: 

- <span style='color:green'> tokenize: </span> converts a text sentence to its corresponding token list (i.e. list of indices); also converts words to subwords.
- <span style='color:green'> detokenize: </span> converts a token list to its corresponding sentence (string).

In [31]:
def tokenize(input_str, EOS=1):
    """Input str to features dict, ready for inference"""
  
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab_dir/',
                                      vocab_file='summarize32k.subword.subwords'))
    
    # Mark the end of the sentence with EOS
    return list(inputs) + [EOS]

def detokenize(integers):
    """List of ints to str"""
  
    s = trax.data.detokenize(integers,
                             vocab_dir='vocab_dir/',
                             vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(s)

<a name='1.2'></a>

## 1.2 Preprocessing for Language Models: Concatenate It!

To create a single input suitable for
a language model, we concatenate inputs with targets using a separator (a special token)
in between. We also need to create a mask -- with 0s at inputs and 1s at targets -- so that the model is not penalized for mis-predicting the article and only focuses on the summary. 

In [37]:
# Special tokens
SEP = 0 # Padding or separator token
EOS = 1 # End of sentence token

# Concatenate tokenized inputs and targets using 0 as separator.
def preprocess(stream):
    for (article, summary) in stream:
        joint = np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
        mask = [0] * (len(list(article)) + 2) + [1] * (len(list(summary)) + 1) # Accounting for EOS and SEP
        yield joint, joint, np.array(mask)
input_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                       vocab_file='summarize32k.subword.subwords'),

    preprocess,

    trax.data.FilterByLength(2048)
)


train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

<a name='1.3'></a>

## 1.3 Batching with bucketing

Our model (and this is common practice in AI/ML models in general, in case you're less familiar with the field) will use batches of data, so we'll define boundaries and their corresponding batch sizes for that purpose.

In [42]:
# Bucketing to create batched generators
boundaries =  [128, 256,  512, 1024]
batch_sizes = [16,    8,   4,  2, 1]

# Create the streams
train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(eval_stream)

In [None]:

input_batch, _, mask_batch = next(train_batch_stream)

# Shape of the input_batch
input_batch.shape

In [None]:
# print the article and its summary
print(input_batch[0])
print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=')
print('Article:\n\n', detokenize(input_batch[0]))

You can see that the data has the following structure:
- <span style='color:green'> [Article] </span> -> `<EOS>` -> `<pad>` -> <span style='color:green'> [Article Summary] </span> -> `<EOS>` -> (possibly) multiple `<pad>`

The loss is taken only on the summary using cross entropy as loss function. 

In [46]:
def create_tensor(t):
    """Create tensor from list of lists"""
    return jnp.array(t)


def display_tensor(t, name):
    """Display shape and tensor"""
    print(f'{name} shape: {t.shape}\n')
    print(f'{t}\n')

<a name='ex01'></a>

Now, we'll Implement the dot product attention. Concretely, as a reminder, this is the equation:


$$
\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\
$$

In [49]:

def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
    Args:
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

    Returns:
        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
    """

    depth = query.shape[-1]

    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
    
    
    # Apply the mask
    if mask is not None: 
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    
    # Note: softmax = e^(dots - logsumexp(dots)) = E^dots / sumexp(dots)
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)

    attention = jnp.matmul(dots, value)
    
    return attention

In [None]:
DotProductAttention(q_with_batch, k_with_batch, v_with_batch, m_bool)

Now, we can define some support functions.

### Support Functions

<span style='color:blgreen'> compute_attention_heads </span>: Gets an input $x$ of dimension (batch_size, seqlen, n_heads $\times$ d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size $\times$ n_heads, seqlen, d_head).

In [52]:

def compute_attention_heads_closure(n_heads, d_head):
    """ Function that simulates environment inside CausalAttention function.
    Args:
        d_head (int):  dimensionality of heads.
        n_heads (int): number of attention heads.
    Returns:
        function: compute_attention_heads function
    """

    def compute_attention_heads(x):
        """ Compute the attention heads.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size, seqlen, n_heads X d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen, d_head).
        """


        batch_size = x.shape[0]
        seqlen = x.shape[1]
        x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
        x = jnp.transpose(x, (0, 2, 1, 3))
        x = jnp.reshape(x, (batch_size*n_heads, seqlen, d_head))
        
        return x
    
    return compute_attention_heads

In [None]:
display_tensor(tensor3dc3b, "input tensor")
result_cah = compute_attention_heads_closure(2,3)(tensor3dc3b)
display_tensor(result_cah, "output tensor")

<span style='color:green'> dot_product_self_attention </span>: Creates a mask matrix with `False` values above the diagonal and `True` values below and calls DotProductAttention which implements dot product self attention. This is important so that our model "only pays attention" to parts of the sentence that come before the current token.

In [54]:

def dot_product_self_attention(q, k, v):
    """ Masked dot product self attention.
    Args:
        q (jax.interpreters.xla.DeviceArray): queries.
        k (jax.interpreters.xla.DeviceArray): keys.
        v (jax.interpreters.xla.DeviceArray): values.
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
    """
    mask_size = q.shape[-2]
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    
    return DotProductAttention(q, k, v, mask)

In [None]:
dot_product_self_attention(q_with_batch, k_with_batch, v_with_batch)

<span style='color:green'> compute_attention_output </span>: Undoes compute_attention_heads by splitting first (vertical) dimension and stacking in the last (depth) dimension (batch_size, seqlen, n_heads $\times$ d_head). These operations concatenate (stack/merge) the heads. 

In [55]:

def compute_attention_output_closure(n_heads, d_head):
    """ Function that simulates environment inside CausalAttention function.
    Args:
        d_head (int):  dimensionality of heads.
        n_heads (int): number of attention heads.
    Returns:
        function: compute_attention_output function
    """
    
    def compute_attention_output(x):
        """ Compute the attention output.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size X n_heads, seqlen, d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size, seqlen, n_heads X d_head).
        """
        seqlen = x.shape[1]
        x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
        x = jnp.transpose(x, (0, 2, 1, 3))

        
        # to allow to concatenate the heads
        return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
    
    return compute_attention_output

In [None]:
display_tensor(result_cah, "input tensor")
result_cao = compute_attention_output_closure(2,3)(result_cah)
display_tensor(result_cao, "output tensor")

### Causal Attention Function

Now we can use some components define before wiht the `CausalAttention` or Masked multi-head attention function

In [57]:

def CausalAttention(d_feature, 
                    n_heads, 
                    compute_attention_heads_closure=compute_attention_heads_closure,
                    dot_product_self_attention=dot_product_self_attention,
                    compute_attention_output_closure=compute_attention_output_closure,
                    mode='train'):
    """Transformer-style multi-headed causal attention.

    Args:
        d_feature (int):  dimensionality of feature embedding.
        n_heads (int): number of attention heads.
        compute_attention_heads_closure (function): Closure around compute_attention heads.
        dot_product_self_attention (function): dot_product_self_attention function. 
        compute_attention_output_closure (function): Closure around compute_attention_output. 
        mode (str): 'train' or 'eval'.

    Returns:
        trax.layers.combinators.Serial: Multi-headed self-attention model.
    """
    
    assert d_feature % n_heads == 0
    d_head = d_feature // n_heads

    ComputeAttentionHeads = tl.Fn('AttnHeads', compute_attention_heads_closure(n_heads, d_head), n_out=1)
        

    return tl.Serial(
        tl.Branch( 
            [tl.Dense(d_feature), ComputeAttentionHeads], # queries
            [tl.Dense(d_feature), ComputeAttentionHeads], # keys
            [tl.Dense(d_feature), ComputeAttentionHeads], # values
        ),
        
        tl.Fn('DotProductAttn', dot_product_self_attention, n_out=1), 
        tl.Fn('AttnOutput', compute_attention_output_closure(n_heads, d_head), n_out=1), # to allow for parallel
        tl.Dense(d_feature)
    )


In [None]:
# You can see the architecture so far with this command
print(CausalAttention(d_feature=512, n_heads=8))

## 2.3 Transformer decoder block

Now we will implement the transformer decoder block.

In [59]:
def DecoderBlock(d_model, d_ff, n_heads,
                 dropout, mode, ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

    The input is an activation tensor.

    Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
    """
    
    causal_attention = CausalAttention( 
                        d_model,
                        n_heads=n_heads,
                        mode=mode
                        )

    feed_forward = [ 
        tl.LayerNorm(),

        tl.Dense(d_ff),

        ff_activation(), 

        tl.Dropout(rate=dropout, mode=mode),

        tl.Dense(d_model),

        tl.Dropout(rate=dropout, mode=mode)
    ]

    # Add list of two Residual blocks: the attention with normalization and dropout and feed-forward blocks
    return [
      tl.Residual(

          tl.LayerNorm(),

          causal_attention,

          tl.Dropout(rate=dropout, mode=mode)
        ),
      tl.Residual(

          feed_forward
        ),
      ]


In [None]:
# Take a look at the decoder block
print(DecoderBlock(d_model=512, d_ff=2048, n_heads=8, dropout=0.1, mode='train', ff_activation=tl.Relu))

## 2.4 Transformer Language Model
Using what we have build so far, it is possible to put together a transformer model.

In [61]:
def TransformerLM(vocab_size=33300,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  dropout=0.1,
                  max_len=4096,
                  mode='train',
                  ff_activation=tl.Relu):
    """Returns a Transformer language model.

    The input to the model is a tensor of tokens. (This model uses only the
    decoder part of the overall Transformer.)

    Args:
        vocab_size (int): vocab size.
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_layers (int): number of decoder layers.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        max_len (int): maximum symbol length for positional encoding.
        mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens
        to activations over a vocab set.
    """
    # Embedding inputs and positional encoder
    positional_encoder = [ 
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode)]

    decoder_blocks = [ 
        DecoderBlock(d_model, d_ff, n_heads,
                    dropout, mode, ff_activation) for _ in range(n_layers)]

    # Create the complete model as written in the figure
    return tl.Serial(
        tl.ShiftRight(mode=mode),
        positional_encoder,
        decoder_blocks,
        tl.LayerNorm(),

        tl.Dense(vocab_size),
        tl.LogSoftmax()
    )

In [None]:
# Take a look at the transformer
print(TransformerLM(n_layers=1))

<a name='3'></a>
# Part 3: Training

Finally, we can train our model. We'll see how we can train the model, but we'll also use a pretrained model to save time (without surprises, transformers model can take a long time to train).

### 3.1 Training the model

Using Trax's TrainTask, EvalTask, and Loop functionalities, we can easily define the training pipeline for our model

In [65]:
from trax.supervised import training

def training_loop(TransformerLM, train_gen, eval_gen, output_dir = "~/model"):
    '''
    Input:
        TransformerLM (trax.layers.combinators.Serial): The model you are building.
        train_gen (generator): Training stream of data.
        eval_gen (generator): Evaluation stream of data.
        output_dir (str): folder to save your file.
        
    Returns:
        trax.supervised.training.Loop: Training loop.
    '''
    output_dir = os.path.expanduser(output_dir)  
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01)

    train_task = training.TrainTask( 
      labeled_data=train_gen, 
      loss_layer=tl.CrossEntropyLoss(),
      optimizer=trax.optimizers.Adam(0.01), 
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=10
    )

    eval_task = training.EvalTask( 
      labeled_data=eval_gen, 
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] 
    )


    loop = training.Loop(TransformerLM(d_model=4,
                                       d_ff=16,
                                       n_layers=1,
                                       n_heads=2,
                                       mode='train'),
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    
    return loop

In [None]:
# Will probably take some time to run, but later on we will import a pretrained model, so do not worry if you do not want to run this
!rm -f ~/model/model.pkl.gz
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)
loop.run(1000)

 <a name='4'></a>
 # Part 4:  Evaluation  

<a name='4.1'></a>
### 4.1 Loading in a trained model

For evaluation, we'll import a pretrained model as mentioned before. The previous model and this one are essentially the same, but you can see a few differences in terms of parameters as follows:

    
   `Original (pretrained) model: `                                 
                                       
    TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, 
                   dropout=0.1, max_len=4096, ff_activation=tl.Relu)
                   
   `Your model:`
   
    TransformerLM(d_model=4, d_ff=16, n_layers=1, n_heads=2)
   
   **Only the parameters shown for the other model were changed. The others stayed the same.**

In [None]:
# Get the model architecture
model = TransformerLM(mode='eval')

model.init_from_file('model.pkl.gz', weights_only=True)

<a name='5'></a>
# Part 5: Generating outputs

There are some different algorithms to generate sequences using our model. One of those is greedy decoding, which basically takes the output with the highest probability at each step (that's why it is greedy).

In [68]:
def next_symbol(cur_output_tokens, model):
    """Returns the next symbol for a given sentence.

    Args:
        cur_output_tokens (list): tokenized sentence with EOS and PAD tokens at the end.
        model (trax.layers.combinators.Serial): The transformer model.

    Returns:
        int: tokenized symbol.
    """
    token_length = len(cur_output_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length+1)))

    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] 

    output, _ = model((padded_with_batch, padded_with_batch))
    log_probs= output[0, token_length, :]
    
    return int(np.argmax(log_probs))

In [None]:
# let's see how it works
sentence_test_nxt_symbl = "I want to fly in the sky."
detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model)])

Now that we can get the next symbol in the sequence, it becomes easy to implement greedy decoding itself!

In [70]:

def greedy_decode(input_sentence, model):
    """Greedy decode function.

    Args:
        input_sentence (string): a sentence or article.
        model (trax.layers.combinators.Serial): Transformer model.

    Returns:
        string: summary of the input.
    """
    
    cur_output_tokens = tokenize(input_sentence) + [0]
    generated_output = [] 
    cur_output = 0 
    EOS = 1 
    
    while cur_output != EOS:
        cur_output = next_symbol(cur_output_tokens, model)
        cur_output_tokens.append(cur_output)
        generated_output.append(cur_output)
        print(detokenize(generated_output))
    
    return detokenize(generated_output)

In [None]:
# give it a try
test_sentence = "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))

Well, greedy decoding works for demonstration purposes and is easy/intuitive to understand, but it does not necessarely produces the best results. That's basically because it chooses the best local result (the output with the highest probability at each step), but not necessarely the "global best". Considering different outputs sequences and choosing the one with highest conditional probability (you can think of it as the total probability of the sentence) might work better... and that's the idea of beam search! We'll implement it now, and you can use it instead of greedy decoding if you want to.

In [None]:
from heapq import heappush, heappop

def get_next_symbols(cur_output_tokens, model, beam_width):
    """Returns the top beam_width next symbols for a given sentence.
    
    Args:
        cur_output_tokens (list): tokenized sentence
        model (trax.layers.combinators.Serial): The transformer model
        beam_width (int): number of top candidates to consider
    
    Returns:
        list: top beam_width tokens and their log probabilities
    """
    token_length = len(cur_output_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length+1)))
    
    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :]
    
    output, _ = model((padded_with_batch, padded_with_batch))
    log_probs = output[0, token_length-1, :]
    
    top_tokens = np.argsort(-log_probs)[:beam_width]
    return [(float(log_probs[token]), int(token)) for token in top_tokens]

def beam_search_decode(input_sentence, model, beam_width=4, max_length=100):
    """Beam search decode function.
    
    Args:
        input_sentence (string): a sentence or article
        model (trax.layers.combinators.Serial): Transformer model
        beam_width (int): beam width for search
        max_length (int): maximum length of generated summary
    
    Returns:
        string: summary of the input
    """
    
    initial_tokens = tokenize(input_sentence)
    
    # Initialize beam
    # Each beam item is (log_prob, tokens, finished)
    beams = [(0.0, initial_tokens + [0], False)]
    finished_beams = []
    
    EOS = 1
    
    for _ in range(max_length):
        new_beams = []
        
        # Explore each beam
        for beam_log_prob, beam_tokens, finished in beams:
            if finished:
                # Keep finished beams as is
                heappush(new_beams, (beam_log_prob, beam_tokens, finished))
                continue
                
            # Get top tokens for this beam
            next_token_predictions = get_next_symbols(beam_tokens, model, beam_width)
            
            # Create new beams from predictions
            for token_log_prob, token in next_token_predictions:
                new_tokens = beam_tokens + [token]
                new_log_prob = beam_log_prob + token_log_prob
                is_finished = token == EOS
                
                heappush(new_beams, (new_log_prob, new_tokens, is_finished))
        
        # Keep top beam_width beams
        beams = []
        for _ in range(min(beam_width, len(new_beams))):
            beam = heappop(new_beams)
            if beam[2]:  # If beam is finished
                finished_beams.append(beam)
            else:
                beams.append(beam)
                
        # Early stopping if all beams are finished
        if not beams:
            break
    
    # If we have no finished beams, take the best unfinished one
    if not finished_beams and beams:
        finished_beams = beams
    
    best_beam = max(finished_beams, key=lambda x: x[0]/len(x[1]))
    generated_tokens = best_beam[1][len(initial_tokens):]
    
    return detokenize(generated_tokens)

Now, let's see how the model workds on inference time, by using a bigger input "article"

In [None]:
article = "It’s the posing craze sweeping the U.S. after being brought to fame by skier Lindsey Vonn, soccer star Omar Cummings, baseball player Albert Pujols - and even Republican politician Rick Perry. But now four students at Riverhead High School on Long Island, New York, have been suspended for dropping to a knee and taking up a prayer pose to mimic Denver Broncos quarterback Tim Tebow. Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were all suspended for one day because the ‘Tebowing’ craze was blocking the hallway and presenting a safety hazard to students. Scroll down for video. Banned: Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll (all pictured left) were all suspended for one day by Riverhead High School on Long Island, New York, for their tribute to Broncos quarterback Tim Tebow. Issue: Four of the pupils were suspended for one day because they allegedly did not heed to warnings that the 'Tebowing' craze at the school was blocking the hallway and presenting a safety hazard to students."
print(wrapper.fill(article), '\n')
print(greedy_decode(article, model))