# TF2 ByteNet
This is a notebook that shows an example implementation of the ByteNet architecture in TensorFlow2. It has the flow of the [NMT tutorial](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/text/nmt_with_attention.ipynb#scrollTo=TNfHIF71ulLu) from the TensorFlow website. Comments and explanatory notes are provided. Note that this notebook, or the TF2 version of ByteNet was not used for experiments by the group. Experiments were done with the PyTorch implementations instead. 




## Imports and Path Definitions 

Add the necessary imports and path definitions here

In [0]:
#@title Imports { form-width: "150px" }

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split

import unicodedata
import string
import re
import numpy as np
import os
import io
import time

In [0]:
#@title Define Paths { form-width: "150px" }
MOUNT_PATH = # DEFINE MOUNT PATH HERE
WORK_DIR = MOUNT_PATH + # DEFINE WORK_DIR
CKPT_DIR = WORK_DIR + #DEFINE CKPT_DIR
MODEL_DIR = WORK_DIR + #DEFINE MODEL_DIR
DATA_DIR = WORK_DIR + #DEFINE DATA_DIR


In [0]:
#@title Mount Drive { form-width: "150px" }
from google.colab import drive
drive.mount(MOUNT_PATH, force_remount=True)

## Preprocessing
ByteNet does dynamic unfolding. This can be done when the data is being read in.

Inputs are also cleaned and processed to normalise Unicode duplicate representations. Use NFKC as opposed to NFD, because you don't want to decompose umlauts and accents. <br> 
Regex taken from [here](https://stackoverflow.com/questions/20690499/concrete-javascript-regex-for-accented-characters-diacritics) to make sure input does contain language characters you wouldn't see in Latin-based languages. Start of sequence (SOS), end of sequence (EOS) and padding char (used in dynamic unfolding) are also defined here. <br>
Strings are made into lowercase to reduce dimensionality of character representation, because in this example notebook, we are using small amounts of data.  


In [0]:
#@title Unicode to ASCII and Preprocess { form-width: "150px" }

SOS_CHAR = "\1"
PAD_CHAR = "\2"
EOS_CHAR = "\3"

# Converts the unicode to ascii
def unicode_to_ascii(s):
  
  # NFKC - represent characters that have the same meaning
  # with the same representation in canonical form
  return ''.join(c for c in unicodedata.normalize('NFKC', s)
      if unicodedata.category(c) != 'Mn')

# character level translation - don't actually need to do the
# kind of preprocessing word level tutorials do
def preprocess_sentence(w, unfold=False, unfold_rate=1.2):
  
  # get ascii to represent chars with the same meaning the same way
  # make it lower, because we are using smaller datasets, and make
  # it easier to learn
  w = unicode_to_ascii(w.lower().strip())


  # need to standardise - some sentences have a space between 
  # last char and punctuation, and some do not
  # creating a space between a word and the punctuation following it
  # eg: "he is a boy." => "he is a boy ."
  # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
  w = re.sub(r"([?.!,¿])", r" \1 ", w)
  w = re.sub(r'[" "]+', " ", w)

  # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
  # and accented characters
  accentedCharacters = "àèìòùÀÈÌÒÙáéíóúýÁÉÍÓÚÝâêîôûÂÊÎÔÛãñõÃÑÕäëïöüÿÄËÏÖÜŸçÇßØøÅåÆæœ"
  accentedCharacters = unicode_to_ascii(accentedCharacters)
  w = re.sub(r"[^a-zA-Z"+accentedCharacters+"?.!,¿]+", " ", w)
  w = w.strip()
  
  if unfold:
    l = float(len(w))* (unfold_rate -1 ) 
    pad = PAD_CHAR * int(l)
    w = SOS_CHAR + w + pad + EOS_CHAR
  else:
    w = SOS_CHAR + w + EOS_CHAR

  return w

For reasons to normalize, look at this [blogpost](https://withblue.ink/2019/03/11/why-you-need-to-normalize-unicode-strings.html). This unit test illustrates the effect of the unicode_to_ascii function we just defined and why it is important.

In [0]:
#@title Unicode to Ascii Sanity Check { form-width: "150px" }
a = '\u0065\u0301'
print(a)
b = '\u00e9'
print(b)
print(a==b)
c = unicode_to_ascii(a)
print(c)
d = unicode_to_ascii(b)
print(d)
print(c==d)

In [0]:
#@title Preprocess Sanity Check { form-width: "150px" }

en_sentence = u"3 Q. Are the four beasts limited to individual beasts , or do they represent classes or a orders ?"
sp_sentence = u"¿Puedo tomar prestado este libro?"

print(preprocess_sentence(en_sentence,True))
print(preprocess_sentence(sp_sentence))

## Creating Dataset

Here, we create a dataset that the model can use from text files. We get the text from the files, and preprocess using the methods used above. <br>
Using TensorFlow's inbuilt tokenizers, we convert input from string format to a tensor representation. <br>
Finally, we create a dataset object


In [0]:
#@title Create Dataset { form-width: "150px" }


# Return word pairs in the format: [ENGLISH, SPANISH]
def create_dataset(path, num_examples, start=1000):
  lines = io.open(path, encoding='UTF-8').read().strip().split('\n')

  word_pairs = []

  for num in range(num_examples):
    pair = []
    l = lines[start+num]
    w = l.split('\t')
    pair.append(preprocess_sentence(w[0], unfold=True))
    pair.append(preprocess_sentence(w[1], unfold=False))
    word_pairs.append(pair)
  
  # word_pairs = [[preprocess_sentence(w) for w in l.split('\t')]  for l in lines[start:start + num_examples]]
  return zip(*word_pairs)

In [0]:
#@title Create Dataset Sanity Check { form-width: "150px" }
ENG_FRA_PATH = DATA_DIR + "eng-fra.txt"
en, fr = create_dataset(ENG_FRA_PATH, 10, 1000)
print(en)
print(fr)



In [0]:
def max_length(tensor):
  return max(len(t) for t in tensor)

In [0]:
#@title Tokenizing Methods { form-width: "150px" }


def tokenize(lang):
  
  lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
      filters='', char_level=True)
  lang_tokenizer.fit_on_texts(lang)
  # lang_tokenizer = add_special_tokens(lang_tokenizer)

  tensor = lang_tokenizer.texts_to_sequences(lang)
  
  # for creating dataset object
  tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,
                                                         padding='post')

  

  return tensor, lang_tokenizer

In [0]:
#@title Load Dataset Method { form-width: "150px" }
def load_dataset(path, num_examples=None, start_pos=1000):
  # creating cleaned input, output pairs
  inp_lang, targ_lang = create_dataset(path, num_examples,)

  input_tensor, inp_lang_tokenizer = tokenize(inp_lang)
  target_tensor, targ_lang_tokenizer = tokenize(targ_lang)

  return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer




In [0]:
#@title Load Dataset Method Sanity Check { form-width: "150px" }

# Try experimenting with the size of that dataset
num_examples = 1000
input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(ENG_FRA_PATH, num_examples)

# Calculate max_length of the target tensors
max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)

