# **Text Summarization using Transformer Model**

In [None]:
!pip install trax

Collecting trax
[?25l  Downloading https://files.pythonhosted.org/packages/a8/04/0c04116bbb372f459ad0a73bf306c5000f9fd63a8419bb179381f54773aa/trax-1.3.5-py2.py3-none-any.whl (416kB)
[K     |▉                               | 10kB 14.2MB/s eta 0:00:01[K     |█▋                              | 20kB 2.5MB/s eta 0:00:01[K     |██▍                             | 30kB 3.4MB/s eta 0:00:01[K     |███▏                            | 40kB 3.9MB/s eta 0:00:01[K     |████                            | 51kB 3.1MB/s eta 0:00:01[K     |████▊                           | 61kB 3.4MB/s eta 0:00:01[K     |█████▌                          | 71kB 3.6MB/s eta 0:00:01[K     |██████▎                         | 81kB 3.9MB/s eta 0:00:01[K     |███████                         | 92kB 4.1MB/s eta 0:00:01[K     |███████▉                        | 102kB 4.2MB/s eta 0:00:01[K     |████████▋                       | 112kB 4.2MB/s eta 0:00:01[K     |█████████▍                      | 122kB 4.2MB/s eta 0:00

In [None]:
import sys
import os

import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

# **1. Importing the Dataset**

In [None]:
# Importing CNN/DailyMail articles dataset
train_stream_fn = trax.data.TFDS('cnn_dailymail', keys = ('article', 'highlights'), train = 'True')
eval_stream_fn = trax.data.TFDS('cnn_dailymail', keys = ('article', 'highlights'), train = 'False')

[1mDownloading and preparing dataset cnn_dailymail/plain_text/3.0.0 (download: 558.32 MiB, generated: 1.27 GiB, total: 1.82 GiB) to /root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0.incompleteZ5001O/cnn_dailymail-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=287113.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0.incompleteZ5001O/cnn_dailymail-validation.tfrecord


HBox(children=(FloatProgress(value=0.0, max=13368.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0.incompleteZ5001O/cnn_dailymail-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=11490.0), HTML(value='')))



[1mDataset cnn_dailymail downloaded and prepared to /root/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0. Subsequent calls will reuse this data.[0m


### **1.1 Tokenize and Detokenize functions**

In [None]:
vocab_file = 'summarize32k.subword.subwords'
vocab_dir = '/content/subwords'

def tokenize(input_str, EOS = 1):
  inputs = next(trax.data.tokenize(iter([input_str]), vocab_dir = vocab_dir, vocab_file = vocab_file))
  return list(inputs) + [EOS]

def detokenize(integers):
  s = trax.data.detokenize(integers, vocab_dir = vocab_dir, vocab_file = vocab_file)
  return wrapper.fill(s)

### **1.2 Preprocessing**

In [None]:
#language models only predict the next word, they have no notion of inputs. 
#To create a single input suitable for a language model, we concatenate inputs with targets putting a separator in between.
#We also need to create a mask -- with 0s at inputs and 1s at targets so that the model is not penalized for mis-predicting the article and only focuses on the summary.
SEP = 0 #padding or separator token
EOS = 1 #End of Sentence token

def preprocess(stream):
  for (article, summary) in stream:
    joint = np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
    mask = [0] * (len(list(article)) + 2) + [1] * (len(list(summary)) + 1)
    yield joint, joint, np.array(mask)

input_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_dir = vocab_dir , vocab_file = 'summarize32k.subword.subwords'),
    preprocess,
    # Filters out examples longer than 2048
    trax.data.FilterByLength(2048))

#Apply preprocessing to data streams
train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

assert sum((train_input - train_target)**2) == 0  # They are the same in Language Model (LM).

In [None]:
print('SIngle mask from train:', train_mask, '\n')

SIngle mask from train: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 

In [None]:
print('Single Input article from train:\n\n', detokenize(train_input), '\n')

Single Input article from train:

 By . Chris Pleasance . PUBLISHED: . 08:09 EST, 15 September 2013 . | .
UPDATED: . 06:05 EST, 16 September 2013 . With divorce fast becoming
the norm it seems true love is harder and harder to find these days.
But as these letters reveal, one smitten couple found a love so strong
it was able to endure even while they were separated during World War
II. Not only that, it lasted for a total of 70 years of marriage until
they both died in 2011, he aged 96 and she aged 92, within just three
days of one another. How time passes: Frederick and Elizabeth Noble
wed on New Year's Day 1941 while Frederick . was on 48-hour leave from
the Royal Tank Corps and stayed married for . the next 70 years . But,
as remarkable as their relationship had seemed in life, in death it
was about to give up its biggest and most heartwarming surprise. After
Frederick and Elizabeth Noble had passed relatives discovered they had
kept 250 love letters, telegrams, notes and Valentine'

In [None]:
print('SIngle input summary (target) from train:\n\n', detokenize(train_target))

SIngle input summary (target) from train:

 It has been claimed that CIA agents on the ground during the deadly
attack on the U.S. Consulate in Benghazi twice asked for permission to
help Ambassador Chris Stevens and twice were told to stand down.
Furthermore sources present during the deadly six-hour assault have
said that a desperate last request for military assistance once the
CIA themselves came under attack was denied, even though elite
counter-terrorism units were only two hours away. And it has been
claimed there was full communication between the CIA annex in Benghazi
and the U.S. military, casting further doubts on the Obama
administration's assertion that there wasn't enough information to
deploy forces - deepening the crisis over their handling of the attack
on September 11th and its aftermath . Scroll down for video .
Revelations: It has been claimed today that CIA operatives at the
Benghazi consulate compound repeatedly had their requests for help
denied during the deadly

### **1.3 Batching with Bucketing**

In [None]:
boundaries = [128, 256, 512, 1024]
batch_sizes = [16, 8, 4, 2, 1]

train_batch_stream = trax.data.BucketByLength(boundaries, batch_sizes)(train_stream)
eval_batch_stream = trax.data.BucketByLength(boundaries, batch_sizes)(eval_stream)

In [None]:
input_batch, target_batch, mask_batch = next(train_batch_stream)

print('input_batch shape:', input_batch.shape, '\n')
print('target_batch shape:', target_batch.shape, '\n')
print('mask_batch shape:', mask_batch.shape, '\n')

input_batch shape: (2, 1024) 

target_batch shape: (2, 1024) 

mask_batch shape: (2, 1024) 



In [None]:
print('input_batch at position 0:\n\n', input_batch[0], '\n')
print('target_batch at position 0:\n\n', target_batch[0], '\n')
print('mask_batch at position 0:\n\n', mask_batch[0], '\n')

input_batch at position 0:

 [11772   236    11 17771  7724     4 10829  4576     2 12370    21     2
 12365  5856   320  3977    28 13684   379    27 20157 14063  5364   779
  1779  1768    28   290     6  1807  6890   236  1248   864   132    28
  1668  2642 12365  5019    15  2639   809    28 13684   186 13200    17
   117    13     7   371  3977    38   527    20     7    38  1099  2713
   127   684 10829  4576     2  2577     2   133   213  6765  1838    28
  2642   796   102  1987 23267 17463   478    95   213   945   527    15
   293    70    28  2380   132  1385   662   809  6972  1400  4305  3171
     2   542  4180     3 14308   864   186    28 11121  5550   248 22730
    78   213  1351  1177    70   186  1667  9027 10829  4576   320 15798
   102    28   290     6  1807  6890   236    78  1895   736     3 10829
  4576  1353  3873  1248 22440 11578    17  7726  1248    28 17418  9046
   186   229   144   475    78    28   281   286    88   226  8107     2
    28  1602  1838   2

In [None]:
print('Detokenized input_batch at position 0:\n\n', detokenize(input_batch[0]), '\n')
print('Detokenized target_batch at position 0:\n\n', detokenize(target_batch[0]), '\n')
print('Detokenized mask_batch at position 0:\n\n', detokenize(mask_batch[0]))

Detokenized input_batch at position 0:

 Standoff: Goerge Pickering, pictured, allegedly threatened to kill a
nurse . A distraught father who caused a four-hour standoff with
police in a Texas hospital allegedly pointed his gun at a nurse and
yelled 'I'll kill all of y'all'. Police said George Pickering, 57,
made the threats from a hospital room after becoming inconsolable over
the treatment of his son - a patient in critical care at Tomball
Regional Hospital, near Houston. Armed police and a SWAT team
descended on the medical center - and eventually convinced Pickering
to surrender after a four-hour standoff on Saturday night. Pickering
was charged with aggravated assault with a deadly weapon and is being
held on a $30,000 bond, a statement from the Tomball Police department
said. Detectives said Pickering was in the room with his son and
family, waited for a nurse to come, then aimed his 9mm pistol at her.
He then allegedly barricaded the room and threatened to kill anybody
who came 

# **2. Transformer**

In [None]:
# creates a jax numpy array from a list of lists.
def create_tensor(t):
  return jnp.array(t)

def display_tensor(t, name):
  print(f'{name} shape: {t.shape}\n')
  print(f'{t}\n')

In [None]:
#examples for Q, K, V M
q = create_tensor([[1,0,1], [1,1,0]])
display_tensor(q, 'Query')
k = create_tensor([[1,2,3], [4,5,6]])
display_tensor(k, 'Keys')
v = create_tensor([[0,1,0], [1,1,1]])
display_tensor(v, 'Values')
m = create_tensor([[0,0], [-1e9, 0]])
display_tensor(m, 'Mask')

Query shape: (2, 3)

[[1 0 1]
 [1 1 0]]

Keys shape: (2, 3)

[[1 2 3]
 [4 5 6]]

Values shape: (2, 3)

[[0 1 0]
 [1 1 1]]

Mask shape: (2, 2)

[[ 0.e+00  0.e+00]
 [-1.e+09  0.e+00]]





In [None]:
q_dot_k = q @ k.T / jnp.sqrt(3)
display_tensor(q_dot_k, 'query dot key')

query dot key shape: (2, 2)

[[2.309401  5.773503 ]
 [1.7320509 5.1961527]]



In [None]:
masked = q_dot_k + m
display_tensor(masked, 'masked query dot key')

masked query dot key shape: (2, 2)

[[ 2.3094010e+00  5.7735028e+00]
 [-1.0000000e+09  5.1961527e+00]]



In [None]:
display_tensor(masked @ v, 'masked query dot key dot value')

masked query dot key dot value shape: (2, 3)

[[ 5.7735028e+00  8.0829039e+00  5.7735028e+00]
 [ 5.1961527e+00 -1.0000000e+09  5.1961527e+00]]



In [None]:
#q, k, v, m with batch size
q_with_batch = q[None, :]
display_tensor(q_with_batch, 'query with batch dim')

k_with_batch = k[None,:]
display_tensor(k_with_batch, 'key with batch dim')

v_with_batch = v[None, :]
display_tensor(v_with_batch, 'value with batch dim')

m_bool = create_tensor([[True, True], [False, True]])
display_tensor(m_bool, 'boolean_mask')

query with batch dim shape: (1, 2, 3)

[[[1 0 1]
  [1 1 0]]]

key with batch dim shape: (1, 2, 3)

[[[1 2 3]
  [4 5 6]]]

value with batch dim shape: (1, 2, 3)

[[[0 1 0]
  [1 1 1]]]

boolean_mask shape: (2, 2)

[[ True  True]
 [False  True]]



## **2.1 Dot Product Attention**

In [None]:
def DotProductAttention(query, key, value, mask):

  assert query.shape[-1] == key.shape[-1] == value.shape[-1], "Embedding dims of q, k, v, and m are not all the same"

  # Save depth/dimension of the query embedding for scaling down the dot product
  depth = query.shape[-1]

  # Calculate scaled query key dot product (jnp.matmul used for dot product instead of @)
  #jnp.swapaxes is used to transpose the key matrix
  dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

  #Apply the mask
  if mask is not None:
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

  # Softmax formula
  logsumexp = trax.fastmath.logsumexp(dots, axis = -1, keepdims = True)

  # Take exponential of dots minus logsumexp to get softmax
  dots = jnp.exp(dots - logsumexp)

  # Multiply dots by value to get self-attention
  attention = jnp.matmul(dots, value)

  return attention

In [None]:
DotProductAttention(q_with_batch,k_with_batch, v_with_batch, m_bool)

DeviceArray([[[0.96964884, 0.99999994, 0.96964884],
              [0.96964884, 0.99999994, 0.96964884]]], dtype=float32)

## **Causal (Self) Attention**

In [None]:
tensor2d = create_tensor(q)
display_tensor(tensor2d, 'query matrix (2D tensor)')

tensor4d2b = create_tensor([[q,q], [q,q]])
display_tensor(tensor4d2b, 'batch of two (multi-head) collections of query matrices (4D tensor)')

tensor3dc = create_tensor([jnp.concatenate([q,q], axis = -1)])
display_tensor(tensor3dc, 'one batch of concatenated heads of query matrices (3D tensor)')

tensor3dc3b = create_tensor([jnp.concatenate([q,q], axis = -1), jnp.concatenate([q,q], axis = -1), jnp.concatenate([q,q], axis = -1)])
display_tensor(tensor3dc3b, 'three batches of concatenated heads of query matrices (3D tensor)')

query matrix (2D tensor) shape: (2, 3)

[[1 0 1]
 [1 1 0]]

batch of two (multi-head) collections of query matrices (4D tensor) shape: (2, 2, 2, 3)

[[[[1 0 1]
   [1 1 0]]

  [[1 0 1]
   [1 1 0]]]


 [[[1 0 1]
   [1 1 0]]

  [[1 0 1]
   [1 1 0]]]]

one batch of concatenated heads of query matrices (3D tensor) shape: (1, 2, 6)

[[[1 0 1 1 0 1]
  [1 1 0 1 1 0]]]

three batches of concatenated heads of query matrices (3D tensor) shape: (3, 2, 6)

[[[1 0 1 1 0 1]
  [1 1 0 1 1 0]]

 [[1 0 1 1 0 1]
  [1 1 0 1 1 0]]

 [[1 0 1 1 0 1]
  [1 1 0 1 1 0]]]



In [None]:
#d_head - dimensionality of heads
#n_heads - number of attention heads
def compute_attention_heads_closure(n_heads, d_head):

  #returns reshaped tensor with shape (batch_size X n_heads, seqlen, d_head)
  def compute_attention_heads(x):
    #x: tensor with shape (batch_size, seqlen, n_heads X d_head)

    batch_size = x.shape[0]

    seqlen = x.shape[1]

    x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))

    #transpose
    x = jnp.transpose(x, (0, 2, 1, 3))

    x = jnp.reshape(x, (-1, seqlen, d_head))

    return x

  return compute_attention_heads

In [None]:
display_tensor(tensor3dc3b, 'input tensor')
result = compute_attention_heads_closure(2,3)(tensor3dc3b)
display_tensor(result, 'output tensor')

input tensor shape: (3, 2, 6)

[[[1 0 1 1 0 1]
  [1 1 0 1 1 0]]

 [[1 0 1 1 0 1]
  [1 1 0 1 1 0]]

 [[1 0 1 1 0 1]
  [1 1 0 1 1 0]]]

output tensor shape: (6, 2, 3)

[[[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]]



In [None]:
#This function returns masked dot product self attention tensor.
def dot_product_self_attention(q, k,v):

  # mask size should be equal to L_q. Remember that q has shape (batch_size, L_q, d)
  mask_size = q.shape[-2]

  # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
  mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype = jnp.bool_), k=0)

  return DotProductAttention(q, k, v, mask)

In [None]:
dot_product_self_attention(q_with_batch, k_with_batch, v_with_batch)

DeviceArray([[[0.        , 1.        , 0.        ],
              [0.96964884, 0.99999994, 0.96964884]]], dtype=float32)

In [None]:
#Function that simulates environment inside CausalAttention function.
def compute_attention_output_closure(n_heads, d_head):
  def compute_attention_output(x):
    #x: tensor with shape (batch_size X n_heads, seqlen, d_head)
    #Returns:reshaped tensor with shape (batch_size, seqlen, n_heads X d_head)

    seqlen = x.shape[1]
    # Reshape x using jnp.reshape() to shape (batch_size, n_heads, seqlen, d_head)
    x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))

    # Transpose x using jnp.transpose() to shape (batch_size, seqlen, n_heads, d_head)
    x = jnp.transpose(x, (0,2,1,3))

    # Reshape to allow to concatenate the heads
    return jnp.reshape(x, (-1, seqlen, n_heads * d_head))

  return compute_attention_output

In [None]:
display_tensor(result, 'input tensor')
results = compute_attention_output_closure(2,3)(result)
display_tensor(results, 'output tensor')

input tensor shape: (6, 2, 3)

[[[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]

 [[1 0 1]
  [1 1 0]]]

output tensor shape: (3, 2, 6)

[[[1 0 1 1 0 1]
  [1 1 0 1 1 0]]

 [[1 0 1 1 0 1]
  [1 1 0 1 1 0]]

 [[1 0 1 1 0 1]
  [1 1 0 1 1 0]]]



In [None]:
#Returns Transformer-style multi-headed causal attention
#causal_attention or masked multi-head attention
def CausalAttention(d_feature, 
                     n_heads, 
                     compute_attention_heads_closure = compute_attention_heads_closure,
                     dot_product_self_attention = dot_product_self_attention, 
                     compute_attention_output_closure = compute_attention_output_closure,
                     mode = 'train'): 
  
  assert d_feature % n_heads == 0
  d_head = d_feature // n_heads

  ComputeAttentionHeads = tl.Fn('AttnHeads', compute_attention_heads_closure(n_heads, d_head), n_out = 1)

  return tl.Serial(
      tl.Branch(
          [tl.Dense(d_feature), ComputeAttentionHeads], #queries
          [tl.Dense(d_feature), ComputeAttentionHeads], #keys
          [tl.Dense(d_feature), ComputeAttentionHeads], #values
      ),
      tl.Fn('DotProdAttn', dot_product_self_attention, n_out = 1), #QKV
      tl.Fn('AttnOutput', compute_attention_output_closure(n_heads, d_head), n_out = 1), #to allow for parallel
      tl.Dense(d_feature) #final dense layer
  )  

In [None]:
print(CausalAttention(512, 8))

Serial[
  Branch_out3[
    [Dense_512, AttnHeads]
    [Dense_512, AttnHeads]
    [Dense_512, AttnHeads]
  ]
  DotProdAttn_in3
  AttnOutput
  Dense_512
]


## **Transformer Decoder Block**

In [None]:
#d_model:  depth of embedding.
#d_ff: depth of feed-forward layer.
#n_heads: number of attention heads.
#dropout: dropout rate.
#mode: 'train' or 'eval'.
#ff_activation: the non-linearity in feed-forward layer.
#Returns a list of layers that implements a Transformer decoder block.
def DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation):
  causal_attention = CausalAttention(d_model, n_heads = n_heads, mode = mode)

  # Create feed-forward block
  feed_forward = [
                  #Normalize layer inputs
                  tl.LayerNorm(),
                  # Add first feed forward (dense) layer
                  tl.Dense(d_ff),
                  ff_activation(), #ReLu
                  tl.Dropout(rate = dropout, mode = mode),
                  # Add second feed forward layer
                  tl.Dense(d_model),
                  tl.Dropout(rate = dropout, mode = mode)

  ]

  # Add list of two Residual blocks: the attention with normalization and dropout and feed-forward blocks
  return [
          tl.Residual(
              tl.LayerNorm(),
              causal_attention,
              tl.Dropout(rate=dropout, mode=mode) 
  ),
  tl.Residual(
      feed_forward
  ),
  ]

In [None]:
print(DecoderBlock(d_model = 512, d_ff = 2048, n_heads = 8, dropout = 0.1, mode = 'train', ff_activation = tl.Relu))

[Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Serial[
        Branch_out3[
          [Dense_512, AttnHeads]
          [Dense_512, AttnHeads]
          [Dense_512, AttnHeads]
        ]
        DotProdAttn_in3
        AttnOutput
        Dense_512
      ]
      Dropout
    ]
  ]
  Add_in2
], Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Dense_2048
      Relu
      Dropout
      Dense_512
      Dropout
    ]
  ]
  Add_in2
]]


## **Transformer Language Model**

In [None]:
#Returns a Transformer language model.
def TransformerLM(vocab_size = 33300, d_model = 512, d_ff = 2048, n_layers = 6, n_heads = 8, dropout = 0.1, max_len = 4096, mode = 'train', ff_activation = tl.Relu):

  # Embedding inputs and positional encoder
  positional_encoder = [
                        # Add embedding layer of dimension (vocab_size, d_model)
                        tl.Embedding(vocab_size, d_model),
                        tl.Dropout(rate=dropout, mode=mode),
                        # Add positional encoding layer with maximum input length and mode specified
                        tl.PositionalEncoding(max_len = max_len, mode = mode)]

  # Create stack (list) of decoder blocks with n_layers with necessary parameters
  decoder_blocks = [
                    DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers)]

  return tl.Serial(
    # Use teacher forcing (feed output of previous step to current step)
    tl.ShiftRight(mode = mode),
    positional_encoder,
    decoder_blocks,
    tl.LayerNorm(),
    tl.Dense(vocab_size),
    tl.LogSoftmax()
  )

In [None]:
print(TransformerLM(n_layers=1))

Serial[
  ShiftRight(1)
  Embedding_33300_512
  Dropout
  PositionalEncoding
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Serial[
          Branch_out3[
            [Dense_512, AttnHeads]
            [Dense_512, AttnHeads]
            [Dense_512, AttnHeads]
          ]
          DotProdAttn_in3
          AttnOutput
          Dense_512
        ]
        Dropout
      ]
    ]
    Add_in2
  ]
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Dense_2048
        Relu
        Dropout
        Dense_512
        Dropout
      ]
    ]
    Add_in2
  ]
  LayerNorm
  Dense_33300
  LogSoftmax
]


