## Tutorial on how to use tensorflow-addons BasicDecoder and BeamSearchDecoder classes

In [None]:
!pip install tensorflow-addons



In [1]:
import tensorflow as tf
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split

import unicodedata
import re
import numpy as np
import os
import io
import time
from utils.dataset import NMTDataset

In [2]:

BUFFER_SIZE = 32000
BATCH_SIZE = 64
num_examples = 30000

dataset_creator = NMTDataset('en-spa')
train_dataset, val_dataset, inp_lang, targ_lang = dataset_creator.call(num_examples, BUFFER_SIZE, BATCH_SIZE)

Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip


In [3]:
example_input_batch, example_target_batch = next(iter(train_dataset))
example_input_batch.shape, example_target_batch.shape

(TensorShape([64, 16]), TensorShape([64, 11]))

### Create a tf.data dataset

In [4]:
vocab_inp_size = len(inp_lang.word_index)+1
vocab_tar_size = len(targ_lang.word_index)+1
max_length_input = example_input_batch.shape[1]
max_length_output = example_target_batch.shape[1]

embedding_dim = 256
units = 1024
steps_per_epoch = num_examples//BATCH_SIZE


In [None]:
max_length_input, max_length_output, vocab_tar_size, vocab_inp_size

(16, 11, 4936, 9415)

In [5]:
##### 

class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    super(Encoder, self).__init__()
    self.batch_sz = batch_sz
    self.enc_units = enc_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

    ##-------- 2 Bidirectional LSTM Layers in Encoder ------- ##
    self.lstm_layer = tf.keras.layers.LSTM(self.enc_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    


  def call(self, x, hidden):
    x = self.embedding(x)
    output, h, c = self.lstm_layer(x, initial_state = hidden)
    return output, h, c

  def initialize_hidden_state(self):
    return [tf.zeros((self.batch_sz, self.enc_units)), tf.zeros((self.batch_sz, self.enc_units))] 

In [6]:
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)


# sample input
sample_hidden = encoder.initialize_hidden_state()
sample_output, sample_h, sample_c = encoder(example_input_batch, sample_hidden)
print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
print ('Encoder h vecotr shape: (batch size, units) {}'.format(sample_h.shape))
print ('Encoder c vector shape: (batch size, units) {}'.format(sample_c.shape))

Encoder output shape: (batch size, sequence length, units) (64, 16, 1024)
Encoder h vecotr shape: (batch size, units) (64, 1024)
Encoder c vector shape: (batch size, units) (64, 1024)


In [7]:
class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, attention_type='luong'):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.attention_type = attention_type
    
    # Embedding Layer
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    
    #Final Dense layer on which softmax will be applied
    self.fc = tf.keras.layers.Dense(vocab_size)

    # Define the fundamental cell for decoder recurrent structure
    self.decoder_rnn_cell = tf.keras.layers.LSTMCell(self.dec_units)
   


    # Sampler
    self.sampler = tfa.seq2seq.sampler.TrainingSampler()

    # Create attention mechanism with memory = None
    self.attention_mechanism = self.build_attention_mechanism(self.dec_units, 
                                                              None, self.batch_sz*[max_length_input], self.attention_type)

    # Wrap attention mechanism with the fundamental rnn cell of decoder
    self.rnn_cell = self.build_rnn_cell(batch_sz)

    # Define the decoder with respect to fundamental rnn cell
    self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler=self.sampler, output_layer=self.fc)

    
  def build_rnn_cell(self, batch_sz):
    rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnn_cell, 
                                  self.attention_mechanism, attention_layer_size=self.dec_units)
    return rnn_cell

  def build_attention_mechanism(self, dec_units, memory, memory_sequence_length, attention_type='luong'):
    # ------------- #
    # typ: Which sort of attention (Bahdanau, Luong)
    # dec_units: final dimension of attention outputs 
    # memory: encoder hidden states of shape (batch_size, max_length_input, enc_units)
    # memory_sequence_length: 1d array of shape (batch_size) with every element set to max_length_input (for masking purpose)

    if(attention_type=='bahdanau'):
      return tfa.seq2seq.BahdanauAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
    else:
      return tfa.seq2seq.LuongAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)

  def build_initial_state(self, batch_sz, encoder_state, Dtype):
    decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=batch_sz, dtype=Dtype)
    decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
    return decoder_initial_state


  def call(self, inputs, initial_state):
    # Setup attention mechanism's memory
    x = self.embedding(inputs)
    outputs, _, _ = self.decoder(x, initial_state=initial_state, sequence_length=self.batch_sz*[max_length_output-1])
    return outputs