The next two cells just visualize what the tokenizer does

In [0]:
def convert(lang, tensor):
  for t in tensor:
    if t!=0:
      print(f"{t} ---> {lang.index_word[t]}")

In [0]:
print ("Input Language; index to word mapping")
convert(inp_lang, input_tensor[0])
print ()
print ("Target Language; index to word mapping")
convert(targ_lang, target_tensor[0])

The paper uses batches and maps input tensors (that have been padded to the nearest multiple of 50) to buckets of the same length, that are then trained on different GPUs. In this small example notebook, because of our limited hardware resources (Colab only uses 1 GPU), I am not coding this, because it has no practical advantage. Instead, given the small dataset, I'm training the network one input tensor at a time. If a bigger dataset is used, the batch size can be increased to make training quicker. <br>
All of the dataset is used for training in this example notebok, as opposed to the standard training-validation-test split. In the evaluation section, we show the results of applying the model to a single unseen sentence. 

In [0]:
#@title Create Dataset Object { form-width: "150px" }

BUFFER_SIZE = len(input_tensor)
BATCH_SIZE = 1
steps_per_epoch = len(input_tensor)//BATCH_SIZE

hidden_units = 20
vocab_inp_size = len(inp_lang.word_index)+1

vocab_tar_size = len(targ_lang.word_index)+1


dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor))

dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)


In [0]:
example_input_batch, example_target_batch = next(iter(dataset))
example_input_batch.shape, example_target_batch.shape


## Model 
Start creating the model componenets for ByteNet.


In [0]:
#@title Residual Block { form-width: "150px" }

class ResBlk(tf.keras.layers.Layer):
  def __init__(self, hidden_units, dilation, kernel_size, masking=False):
    super(ResBlk, self).__init__()
    # hidden units
    self.hidden_units = hidden_units
    # dilation
    self.dilation = dilation
    # kernel size
    self.kernel_size = kernel_size
    # masking 
    self.masking = masking

    self.normlayer1 = tf.keras.layers.LayerNormalization()
    self.relu1 = tf.keras.layers.ReLU()
    # output filters = d
    self.conv_down = tf.keras.layers.Conv1D(hidden_units, kernel_size=1)

    self.normlayer2 = tf.keras.layers.LayerNormalization()
    self.relu2 = tf.keras.layers.ReLU()

    # now for masked convolution. Need padding for kernel if using masked, 
    # tf2 handles for us
    self.padding_type = 'causal' if masking else 'same'

    self.conv_masked = tf.keras.layers.Conv1D(hidden_units, 
                                              kernel_size=kernel_size,
                                              padding = self.padding_type,
                                              dilation_rate = dilation)
    
    
    self.normlayer3 = tf.keras.layers.LayerNormalization()
    self.relu3 = tf.keras.layers.ReLU()
    self.conv_up = tf.keras.layers.Conv1D(2*hidden_units, kernel_size=1)
    
    self.block = tf.keras.Sequential([self.normlayer1, self.relu1, self.conv_down,
                                      self.normlayer2, self.relu2, self.conv_masked,
                                      self.normlayer3, self.relu3, self.conv_up])

  
  def call(self, x):
    
    o = self.block(x)
  
    # residual part
    o += x
    return o





