In [1]:
!pip install Trax

Collecting Trax
[?25l  Downloading https://files.pythonhosted.org/packages/42/51/305b839f51d53abb393777f743e497d27bb341478f3fdec4d6ddaccc9fb5/trax-1.3.7-py2.py3-none-any.whl (521kB)
[K     |▋                               | 10kB 26.7MB/s eta 0:00:01[K     |█▎                              | 20kB 23.9MB/s eta 0:00:01[K     |█▉                              | 30kB 12.1MB/s eta 0:00:01[K     |██▌                             | 40kB 9.9MB/s eta 0:00:01[K     |███▏                            | 51kB 7.5MB/s eta 0:00:01[K     |███▊                            | 61kB 7.9MB/s eta 0:00:01[K     |████▍                           | 71kB 8.3MB/s eta 0:00:01[K     |█████                           | 81kB 8.3MB/s eta 0:00:01[K     |█████▋                          | 92kB 8.1MB/s eta 0:00:01[K     |██████▎                         | 102kB 8.6MB/s eta 0:00:01[K     |███████                         | 112kB 8.6MB/s eta 0:00:01[K     |███████▌                        | 122kB 8.6MB/s eta 0:

In [2]:
import sys
import os
import numpy as np

import textwrap

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp
from trax.supervised import training

wrapper = textwrap.TextWrapper(width=70)
np.set_printoptions(threshold=sys.maxsize)

In [3]:
train_stream_fn = trax.data.TFDS('cnn_dailymail',
                                 data_dir='data/',
                                 keys=('article', 'highlights'),
                                 train=True)

eval_stream_fn = trax.data.TFDS('cnn_dailymail',
                                data_dir='data/',
                                keys=('article', 'highlights'),
                                train=False)

[1mDownloading and preparing dataset 558.32 MiB (download: 558.32 MiB, generated: 1.27 GiB, total: 1.82 GiB) to data/cnn_dailymail/3.1.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=0.0, description='Generating splits...', max=3.0, style=ProgressStyle(descr…

HBox(children=(FloatProgress(value=0.0, description='Generating train examples...', max=287113.0, style=Progre…

HBox(children=(FloatProgress(value=0.0, description='Shuffling cnn_dailymail-train.tfrecord...', max=287113.0,…

HBox(children=(FloatProgress(value=0.0, description='Generating validation examples...', max=13368.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Shuffling cnn_dailymail-validation.tfrecord...', max=1336…

HBox(children=(FloatProgress(value=0.0, description='Generating test examples...', max=11490.0, style=Progress…

HBox(children=(FloatProgress(value=0.0, description='Shuffling cnn_dailymail-test.tfrecord...', max=11490.0, s…

[1mDataset cnn_dailymail downloaded and prepared to data/cnn_dailymail/3.1.0. Subsequent calls will reuse this data.[0m


In [4]:
def tokenize(input_str, EOS=1):
    """Input str to features dict, ready for inference"""
    inputs =  next(trax.data.tokenize(iter([input_str]), vocab_dir='.', vocab_file='summarize32k.subword.subwords'))
    return list(inputs) + [EOS]

def detokenize(integers):
    """List of ints to str"""
    s = trax.data.detokenize(integers, vocab_dir='.', vocab_file='summarize32k.subword.subwords')
    return wrapper.fill(s)

In [5]:
EOS = 1 # End of sentence token

def preprocess(stream):
    """Concatenate tokenized inputs and targets using 0 as separator."""
    for (article, summary) in stream:
        input_article = np.array(list(article) + [EOS])
        target_summary = np.array(list(summary) + [EOS])
        
        yield input_article, target_summary

input_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_dir='.', vocab_file='summarize32k.subword.subwords'),
    preprocess,
    trax.data.FilterByLength(2048)
)

train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

In [6]:
train_input, train_target = next(train_stream)
print(train_target)

[ 6945  5517  2869  9783   158    11  6940  5539  6340   922   691  1819
  3118  1572 16346 27439  6774  1628   368  8627 20373     4 20872    21
   492  6460  1019  3074     7     5   580 25541   917 16346 27439  6774
  1628    69   127   368  1550   117  1003  8404    51   790     7    26
  1194  3445   186 14661  6053 27439  6774  1628  7057  3074  1299 24882
   368  8627 20373     4   132   163 12339   922 10038  2104     1]


In [7]:
print(detokenize(train_target))

EXCLUSIVE: Source reveals extraordinary call by ex Prime Minister . Mr
Umunna blamed former PM for Labour's economic credibility problem . He
said Mr Brown 'gave impression we didn't understand debt and deficit'
Former Labour leader confronted Mr Umunna in an angry call afterwards
.<EOS>


In [8]:
# Bucketing to create batched generators. 
boundaries =  [128, 256,  512, 1024]
batch_sizes = [16,    8,    4,    2, 1]

train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(eval_stream)

In [9]:
train_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)
eval_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(eval_batch_stream)

In [10]:
input_batch, target_batch, loss_weight = next(train_batch_stream)

In [11]:
print(loss_weight)
print(target_batch.shape)

[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
(1, 117)


In [12]:
print(input_batch[0])

[  567   379 20529    20 25360     4   379  7226  5182  3047  6611   136
  4601     3  2879   180  1725 16958     4     2   118   389   420 11969
 28081   379  9720 22449  3590  4601     3   406   180   137 16958     4
     2   118   389   420   379  2702 14815     4    11    69    39  1151
  1537  1687  1346    88   226    28   984  1019   213   382   104   379
     9  8998   527  1651 29725     4     5  2448  3558   458    39   952
   320  1555    15 11001   527  1687  1346    88   226    28   984  1019
   213   382   104  1786    15 24701    16 20240  7356     3  2702 14815
     4     2  1779    23 15472   412  2028  2462   527  9213  2640   102
    28 21611  7433 20376     2    39  1555    28  8749 10453 19331  1378
   239  1687    32    10  5411   446   132   585     3  2764  3112     2
  2671   318  1641   527 21611  7433  5251  2214   320   179    15  1961
     6  5401    17  1153  5865  1378    61   320  1687    65    10  1313
   446   220   104     3    52  1353   213  2150  2

In [13]:
print(detokenize(input_batch[0]))

By . Becky Barrow . PUBLISHED: . 03:39 EST, 8 May 2012 . | . UPDATED:
. 18:10 EST, 8 May 2012 . Andrew Moss: He will be paid £80,000 a month
for the next year . The boss of Britain’s biggest insurance company
will continue to receive his salary of £80,000 a month for the next
year despite his humiliating resignation yesterday. Andrew Moss, who
has quit as chief executive of Aviva after a shareholder revolt, will
receive a golden goodbye worth around £1.75million in total. Last
Thursday, 59 per cent of shareholder votes failed to back his gold-
plated pay package worth up to £5.2million last year. It was the
latest chapter in the growing backlash against boardroom greed,
nicknamed the Shareholder Spring. Yesterday the 54-year-old chief
executive said he ‘felt it was in the best interests of the company
that he step aside to make way for new leadership’. But Mr Moss, who
has also sparked public criticism for leaving his wife of 25 years and
their four children for a junior married collea

In [14]:
def Transformer_Model(input_vocab_size=33300,
                  output_vocab_size=33300,
                  d_model=512,
                  d_ff=2048,
                  n_encoder_layers=6,
                  n_decoder_layers=6,
                  n_heads=8,
                  max_len=4096,
                  dropout=0.1,
                  mode='train',
                  ff_activation=tl.Relu):
    """Returns a Transformer language model.
    Args:
        vocab_size: vocab size.
        d_model:  depth of embedding.
        d_ff: depth of feed-forward layer.
        n_layers: number of decoder layers.
        n_heads: number of attention heads.
        dropout: dropout rate (how much to drop out).
        max_len: maximum symbol length for positional encoding.
        mode: 'train', 'eval' or 'predict', predict mode is for fast inference.
        ff_activation: the non-linearity in feed-forward layer.
    """
    return tl.Serial(trax.models.transformer.Transformer(input_vocab_size=input_vocab_size,
                                                         output_vocab_size=output_vocab_size,
                                                         d_model=d_model,
                                                         d_ff=d_ff,
                                                         n_encoder_layers=n_encoder_layers,
                                                         n_decoder_layers=n_decoder_layers,
                                                         n_heads=n_heads,
                                                         max_len=max_len,
                                                         dropout=dropout,
                                                         mode=mode,
                                                         ff_activation=ff_activation),
                     tl.LogSoftmax())

In [15]:
print(Transformer_Model(n_encoder_layers=1, n_decoder_layers=1))

Serial_in2_out2[
  Serial_in2_out2[
    Select[0,1,1]_in2_out3
    Branch_out2[
      []
      Serial[
        PaddingMask(0)
      ]
    ]
    Serial_in2_out2[
      Embedding_33300_512
      Dropout
      PositionalEncoding
      Serial_in2_out2[
        Branch_in2_out3[
          None
          Serial_in2_out2[
            LayerNorm
            Serial_in2_out2[
              _in2_out2
              Serial_in2_out2[
                Select[0,0,0]_out3
                Serial_in4_out2[
                  _in4_out4
                  Serial_in4_out2[
                    Parallel_in3_out3[
                      Dense_512
                      Dense_512
                      Dense_512
                    ]
                    PureAttention_in4_out2
                    Dense_512
                  ]
                  _in2_out2
                ]
              ]
              _in2_out2
            ]
            Dropout
          ]
        ]
        Add_in2
      ]
      Serial[
        Branch_ou

In [34]:
def training_loop(Transformer, train_gen, eval_gen, output_dir = "."):
    '''
    Returns Training Loop
    Input:
        TransformerLM : The model you are building.
        train_gen: Training stream of data.
        eval_gen: Evaluation stream of data.
        output_dir: folder to save your file.
    '''
    output_dir = os.path.expanduser(output_dir)
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=4000, max_value=0.01)

    train_task = training.TrainTask( 
      labeled_data=train_gen,
      loss_layer=tl.CrossEntropyLoss(),
      optimizer=trax.optimizers.Adam(0.01),
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=10
    )

    eval_task = training.EvalTask( 
      labeled_data=eval_gen,
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]
    )

    loop = training.Loop(Transformer(),
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    
    return loop

In [35]:
loop = training_loop(Transformer_Model, train_batch_stream, eval_batch_stream)
loop.run(30)


Step     40: Ran 10 train steps in 469.55 secs
Step     40: train CrossEntropyLoss |  9.75104141
Step     40: eval  CrossEntropyLoss |  9.64670181
Step     40: eval          Accuracy |  0.02453988

Step     50: Ran 10 train steps in 376.45 secs
Step     50: train CrossEntropyLoss |  9.40565491
Step     50: eval  CrossEntropyLoss |  8.97205830
Step     50: eval          Accuracy |  0.03896104

Step     60: Ran 10 train steps in 379.98 secs
Step     60: train CrossEntropyLoss |  8.81313992
Step     60: eval  CrossEntropyLoss |  8.34542370
Step     60: eval          Accuracy |  0.04201681


In [36]:
model = Transformer_Model(mode='eval')
# Load the pre-trained weights
model.init_from_file('model.pkl.gz', weights_only=True)

In [44]:
MAX_LENGTH = 30 # Maximum length approvable for summary of an article.

In [45]:
def next_symbol(cur_input_tokens, cur_output_tokens, model):
    """Returns the next symbol for a given sentence.
    Args:
        cur_output_tokens: tokenized sentence with EOS and PAD tokens at the end.
        model: The transformer model.
    """
    # current output tokens length
    token_length = len(cur_output_tokens)
    
    # calculate the minimum power of 2 big enough to store token_length
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :]



    output = model((cur_input_tokens, padded_with_batch)) 
    log_probs = output[0][0, token_length, :]
    
    return int(np.argmax(log_probs))

In [48]:
def greedy_decode(input_sentence, model):
    """Greedy decode function.
    Args:
        input_sentence: a sentence or article.
        model: Transformer model.
    """
    cur_input_tokens = tokenize(input_sentence) + [0]
    cur_output_tokens = [-1]
    generated_output = [] 
    cur_output = 0 
    EOS = 1
    i = 0

    token_length = len(cur_input_tokens)
    # calculate the minimum power of 2 big enough to store token_length
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    padded = cur_input_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :]

    while cur_output != EOS and i < MAX_LENGTH:
        cur_output = next_symbol(padded_with_batch, cur_output_tokens, model)
        cur_output_tokens.append(cur_output)
        generated_output.append(cur_output)
        i = i + 1
    
    return detokenize(generated_output)

In [49]:
# It will not give the result i want. It will output some random text/numbers.
test_sentence = "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))

It was a sunny day when I went to the market to buy some flowers. But
I only found roses, not tulips. 

 . . . . . . . . . . . .101010101010101010101010101010101010