In [8]:
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE, 'luong')
sample_x = tf.random.uniform((BATCH_SIZE, max_length_output))
decoder.attention_mechanism.setup_memory(sample_output)
initial_state = decoder.build_initial_state(BATCH_SIZE, [sample_h, sample_c], tf.float32)


sample_decoder_outputs = decoder(sample_x, initial_state)

print("Decoder Outputs Shape: ", sample_decoder_outputs.rnn_output.shape)


Decoder Outputs Shape:  (64, 10, 4936)


## Define the optimizer and the loss function

In [9]:
optimizer = tf.keras.optimizers.Adam()


def loss_function(real, pred):
  # real shape = (BATCH_SIZE, max_length_output)
  # pred shape = (BATCH_SIZE, max_length_output, tar_vocab_size )
  cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
  loss = cross_entropy(y_true=real, y_pred=pred)
  mask = tf.logical_not(tf.math.equal(real,0))   #output 0 for y=0 else output 1
  mask = tf.cast(mask, dtype=loss.dtype)  
  loss = mask* loss
  loss = tf.reduce_mean(loss)
  return loss  

## Checkpoints (Object-based saving)

In [11]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [10]:
@tf.function
def train_step(inp, targ, enc_hidden):
  loss = 0

  with tf.GradientTape() as tape:
    enc_output, enc_h, enc_c = encoder(inp, enc_hidden)


    dec_input = targ[ : , :-1 ] # Ignore <end> token
    real = targ[ : , 1: ]         # ignore <start> token

    # Set the AttentionMechanism object with encoder_outputs
    decoder.attention_mechanism.setup_memory(enc_output)

    # Create AttentionWrapperState as initial_state for decoder
    decoder_initial_state = decoder.build_initial_state(BATCH_SIZE, [enc_h, enc_c], tf.float32)
    pred = decoder(dec_input, decoder_initial_state)
    logits = pred.rnn_output
    loss = loss_function(real, logits)

  variables = encoder.trainable_variables + decoder.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))

  return loss

In [12]:
EPOCHS = 10

for epoch in range(EPOCHS):
  start = time.time()

  enc_hidden = encoder.initialize_hidden_state()
  total_loss = 0
  # print(enc_hidden[0].shape, enc_hidden[1].shape)

  for (batch, (inp, targ)) in enumerate(train_dataset.take(steps_per_epoch)):
    batch_loss = train_step(inp, targ, enc_hidden)
    total_loss += batch_loss

    if batch % 100 == 0:
      print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                   batch,
                                                   batch_loss.numpy()))
  # saving (checkpoint) the model every 2 epochs
  if (epoch + 1) % 2 == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

  print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Epoch 1 Batch 0 Loss 5.0891
Epoch 1 Batch 100 Loss 2.2703
Epoch 1 Batch 200 Loss 1.9739
Epoch 1 Batch 300 Loss 1.9094
Epoch 1 Loss 1.7286
Time taken for 1 epoch 33.77519965171814 sec

Epoch 2 Batch 0 Loss 1.5886
Epoch 2 Batch 100 Loss 1.5343
Epoch 2 Batch 200 Loss 1.5672
Epoch 2 Batch 300 Loss 1.4411
Epoch 2 Loss 1.1849
Time taken for 1 epoch 28.854251861572266 sec

Epoch 3 Batch 0 Loss 1.2023
Epoch 3 Batch 100 Loss 1.0288
Epoch 3 Batch 200 Loss 1.1784
Epoch 3 Batch 300 Loss 1.0025
Epoch 3 Loss 0.8803
Time taken for 1 epoch 28.387887477874756 sec

Epoch 4 Batch 0 Loss 0.7196
Epoch 4 Batch 100 Loss 0.9163
Epoch 4 Batch 200 Loss 0.6681
Epoch 4 Batch 300 Loss 0.7946
Epoch 4 Loss 0.6216
Time taken for 1 epoch 29.015039443969727 sec

Epoch 5 Batch 0 Loss 0.5510
Epoch 5 Batch 100 Loss 0.5196
Epoch 5 Batch 200 Loss 0.6549
Epoch 5 Batch 300 Loss 0.5678
Epoch 5 Loss 0.4362
Time taken for 1 epoch 28.51006507873535 sec

Epoch 6 Batch 0 Loss 0.3410
Epoch 6 Batch 100 Loss 0.3889
Epoch 6 Batch 200 L

## Use tf-addons BasicDecoder for decoding