In [0]:
#@title Encoder { form-width: "150px" }

class Encoder(tf.keras.Model):
  def __init__(self, num_chars, hidden_units, kernel_size, num_blocks,
               num_sets):
    super(Encoder, self).__init__()
    
    ## store attributes
    self.num_chars = num_chars
    self.hidden_units = hidden_units
    self.kernel_size = kernel_size
    # blocks in a set, for increasing amounts of dilation
    self.num_blocks = num_blocks
    # number of sets
    self.num_sets = num_sets


    # add the different layers and block sets
    
    # embedding needs mask_zero=True to tell tf it is 
    # added as not a character
    self.embedding = tf.keras.layers.Embedding(input_dim=num_chars, 
                                               output_dim=hidden_units,
                                               mask_zero=True)
    # up conv to 2d
    self.conv_1 = tf.keras.layers.Conv1D(2*hidden_units, kernel_size=1)
    
    ## sets
    sets = []
    for n in range(num_sets):
      for l in range(num_blocks):
        dilation = 1<<l 
        sets.append(ResBlk(hidden_units, dilation, kernel_size))
    
    self.sets = tf.keras.Sequential(sets)

    # bring back down to d
    self.conv_2 = tf.keras.layers.Conv1D(hidden_units, kernel_size=1)
    self.relu = tf.keras.layers.ReLU()

  def call(self, x):
    o = self.embedding(x)
    o = self.conv_1(o)
    o = self.sets(o)
    o = self.conv_2(o)
    o = self.relu(o)
    return o


In [0]:
#@title Encoder Sanity Check { form-width: "150px" }

enc_test = Encoder(vocab_inp_size,hidden_units,3,2,2)
# tensorflow is weird - need to reshape if putting it just one tensor.
# However, this should not matter if using dataset object
a = tf.reshape(input_tensor[0], [1, len(input_tensor[0])])
enc_out = enc_test(a)


In [0]:
#@title Decoder { form-width: "150px" }
class Decoder(tf.keras.Model):
  def __init__(self, output_size,hidden_units, kernel_size,
               num_blocks, num_sets):
    
    super(Decoder, self).__init__()
    
    self.output_size = output_size
    self.hidden_units = hidden_units
    
    # need mask zero in decoder too
    self.embedding = tf.keras.layers.Embedding(output_size, hidden_units, 
                                               mask_zero=True)
    
    ## sets
    sets = []
    for n in range(num_sets):
      for l in range(num_blocks):
        dilation = 1<<l 
        # need to mask in decoder
        sets.append(ResBlk(hidden_units, dilation, kernel_size, masking=True))
    
    self.sets = tf.keras.Sequential(sets)

    # from 3.6 in paper - "...one more convolution and 
    # ReLU followed by a convolution and a final softmax layer"
    self.conv = tf.keras.layers.Conv1D(2*hidden_units, kernel_size=kernel_size,
                                         padding='causal')
    self.relu = tf.keras.layers.ReLU()
    # conv to get the output_size
    self.conv_final = tf.keras.layers.Conv1D(output_size, kernel_size=1)
    self.softmax = tf.keras.layers.Softmax()

  
  def call(self, d_pred, e_out):
    # d_pred: previous decoder predictions
    # e_out: encoder output 
    
    # d_pred shape is 2d - batchsize * number of chars already decoded
    # after embedding will be 3d - 
    # batchsize * number of chars already decoded * embedding dim
    emb = self.embedding(d_pred)


    # add e_out to the RIGHT of predicted vals
    # get 2*d features
    emb = tf.concat([emb, e_out], 2) 
    
    o = self.sets(emb)
    o = self.conv(o)
    o = self.relu(o)
    o = self.conv_final(o)
    
    # take the last row, the actual prediction
    o = o[:,-1,:]
    o = self.softmax(o)
    
    return o


In [0]:
#@title Decoder Sanity Check { form-width: "150px" }

d_o_start = tf.expand_dims([targ_lang.word_index[SOS_CHAR]] * BATCH_SIZE, 1)
e_o = enc_out[:,:1,:]
print(f"e_o.shape = {e_o.shape}")

dec_test = Decoder(vocab_tar_size, hidden_units, 3, 5,2)
d_o = dec_test(d_o_start, e_o)
print(f"d_o.shape={d_o.shape}")

## Model Variants 

The paper presents "variants" to the standard ByteNet architecture - the residual multiplicative block (a variant of the residual block), the Recurrent ByteNet Encoder and the Recurrent ByteNet Decoder. These are shown here



### Residual Block Variant