# **Training**

In [None]:
from trax.supervised import training

def training_loop(TransformerLM, train_gen, eval_gen, output_dir = '/content/model'):
  
  output_dir = os.path.expanduser(output_dir)  # trainer is an object
  lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01)

  train_task = training.TrainTask( 
      labeled_data=train_gen, # The training generator
      loss_layer=tl.CrossEntropyLoss(), # Loss function 
      optimizer=trax.optimizers.Adam(0.01), # Optimizer 
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=10
    )
  
  eval_task = training.EvalTask( 
      labeled_data=eval_gen, # The evaluation generator
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] # CrossEntropyLoss and Accuracy
    )
  
  loop = training.Loop(TransformerLM( vocab_size = 33300, d_model = 512, d_ff = 2048, n_layers = 6, n_heads = 8, dropout = 0.1, max_len = 4096, mode = 'train', ff_activation = tl.Relu),
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
  
  return loop

In [None]:
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream, output_dir = '/content/model')
loop.run(5)


Step      1: Ran 1 train steps in 102.14 secs
Step      1: train CrossEntropyLoss |  10.44239140
Step      1: eval  CrossEntropyLoss |  10.40130138
Step      1: eval          Accuracy |  0.00000000


# **4. Evaluation**

In [None]:
# Get the model architecture