In [13]:
def evaluate_sentence(sentence):
  sentence = dataset_creator.preprocess_sentence(sentence)

  inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
  inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
                                                         maxlen=max_length_input,
                                                         padding='post')
  inputs = tf.convert_to_tensor(inputs)
  inference_batch_size = inputs.shape[0]
  result = ''

  enc_start_state = [tf.zeros((inference_batch_size, units)), tf.zeros((inference_batch_size,units))]
  enc_out, enc_h, enc_c = encoder(inputs, enc_start_state)

  dec_h = enc_h
  dec_c = enc_c

  start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
  end_token = targ_lang.word_index['<end>']

  greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler()

  # Make Decoder input from <start> tokens. 
  # Decoder input shape = (inference_batch_size, 1)
  decoder_input = tf.expand_dims([targ_lang.word_index['<start>']]*inference_batch_size, axis=1)

  
  decoder_emb_input = decoder.embedding(decoder_input)

  # Instantiate BasicDecoder object
  decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, sampler=greedy_sampler, output_layer=decoder.fc)
  decoder.attention_mechanism.setup_memory(enc_out)
  
  # pass [ last step activations , encoder memory_state ] as input to decoder for LSTM
  decoder_initial_state = decoder.build_initial_state(inference_batch_size, [enc_h, enc_c], tf.float32)
  
  # Since we do not know the target sequence lengths in advance, we use maximum_iterations to limit the translation lengths.
  # One heuristic is to decode up to two times the source sentence lengths.
  maximum_iterations = tf.round(tf.reduce_max(max_length_input) * 2)

  #initialize inference decoder
  ### Since the BasicDecoder wraps around Decoder's rnn cell only, we have to ensure that the inputs to BasicDecoder 
  ### decoding step is output of embedding layer. tfa.seq2seq.GreedyEmbeddingSampler() takes care of this. 
  ### We only need to get the weights of embedding layer, which can be done by decoder.embedding.variables[0] and pass this callable
  ### to decoder_instance.initialize() function.

  decoder_embedding_matrix = decoder.embedding.variables[0]
  # first_inputs = embedding of <start_token>
  (first_finished, first_inputs, first_state) = decoder_instance.initialize(decoder_embedding_matrix, 
                                                                            start_tokens = start_tokens, 
                                                                            end_token= end_token,
                                                                           initial_state=decoder_initial_state)
  inputs = first_inputs
  state = first_state
  predictions = np.empty((inference_batch_size,0), dtype=np.int32)
  for t in range(maximum_iterations):
    outputs, next_state, next_inputs, finished = decoder_instance.step(t, inputs, state)
    # print(outputs.sample_id, next_inputs.shape, finished)
    inputs = next_inputs
    state = next_state
    outputs = tf.expand_dims(outputs.sample_id, axis=-1)
    predictions = np.append(predictions, outputs, axis=-1)
    # If <end> token has been predicted then stop here.
    if(finished):
      break

  return predictions

def translate(sentence):
  result = evaluate_sentence(sentence)
  print(result)
  result = targ_lang.sequences_to_texts(result)
  print('Input: %s' % (sentence))
  print('Predicted translation: {}'.format(result))

## Restore the latest checkpoint and test

In [14]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd9a7727f60>

In [15]:
translate(u'hace mucho frio aqui.')

[[ 11  12  49 184  40   4   3]]
Input: hace mucho frio aqui.
Predicted translation: ['it s very cold here . <end>']


In [16]:
translate(u'esta es mi vida.')

[[ 20   9  22 190   4   3]]
Input: esta es mi vida.
Predicted translation: ['this is my life . <end>']


In [None]:
translate(u'¿todavia estan en casa?')

Input: ¿todavia estan en casa?
Predicted translation: ['are you still home ? <end>']


In [None]:
# wrong translation
translate(u'trata de averiguarlo.')

Input: trata de averiguarlo.
Predicted translation: ['try to figure it out . <end>']


## Use tf-addons BeamSearchDecoder 