The paper mentions two variants of the Residual Blocks, used for different experiments. The Residual Multiplicative Block is shown below. This is used in the decoder during the language modelling experiments by the paper (section 3.6)

In [0]:
#@title Residual Multiplicative Block { form-width: "150px" }


class MU(tf.keras.layers.Layer):
  def __init__(self, hidden_units, dilation, kernel_size, masking = True):
    super(MU, self).__init__()
    
    self.hidden_units = hidden_units
    self.kernel_size = kernel_size
    self.dilation = dilation
    
    self.masking = masking
    
    self.padding_type = 'causal' if masking else 'same'

    self.conv_masked_1 = tf.keras.layers.Conv1D(hidden_units, 
                                              kernel_size=kernel_size,
                                              padding = self.padding_type,
                                              dilation_rate = dilation)
    
    self.normlayer_1 = tf.keras.layers.LayerNormalization()
    self.conv_masked_2 = tf.keras.layers.Conv1D(hidden_units, 
                                              kernel_size=kernel_size,
                                              padding = self.padding_type,
                                              dilation_rate = dilation)
    
    self.normlayer_2 = tf.keras.layers.LayerNormalization()
    self.conv_masked_3 = tf.keras.layers.Conv1D(hidden_units, 
                                              kernel_size=kernel_size,
                                              padding = self.padding_type,
                                              dilation_rate = dilation)
    
    self.normlayer_3 = tf.keras.layers.LayerNormalization()
    self.conv_masked_4 = tf.keras.layers.Conv1D(hidden_units, 
                                              kernel_size=kernel_size,
                                              padding = self.padding_type,
                                              dilation_rate = dilation)
    
    self.normlayer_4 = tf.keras.layers.LayerNormalization()
  
  def build(self, input_shape):
    self.w1 = self.add_weight(shape=(input_shape[-1], self.hidden_units),
                             initializer='random_normal',
                             trainable=True)
    self.w2 = self.add_weight(shape=(input_shape[-1], self.hidden_units),
                             initializer='random_normal',
                             trainable=True)
    self.w3 = self.add_weight(shape=(input_shape[-1], self.hidden_units),
                             initializer='random_normal',
                             trainable=True)
    self.w4 = self.add_weight(shape=(input_shape[-1], self.hidden_units),
                             initializer='random_normal',
                             trainable=True)
    self.b1 = self.add_weight(shape=(self.hidden_units,),
                             initializer='random_normal',
                             trainable=True)
    self.b2 = self.add_weight(shape=(self.hidden_units,),
                             initializer='random_normal',
                             trainable=True)
    self.b3 = self.add_weight(shape=(self.hidden_units,),
                             initializer='random_normal',
                             trainable=True)
    self.b4 = self.add_weight(shape=(self.hidden_units,),
                             initializer='random_normal',
                             trainable=True)
    
    def call(self, x):
      
      # create x1 to x4 
      x1 = self.conv_masked_1(x)
      x1 = self.normlayer_1(x1)
      x1 = tf.matmul(x1, self.w1) + self.b1
      x1 = tf.math.sigmoid(x1)

      x2 = self.conv_masked_2(x)
      x2 = self.normlayer_2(x2)
      x2 = tf.matmul(x2, self.w2) + self.b2
      x2 = tf.math.sigmoid(x2)

      x3 = self.conv_masked_3(x)
      x3 = self.normlayer_3(x3)
      x3 = tf.matmul(x3, self.w3) + self.b3
      x3 = tf.math.sigmoid(x3)

      x4 = self.conv_masked_4(x)
      x4 = self.normlayer_4(x4)
      x4 = tf.matmul(x4, self.w4) + self.b4
      x4 = tf.math.tanh(x4)

      # elementwise multiplication
      h1 = tf.math.multiply(x, x2)
      h2 = tf.math.multiply(x3,x4)
      h3 = h2+h1

      h3 = tf.math.tanh(h3)
      out = tf.math.multiply(x1, h3)

      return out
      
  



class ResMUBlk(tf.keras.layers.Layer):
  def __init__(self, hidden_units, dilation, kernel_size, masking=True):
    super(ResMUBlk, self).__init__()
    # hidden units
    self.hidden_units = hidden_units
    # dilation
    self.dilation = dilation
    # kernel size
    self.kernel_size = kernel_size
    # masking 
    self.masking = masking

    self.normlayer1 = tf.keras.layers.LayerNormalization()
    self.relu1 = tf.keras.layers.ReLU()
    # output filters = d
    self.conv_down = tf.keras.layers.Conv1D(hidden_units, kernel_size=1)

    self.normlayer2 = tf.keras.layers.LayerNormalization()
    self.relu2 = tf.keras.layers.ReLU()


    self.mu = MU(hidden_units, dilation, kernel_size, masking)

    self.conv_up = tf.keras.layers.Conv1D(2*hidden_units, kernel_size=1)
    
    self.block = tf.keras.Sequential([self.normlayer1, self.relu1, self.conv_down,
                                      self.normlayer2, self.relu2,
                                      self.mu, self.conv_up])

  
  def call(self, x):
    
    o = self.block(x)
  
    # residual part
    o += x
    return o