model_trained = TransformerLM(mode='eval')

# Load the pre-trained weights
model_trained.init_from_file('/content/model/model.pkl.gz', weights_only=True)

# **5. Testing with your own input**

In [None]:
#Returns the next symbol for a given sentence.
#cur_output_tokens (list): tokenized sentence with EOS and PAD tokens at the end.
#model: The transformer model.
def next_symbol(cur_output_tokens, model):
    
    # current output tokens length
    token_length = len(cur_output_tokens)
    # calculate the minimum power of 2 big enough to store token_length
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    # Fill cur_output_tokens with 0's until it reaches padded_length
    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] #'None' is a way of setting the batch dim

    # model expects a tuple containing two padded tensors (with batch)
    output, _ = model((padded_with_batch, padded_with_batch)) 

    # HINT: output has shape (1, padded_length, vocab_size)
    # To get log_probs you need to index output with 0 in the first dim
    # token_length in the second dim and all of the entries for the last dim.
    log_probs = output[0, token_length, :]
    
    return int(np.argmax(log_probs))

In [None]:
sentence_test_nxt_symbl = "I want to fly in the sky."
detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model_trained)])

'轼'

# **Greedy Decoding**

In [None]:
#Returns:summary of the input.
#input_sentence (string): a sentence or article.
#model: Transformer model.
def greedy_decode(input_sentence, model):

    cur_output_tokens = tokenize(input_sentence) + [0]
    generated_output = [] 
    cur_output = 0 
    EOS = 1 
    
    while cur_output != EOS:
        # Get next symbol
        cur_output = next_symbol(cur_output_tokens, model)
        # Append next symbol to original sentence
        cur_output_tokens.append(cur_output)
        # Append next symbol to generated sentence
        generated_output.append(cur_output)
        print(detokenize(generated_output))
    return detokenize(generated_output)

In [None]:
# Test it out on a sentence!
test_sentence = "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model_trained))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
則則okpensipensipensi優pensipensipensi則則則勤勤wrestlewrestledeclining
declining 則則則則則則則則則Von OreOreOreOreOredeclining declining declining
declining declining declining declining declining declining Von Von
窺socio declining declining declining declining declining declining
declining declining 했했했했declining declining declining
EkEkEk勤勤Monitoring Monitoring Ask Brad 했했Monitoring 勤勤勤窺declining
declining declining declining declining declining declining declining
Monitoring Monitoring Monitoring Monitoring Monitoring Monitoring
Monitoring Monitoring Monitoring Monitoring computed computed computed
Monitoring Monitoring Monitoringोोyearyear declining declining
declining declining 했勤勤勤勤勤declining 勤year wrestle勤declining declining
decliningोोोो했했했decliningMonitoring Monitoring
Monitoringोोोोोोोोोोोो했했했Monitoringोdecliningdeclining declining
declining declining declining declining declining declining declining
했했했했declining declining de

KeyboardInterrupt: ignored