In [17]:
def beam_evaluate_sentence(sentence, beam_width=3):
  sentence = dataset_creator.preprocess_sentence(sentence)

  inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
  inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
                                                         maxlen=max_length_input,
                                                         padding='post')
  inputs = tf.convert_to_tensor(inputs)
  inference_batch_size = inputs.shape[0]
  result = ''

  enc_start_state = [tf.zeros((inference_batch_size, units)), tf.zeros((inference_batch_size,units))]
  enc_out, enc_h, enc_c = encoder(inputs, enc_start_state)

  dec_h = enc_h
  dec_c = enc_c

  start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
  end_token = targ_lang.word_index['<end>']

  # From official documentation
  #NOTE If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:
  #The encoder output has been tiled to beam_width via tfa.seq2seq.tile_batch (NOT tf.tile).
  #The batch_size argument passed to the get_initial_state method of this wrapper is equal to true_batch_size * beam_width.
  #The initial state created with get_initial_state above contains a cell_state value containing properly tiled final state from the encoder.

  enc_out = tfa.seq2seq.tile_batch(enc_out, multiplier=beam_width)
  decoder.attention_mechanism.setup_memory(enc_out)
  print("beam_with * [batch_size, max_length_input, rnn_units] :  3 * [1, 16, 1024]] :", enc_out.shape)
  
  #set decoder_inital_state which is an AttentionWrapperState considering beam_width
  hidden_state = tfa.seq2seq.tile_batch([enc_h, enc_c], multiplier=beam_width)
  decoder_initial_state = decoder.rnn_cell.get_initial_state(batch_size=beam_width*inference_batch_size, dtype=tf.float32)
  decoder_initial_state = decoder_initial_state.clone(cell_state=hidden_state)

  # Instantiate BeamSearchDecoder
  decoder_instance = tfa.seq2seq.BeamSearchDecoder(decoder.rnn_cell,beam_width=beam_width,
                                                 output_layer=decoder.fc)

  # Make Decoder input from <start> tokens. 
  # Decoder input shape = (inference_batch_size, 1)
  decoder_input = tf.expand_dims([targ_lang.word_index['<start>']]*inference_batch_size, axis=1)
  decoder_emb_input = decoder.embedding(decoder_input)

  
  
  # pass [ last step activations , encoder memory_state ] as input to decoder for LSTM
  
  
  # Since we do not know the target sequence lengths in advance, we use maximum_iterations to limit the translation lengths.
  # One heuristic is to decode up to two times the source sentence lengths.
  maximum_iterations = tf.round(tf.reduce_max(max_length_input) * 2)

  #initialize inference decoder
  decoder_embedding_matrix = decoder.embedding.variables[0]
  # first_inputs = embedding of <start_token>
  (first_finished, first_inputs, first_state) = decoder_instance.initialize(decoder_embedding_matrix, 
                                                                            start_tokens = start_tokens, 
                                                                            end_token= end_token,
                                                                           initial_state=decoder_initial_state)
  inputs = first_inputs
  state = first_state
  predictions = np.empty((inference_batch_size,beam_width, 0), dtype=np.int32)
  beam_scores =  np.empty((inference_batch_size, beam_width,0), dtype = np.float32)
  for t in range(maximum_iterations):
    beam_search_outputs, next_state, next_inputs, finished = decoder_instance.step(t, inputs, state)
    # print(outputs.sample_id, next_inputs.shape, finished)
    inputs = next_inputs
    state = next_state
    outputs = tf.expand_dims(beam_search_outputs.predicted_ids, axis=-1)
    scores = tf.expand_dims(beam_search_outputs.scores, axis=-1)
    predictions = np.append(predictions, outputs, axis=-1)
    beam_scores = np.append(beam_scores, scores, axis=-1)
  return predictions, beam_scores



In [37]:
def beam_translate(sentence):
  result, beam_scores = beam_evaluate_sentence(sentence)
  print(result.shape, beam_scores.shape)
  for beam, score in zip(result, beam_scores):
    print(beam.shape, score.shape)
    output = targ_lang.sequences_to_texts(beam)
    output = [a[:a.index('<end>')] for a in output]
    beam_score = [a.sum() for a in score]
    print('Input: %s' % (sentence))
    for i in range(len(output)):
      print('{} Predicted translation: {}  {}'.format(i+1, output[i], beam_score[i]))


In [38]:
beam_translate(u'hace mucho frio aqui.')

beam_with * [batch_size, max_length_input, rnn_units] :  3 * [1, 16, 1024]] : (3, 16, 1024)
(1, 3, 32) (1, 3, 32)
(3, 32) (3, 32)
Input: hace mucho frio aqui.
1 Predicted translation: it s is cold here .   -63.086402893066406
2 Predicted translation: the lot very cold here .   -69.65613555908203
3 Predicted translation: that gets very cold here .   -83.43080139160156


In [40]:
beam_translate(u'¿todavia estan en casa?')

beam_with * [batch_size, max_length_input, rnn_units] :  3 * [1, 16, 1024]] : (3, 16, 1024)
(1, 3, 32) (1, 3, 32)
(3, 32) (3, 32)
Input: ¿todavia estan en casa?
1 Predicted translation: are you still at home ?   -27.622243881225586
2 Predicted translation: is he still at home ?   -58.77956008911133
3 Predicted translation: aren it still busy ?   -90.82772827148438