### RNN Encoder and Decoder 

The paper mentions the recurrent variants of the standard ByteNet architecture, which are shown here

In [0]:
#@title Stacked Bidirectional LSTMs { form-width: "150px" }
class StackedBidirectionalLSTM(tf.keras.layers.Layer):
  def __init__(self, units, num_layers, dropout):

    super(StackedBidirectionalLSTM, self).__init__()
    self.stacked_lstms = tf.keras.Sequential()
    for i in range(num_layers):
      self.stacked_lstms.add(tf.keras.layers.Bidirectional(
          tf.keras.layers.LSTM(units, dropout=dropout, 
                               return_sequences=True), merge_mode='concat'))
    
  def call(self, x):
    out = self.stacked_lstms(x)
    return out

In [0]:
#@title RNN Encoder { form-width: "150px" }
class EncoderRNN(tf.keras.Model):
  def __init__(self, num_chars, hidden_units, num_layers, dropout=0):
    super(EncoderRNN, self).__init__()
    self.num_chars = num_chars
    self.hidden_units = hidden_units
    self.num_layers = num_layers
    self.emb = tf.keras.layers.Embedding(input_dim=num_chars, 
                                               output_dim=hidden_units,
                                               mask_zero=True)
    
    self.stacked_lstms = StackedBidirectionalLSTM(hidden_units, num_layers-1, 
                                                  dropout)
    self.last_lstm_layer = tf.keras.layers.Bidirectional(
          tf.keras.layers.LSTM(hidden_units, dropout=dropout, 
                               return_sequences=True, 
                               return_state=True), 
          merge_mode='concat')
    # bidirectional gives 2*hidden_units, bring down to hidden_units
    self.out_layer = tf.keras.layers.Dense(hidden_units)

  
  def call(self, x):
    out = self.emb(x)
    if self.num_layers>1:
      out = self.stacked_lstms(out)
    
    out, state_h, state_c, _,_ = self.last_lstm_layer(out)
    out = self.out_layer(out)
    
    hidden_states = [state_h, state_c]
    return out, hidden_states



In [0]:
#@title RNN Decoder { form-width: "150px" }
class DecoderRNN(tf.keras.Model):
  def __init__(self, output_size, hidden_units, num_layers, dropout=0):
    super(DecoderRNN, self).__init__()
    self.output_size = output_size
    self.hidden_units = hidden_units
    self.num_layers = num_layers
    self.emb = tf.keras.layers.Embedding(output_size, hidden_units, 
                                               mask_zero=True)
    
    # because of the way sequential works, need a "first" layer, 
    # and then the stacks, and then the "last" layer, 
    # since we want access to final decoder hidden state
    self.first_lstm_layer = tf.keras.layers.LSTM(hidden_units, 
                                                dropout=dropout, 
                                                return_sequences=True, 
                                                return_state=True)
    self.stacked_lstms = tf.keras.Sequential()

    for i in range(num_layers-2):
      self.stacked_lstms.add(tf.keras.layers.LSTM(hidden_units, 
                                                dropout=dropout, 
                                                return_sequences=True))
    
    self.last_lstm_layer = tf.keras.layers.LSTM(hidden_units, 
                                                dropout=dropout, 
                                                return_sequences=True, 
                                                return_state=True)
    self.out_layer = tf.keras.layers.Dense(output_size)
    self.softmax = tf.keras.layers.Softmax()

  
  def call(self, x, e_out, dec_hidden):
    # x is input, i.e. most recently decoded char
    # e_out is single context from encoder - taken from encoder_output,
    # as enc_out[:,i:i+1,:] during training and evaluating 
    # dec_hidden is hidden state to use in LSTMs


    # if there is no more context from the encoder, 
    # we pad with zeros! 
    if e_out.shape[1] == 0:
      e_out = tf.zeros([1, 1, self.hidden_units])
    
    emb = self.emb(x)
    emb = tf.concat([emb, e_out], 2)
    
    out, state_h, state_c = self.first_lstm_layer(emb, dec_hidden)

    if self.num_layers>2:
      out = self.stacked_lstms(out)
      out, state_h, state_c = self.last_lstm_layer(out)
    
    elif self.num_layers==2:
      out, state_h, state_c = self.last_lstm_layer(out)
    
    hidden_states = [state_h, state_c]
    
    out = self.out_layer(out)
    out = self.softmax(out)
    return out, hidden_states


### Extensions

ByteNet was proposed before the Transformer model. We could try extensions of the proposed ByteNet by adding in multi-headed attention layer to give better positional encoding. The implementation of the attention layer is from [here](https://www.tensorflow.org/tutorials/text/transformer#multi-head_attention) but due to time constraints, a working Decoder and Encoder that uses this layer was not completely developed.

In [0]:
#@title Attention Layer { form-width: "150px" }
class AttentionLayer(tf.keras.layers.Layer):
  def __init__(self, hidden_units, num_heads):
    super(AttentionLayer, self).__init__()
    self.num_heads = num_heads
    self.hidden_units = hidden_units
    
    assert hidden_units % self.num_heads == 0
    
    self.depth = hidden_units // self.num_heads
    
    self.wq = tf.keras.layers.Dense(hidden_units)
    self.wk = tf.keras.layers.Dense(hidden_units)
    self.wv = tf.keras.layers.Dense(hidden_units)
    
    self.dense = tf.keras.layers.Dense(hidden_units)
        
  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])
    
  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]
    
    q = self.wq(q)  # (batch_size, seq_len, hidden_units)
    k = self.wk(k)  # (batch_size, seq_len, hidden_units)
    v = self.wv(v)  # (batch_size, seq_len, hidden_units)
    
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)
    
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.hidden_units))  # (batch_size, seq_len_q, hidden_units)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, hidden_units)
        
    return output, attention_weights

## Training 

Here, we show how to train a standard ByteNet model. 
<br>
There are two kinds of training possible - 
<br>
The first is parallel, where the actual targets are used instead (i.e. teacher forcing), and gives us the benefit of predicting multiple characters at once and calculating loss over them (this is what the ByteNet paper uses to train). This second type of training might lead to unstable predictions at test time if the training set is small, but significantly increases training time. For example, when training on 1000 pairs of EN-FR, the sequential training takes around 7 minutes per epoch, while training using complete teacher forcing takes around 1 minute per epoch. 
<br>
The second is sequential, where the decoder's outputs are fed back into decoder (along with the encoder's outputs). This kind of training takes a long time, and converges in a reasonable time only for datasets that are on the order of 100s of pairs. 

<br> 
When training, ByteNet completely learns (i.e. NLLLoss goes to 0) small datasets. For example, 100 sentence pairs take approximately 250-300 epochs to overfit, and for 1000 sentence pairs, it takes ~5000 epochs.


In [0]:
#@title Loss Function { form-width: "150px" }

def loss_function(real, pred, loss_object):

  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

In [0]:
#@title Train Step { form-width: "150px" }

@tf.function
def train_step(inp, targ, encoder, decoder, optimizer, loss_object,
               teacher_forcing=False):
  

  loss = 0
  inp_len = inp.shape[1]
  targ_len = targ.shape[1]

  with tf.GradientTape() as tape:
    
    enc_output = encoder(inp)

    d_inp = tf.expand_dims([targ_lang.word_index[SOS_CHAR]] * BATCH_SIZE, 1)
    
    for t in range(targ_len):
      
      # need to prepare inputs to decoder to be of appropriate dimensions
      if t < enc_output.shape[1]:

        enc_padding = tf.zeros([enc_output.shape[0], 
                                targ_len-(t+1), 
                                enc_output.shape[2]])
        
      else:
        enc_padding = tf.zeros([enc_output.shape[0], 
                                targ_len-enc_out.shape[1], 
                                enc_output.shape[2]])
      
      e_out = tf.concat([enc_output[:,:t+1,:], enc_padding],1)

      d_padding = tf.zeros([d_inp.shape[0],
                            targ_len-(t+1)], dtype=tf.dtypes.int32)
      
      d_context = d_inp[:,:t+1]
      if teacher_forcing: 
        d_context = tf.expand_dims(targ[:, t+1], 1)
      
      
      d_prev = tf.concat([d_context, d_padding],1)
      
      # pass d_prev and e_out into decoder
      d_out = decoder(d_prev, e_out)
      
      loss += loss_function(targ[:, t], d_out, loss_object)

      _, i = tf.math.top_k(d_out)
      d_inp = tf.concat([d_inp, i], axis =1) 
      
  batch_loss = (loss / int(targ.shape[1]))

  variables = encoder.trainable_variables + decoder.trainable_variables

  gradients = tape.gradient(loss, variables)

  optimizer.apply_gradients(zip(gradients, variables))

  return batch_loss

In [0]:
#@title Train Step Sanity Check Set Up { form-width: "150px" }

# use same learning rate as paper
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)

# set logits to false, because we have softmax in last layer
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=False, reduction='none')


# use a smaller encoder and decoder as compared to the paper for 
# this notebook 
encoder = Encoder(vocab_inp_size,hidden_units,3,2,2)
decoder = Decoder(vocab_tar_size, hidden_units, 3, 2,2)



In [0]:
#@title Train Step Sanity Check { form-width: "150px" }
loss = train_step(example_input_batch, example_target_batch, 
                  encoder, decoder, optimizer, loss_object)
print(loss)

In [0]:
#@title Train Step With Complete Teacher Forcing{ form-width: "150px" }

@tf.function
def train_step_tf(inp, targ, encoder, decoder, optimizer, loss_object,
               teacher_forcing=False):
  

  loss = 0
  inp_len = inp.shape[1]
  targ_len = targ.shape[1]

  with tf.GradientTape() as tape:
    
    enc_output = encoder(inp)
    
    d_inp = tf.expand_dims([targ_lang.word_index[SOS_CHAR]] * BATCH_SIZE, 1)
  
    d_context = tf.expand_dims(targ, 1)
      
    d_prev = tf.concat([d_inp, d_context],1)
      
    d_out = decoder(d_prev, e_out)
      
    for t in range(targ_len):
      loss += loss_function(targ[:, t], d_out, loss_object)

      
  batch_loss = (loss / int(targ.shape[1]))

  variables = encoder.trainable_variables + decoder.trainable_variables

  gradients = tape.gradient(loss, variables)

  optimizer.apply_gradients(zip(gradients, variables))

  return batch_loss

Checkpoint saving

In [0]:
#@title Checkpoint Saving { form-width: "150px" }

checkpoint_dir = CKPT_DIR
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [0]:
#@title Training by Epoch { form-width: "150px" }

EPOCHS = 100

for epoch in range(EPOCHS):
  start = time.time()

  total_loss = 0

  for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):

 
    batch_loss = train_step(inp, targ, encoder, decoder, optimizer, 
                            loss_object)
    total_loss += batch_loss

    if batch % 100 == 0:
      print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                   batch,
                                                   batch_loss.numpy()))
  # saving (checkpoint) the model every 2 epochs
  if (epoch + 1) % 2 == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

  print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

## Evaluating 

Given the model we have trained, run it on unseen sentences. While "greedy" search is the easiest to implement, the paper uses beam search for evaluation. Both are shown here.
<br>
It is here the difficulty of using TensorFlow becomes most apparent. Several built-in methods that could make evaluation easier, such as beam search (which is implemented from scratch in this notebook for ease of presentation), require non-eager execution and other technical complications. <br>




For beam search, tf.nn.ctc_beam_search_decoder requires non-eager execution, which is tricky. Going to implement a version of CTC Beam Search from scratch by following this [article](https://medium.com/the-artificial-impostor/implementing-beam-search-part-1-4f53482daabe) and this [repository](https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py).<br>


Each Node going to contain the following - <br>
1) List of chars so far<br>
2) Decoder output for each of the chars so far (need to calculate loss later)<br>
3) It's own log probability<br>

The basic beam search demonstrated in this example notebook follows a very basic algorithm as follows - <br>
1) Initialize - First node contains the SOS char, and add it to prev_nodes <br>
2) At each search iteration -<br>
 * For each node in prev_nodes, get the k most probable next chars and create k nodes for each of these chars, where k = beam_width
 * If the char is EOS, add the node to the list of candidates, else, add the node to the next_nodes list
 * Sort the next_nodes list based on the log_probs of the nodes, and keep the k top ones.
 * The next_nodes are now the new prev_nodes for the next iteration

In [0]:
#@title Evaluate { form-width: "150px" }



# similar to training, just add a output array and stop if EOS is reached
def evaluate(inputs,targs, encoder, decoder, loss_object):

  loss = 0
  # this will be the MAXIMUM length of the DataSet Object - this 
  # is to prevent runaway translations
  targ_len = targs.shape[1]
  enc_output = encoder(inputs)
  d_inp = tf.expand_dims([targ_lang.word_index[SOS_CHAR]] , 1)
  output = []
  for t in range(targ_len):

    if t < enc_output.shape[1]:

        enc_padding = tf.zeros([enc_output.shape[0], 
                                targ_len-(t+1), 
                                enc_output.shape[2]])
        
    else:
      enc_padding = tf.zeros([enc_output.shape[0], 
                              targ_len-enc_out.shape[1], 
                              enc_output.shape[2]])
    
    e_out = tf.concat([enc_output[:,:t+1,:], enc_padding],1)

    d_padding = tf.zeros([d_inp.shape[0],
                          targ_len-(t+1)], dtype=tf.dtypes.int32)
    
    d_context = d_inp[:,:t+1]
    
    d_prev = tf.concat([d_context, d_padding],1) 

    d_out = decoder(d_prev, e_out)

    loss += loss_function(targs[:, t], d_out, loss_object)

    _, i = tf.math.top_k(d_out)
    idx = tf.reduce_sum(i).numpy()
    c = targ_lang.index_word[idx]
    output.append(c)
    if c==EOS_CHAR:
      break
    d_inp = tf.concat([d_inp, i], axis =1) 

  sentence_loss = loss / len(output)
  return output, sentence_loss


In [0]:
#@title Node Class { form-width: "150px"}


class Node:
  def __init__(self, decoded_chars, decoder_output, log_prob):
    # chars already decoded in the path leading up to it 
    # basically the record of all previous states
    # stored as indices
    self.decoded_chars = decoded_chars
    # output from decoder if the decoded_chars 
    # was fed into the decoder. Need this to 
    # calculate the loss 
    self.decoder_output = decoder_output
    # log of probability, this is what we will 
    # sort on 
    self.log_prob = log_prob


In [0]:
#@title Beam Search Params { form-width: "150px"}

beam_search_params = {
    # maximum candidates being considered
    'max_candidates' : 10,
    # to prevent runaway search
    'max_iterations' : 1000,
    # width of the beam search - paper uses 12
    'width' : 5
}

In [0]:
#@title Evaluate with Beam { form-width: "150px"}

def beam_evaluate(inputs,targs, encoder, decoder, loss_object, beam_search_params):

  loss = 0
  enc_output = encoder(inputs)
  d_inp = tf.expand_dims([targ_lang.word_index[SOS_CHAR]] , 1)
  output = []


  # beam search initialize arrays 
  prev_nodes = []
  next_nodes = []
  candidates = []
  beam_width = beam_search_params['width']

  # make the first node 
  d_prev = d_inp[:,:1]
  e_out = enc_output[:,:1, :]
  d_out = decoder(d_prev, e_out)
  initial_node = Node(d_inp, d_out, 0)


  # put it in the prev_nodes 
  prev_nodes.append(initial_node)

  # limit the number of iterations 
  # to prevent run away search
  t = 0
  while t <= beam_search_params['max_iterations'] and \
  prev_nodes and len(candidates) <= beam_search_params['max_candidates']:
    t+=1 
    for prev_node in prev_nodes:
      # consider a previous node, get possible 
      # next chars
      decoded_prev = prev_node.decoded_chars
      l = d_prev.shape[1]
      e_out = enc_output[:,:l,:]
      d_out = decoder(decoded_prev, e_out)

      probs, indexes = tf.math.top_k(d_out, k=beam_width)

      # add all to the next nodes 
      for j in range(beam_width):
        idx_j = tf.expand_dims(indexes[:,j], axis=1)
        p_j = tf.expand_dims(probs[:,j], axis=1) 
        log_prob_j = np.log(p_j).flatten()[0] 
        decoded_next = tf.concat([decoded_prev, idx_j], axis= 1)
        output_next = tf.concat([prev_node.decoder_output, d_out], axis=0)

        next_node = Node(decoded_next, output_next, 
                        prev_node.log_prob + log_prob_j)
        
        idx_j = tf.reduce_sum(idx_j).numpy()
        c = targ_lang.index_word[idx_j]

        if c == EOS_CHAR:
          candidates.append(next_node)
        else:
          next_nodes.append(next_node)
    
    # done adding considering new nodes for this 
    # iteration. Now, sort next_nodes 
    # and rebuild prev_nodes to keep it to beam_width
    prev_nodes = []
    next_nodes.sort(key = lambda x: x.log_prob, reverse=True)

    # rebuild prev_nodes
    # seem to 
    for i in range(min(beam_width, len(next_nodes))):
      prev_nodes.append(next_nodes[i])
    
    # empty next nodes
    next_nodes = []

    if t+1 > beam_search_params['max_iterations'] and not candidates:
      candidates = prev_nodes

  # done with beam search - found our candidates
  # sort them, get the top one and calculate loss 
  candidates.sort(key = lambda x: x.log_prob, reverse=True)
  o_node = candidates[0]

  output = []
  for i in range(o_node.decoded_chars.shape[1]):  
    idx = tf.reduce_sum(o_node.decoded_chars[0][i]).numpy()
    output.append(targ_lang.index_word[idx])

  t_len = targs.shape[1]
  for i in range(len(output)):
    
    # target might be smaller than predicted, so just put 
    # EOS token
    if i<t_len:
      real = targs[:,i]
    else:
      real = tf.convert_to_tensor([targ_lang.word_index[EOS_CHAR]])
    pred = o_node.decoder_output[i,:]
    pred = tf.expand_dims(pred,0)
    loss += loss_function(real, pred, loss_object)

  sentence_loss = loss / len(output)
  return output, sentence_loss

If the Dataset object had been split into training-validation-test sets, the preprocessing steps in the "Evaluate Sanity Check Set Up" would be unnecessary. But because this example notebook is for illustration purposes, all of the dataset is used for training, and we evaluate the model on single sentences 

In [0]:
#@title Evaluate Sanity Check Set Up { form-width: "150px" }
sentence = "we try"
target_sentence = "on essaye"

sentence = preprocess_sentence(sentence, True)
inputs = [inp_lang.word_index[c] for c in sentence]
inputs = tf.convert_to_tensor(inputs)
inputs = tf.expand_dims(inputs, 0)
inputs = tf.keras.preprocessing.sequence.pad_sequences(inputs, 
                                                       maxlen=max_length_inp, 
                                                       padding='post')

target_sentence = preprocess_sentence(target_sentence)
targs = [targ_lang.word_index[c] for c in target_sentence]
targs = tf.expand_dims(targs, 0)
targs = tf.keras.preprocessing.sequence.pad_sequences(targs, 
                                                      maxlen=max_length_targ, 
                                                      padding='post')



In [0]:
#@title Evaluate Unit Test { form-width: "150px" }

output, sentence_loss = evaluate(inputs ,targs, encoder, 
                                 decoder, loss_object)
