## Modified version of the Transformer with sequential Monte-Carlo Methods applied to Language Modelling
* start from this notebook: https://www.tensorflow.org/tutorials/text/transformer
* Implement the "SMC Transformer forward Model" presented here: https://drive.google.com/open?id=1ms8B9dYCU9yiMtkrEX2LzOEDD7EJJYQI 
* Use the dataset from this tutorial: https://www.tensorflow.org/tutorials/text/text_generation


* Pre-trained models - Hugging Face: https://github.com/huggingface/transformers

* See also this tutorial on Transfer Learning for NLP from NNACL 2019: https://colab.research.google.com/drive/1iDHCYIrWswIKp-n-pOg69xLoZO09MEgf

* github version: https://github.com/huggingface/naacl_transfer_learning_tutorial

### experiments
#### PTB code: 
* https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py
* https://github.com/sebastianruder/tensorflow-experiments/blob/master/rnn_ptb.py

#### Wikipedia Dataset (included in tensorflow.dataset library). 

#### code of the Particle Filter Recurrent Neural Network 
* https://github.com/AdaCompNUS/pfnet

#### About 'automatic' metrics to measure diversity/quality of language: see the link below
* https://thegradient.pub/understanding-evaluation-metrics-for-language-models/




### to do (14/11/2019)

#### Model part
* Recode and debug pass_forward_from_layer class to have the option of 'decoder-level' SMC (what is done right now), and 'layer-level' SMC (OPTIONAL AND NOT URGENT)
* Correct decoding with teacher forcing instead of with predictions
* Code the option to choose between sampling method (what is done right now) and 'greedy' method (take previous as input to compute attention parameters). 

#### Inference part: 
* adapt the evaluate function with only the decoder and the right dataset. 
* code the function that predicts a new word Yk from a sequence of words Y0:k-1 (see language_models.pdf)
* check if it works (do unit tests)

#### training part
* Implement the computation of the gradient
* See what are the 'real' trainable parameters, and the ones for which no gradient computation is needed > use the function tf.stop_gradient (on the resampling weights & on the ancestral index matrix)
* Implement the computation of the gradient as a custom loss? > cf this tuto: https://colab.research.google.com/drive/1blsyWmNzhSrvMsxIEp3g2-X-zk4FyFzx

#### experiments
* Change the prediction task to a word-level LM (instead of a character-level one)

#### formating
* Transform this code in a github repo with .py files. 

# 1. TRANSFORMER MODEL WITH PARTICLE FILTERS

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/text/transformer">
    <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />
    View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/text/transformer.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/text/transformer.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />
    View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/text/transformer.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

The core idea behind the Transformer model is ** self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence.** Transformer creates stacks of self-attention layers and is explained below in the sections *Scaled dot product attention* and *Multi-head attention*.

A transformer model handles variable-sized input using stacks of self-attention layers instead of [RNNs](text_classification_rnn.ipynb) or [CNNs](../images/intro_to_cnns.ipynb). This general architecture has a number of advantages:

* It make no assumptions about the temporal/spatial relationships across the data. This is ideal **for processing a set of objects** (for example, [StarCraft units](https://deepmind.com/blog/alphastar-mastering-real-time-strategy-game-starcraft-ii/#block-8)).
* **Layer outputs can be calculated in parallel**, instead of a series like an RNN.
* Distant items can affect each other's output without passing through many RNN-steps, or convolution layers (see [Scene Memory Transformer](https://arxiv.org/pdf/1903.03878.pdf) for example).
* It can **learn long-range dependencies.** This is a challenge in many sequence tasks.

The downsides of this architecture are:

* For a time-series, the output for a time-step is calculated from the *entire history* instead of only the inputs and current hidden-state. This _may_ be less efficient.   
* If the input *does* have a  temporal/spatial relationship, like text, **some positional encoding must be added or the model will effectively see a bag of words.**


In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow_datasets as tfds
import tensorflow as tf

import time
import numpy as np
import matplotlib.pyplot as plt

## Setup input pipeline
* From this colab botebook: https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/text/text_generation.ipynb#scrollTo=EHDoRoc5PKWz

In [0]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

### Read the data

First, look in the text:

In [0]:
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print ('Length of text: {} characters'.format(len(text)))

In [0]:
# Take a look at the first 250 characters in text
print(text[:250])

In [0]:
# The unique characters in the file
vocab = sorted(set(text))
print ('{} unique characters'.format(len(vocab)))

## Process the text

### Creating a word-level vocabulary

### Vectorizing the text at word-level

### Vectorize the text

Before training, we need to map strings to a numerical representation. Create two lookup tables: one mapping characters to numbers, and another for numbers to characters.

In [0]:
# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

text_as_int = np.array([char2idx[c] for c in text])

Now we have an integer representation for each character. Notice that we mapped the character as indexes from 0 to `len(unique)`.

In [0]:
print('{')
for char,_ in zip(char2idx, range(20)):
    print('  {:4s}: {:3d},'.format(repr(char), char2idx[char]))
print('  ...\n}')

In [0]:
# Show how the first 13 characters from the text are mapped to integers
print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))

### The prediction task
#### NB: change the task to a word-level prediction

**Given a character, or a sequence of characters, what is the most probable next character? This is the task we're training the model to perform.**
The input to the model will be a sequence of characters, and we train the model to predict the output—the following character at each time step.


### Create training examples and targets

Next divide the text into example sequences. Each input sequence will contain `seq_length` characters from the text.

For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.

**So break the text into chunks of `seq_length+1`. For example, say `seq_length` is 4 and our text is "Hello". The input sequence would be "Hell", and the target sequence "ello".**

To do this first use the `tf.data.Dataset.from_tensor_slices` function to convert the text vector into a stream of character indices.

In [0]:
# The maximum length sentence we want for a single input in characters
seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)

# Create training examples / targets
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)

for i in char_dataset.take(5):
  print(idx2char[i.numpy()])

The `batch` method lets us easily convert these individual characters to sequences of the desired size.

In [0]:
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)

for item in sequences.take(5):
  print(repr(''.join(idx2char[item.numpy()])))

For each sequence, duplicate and shift it to form the input and target text by using the `map` method to apply a simple function to each batch:

In [0]:
def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

Print the first examples input and target values:

In [0]:
for input_example, target_example in  dataset.take(1):
  print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))
  print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))

Each index of these vectors are processed as one time step. For the input at time step 0, the model receives the index for "F" and trys to predict the index for "i" as the next character. At the next timestep, it does the same thing but the `RNN` considers the previous step context in addition to the current input character.

In [0]:
for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):
    print("Step {:4d}".format(i))
    print("  input: {} ({:s})".format(input_idx, repr(idx2char[input_idx])))
    print("  expected output: {} ({:s})".format(target_idx, repr(idx2char[target_idx])))

### Create training batches

We used `tf.data` to split the text into manageable sequences. But before feeding this data into the model, we need to shuffle the data and pack it into batches.

In [0]:
# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

dataset

## Positional encoding

Since this model doesn't contain any recurrence or convolution, **positional encoding is added to give the model some information about the relative position of the words in the sentence.**

The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of words in a sentence. So after adding the positional encoding, words will be closer to each other based on the *similarity of their meaning and their position in the sentence*, in the d-dimensional space.

See the notebook on [positional encoding](https://github.com/tensorflow/examples/blob/master/community/en/position_encoding.ipynb) to learn more about it. The formula for calculating the positional encoding is as follows:

$$\Large{PE_{(pos, 2i)} = sin(pos / 10000^{2i / d_{model}})} $$
$$\Large{PE_{(pos, 2i+1)} = cos(pos / 10000^{2i / d_{model}})} $$

In [0]:
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

In [0]:
def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)
  
  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
  
  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    
  pos_encoding = angle_rads[np.newaxis, ...]
    
  return tf.cast(pos_encoding, dtype=tf.float32)

In [0]:
def positional_encoding_SMC(position, d_model, num_particles):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)
  
  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
  
  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    
  pos_encoding = angle_rads[np.newaxis, ...]
  pos_encoding=pos_encoding[:,np.newaxis,:,:]

  pos_encoding=tf.tile(pos_encoding, [1,num_particles,1,1])
  print(pos_encoding.shape)
    
  return tf.cast(pos_encoding, dtype=tf.float32)

In [0]:
pos_encoding = positional_encoding(50, 512)
print (pos_encoding.shape)

pos_encoding_SMC=positional_encoding_SMC(50,512,10)
print('pos_encoding_SMC', positional_encoding_SMC)

plt.pcolormesh(pos_encoding[0], cmap='RdBu')
plt.xlabel('Depth')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()

## Masking

Mask all the pad tokens in the batch of sequence. It ensures that the model does not treat padding as the input. The mask indicates where pad value `0` is present: it outputs a `1` at those locations, and a `0` otherwise.

In [0]:
def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
  
  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

In [0]:
x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
create_padding_mask(x)

**The look-ahead mask is used to mask the future tokens in a sequence. In other words, the mask indicates which entries should not be used.**

This means that to predict the third word, only the first and second word will be used. Similarly to predict the fourth word, only the first, second and the third word will be used and so on.

In [0]:
#tf.linalg.band_part(
    #input,
    #num_lower,
    #num_upper,
    #name=None
#)
# num_lower: number of subdiagonals to keep. If negative, keeps, entire triangle. 
# num_upper: # of upper diagonals to keep. If negative, keeps entire triangle. 


def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)

In [0]:
x = tf.random.uniform((1, 3))
print('x', x)
temp = create_look_ahead_mask(x.shape[1])
print('mask', temp)

In [0]:
print(tf.matmul(x,temp))
x+=temp
print(x)

## Scaled dot product attention

<img src="https://www.tensorflow.org/images/tutorials/transformer/scaled_attention.png" width="500" alt="scaled_dot_product_attention">

The attention function used by the transformer takes three inputs: **Q (query), K (key), V (value).** The equation used to calculate the attention weights is:

$$\Large{Attention(Q, K, V) = softmax_k(\frac{QK^T}{\sqrt{d_k}}) V} $$

The dot-product attention is scaled by a factor of square root of the depth. **This is done because for large values of depth, the dot product grows large in magnitude pushing the softmax function where it has small gradients resulting in a very hard softmax.** 

For example, consider that `Q` and `K` have a mean of 0 and variance of 1. Their matrix multiplication will have a mean of 0 and variance of `dk`. Hence, *square root of `dk`* is used for scaling (and not any other number) because the matmul of `Q` and `K` should have a mean of 0 and variance of 1, and you get a gentler softmax.

**The mask is multiplied with -1e9 (close to negative infinity). This is done because the mask is summed with the scaled matrix multiplication of Q and K and is applied immediately before a softmax. The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.**

In [0]:
def scaled_dot_product_attention(q, k, v, mask, dec_timestep, K=None, V=None, Z=None, mode='self'): 
  
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.
  
  Args:
    q: query shape == (..., num_particles, depth) for sampled word
    k: key shape == (..., num_particles, depth) for sampled word
    v:value shape == (..., num_particles, depth_v) for sampled word
    
    K: key shape == (..., num_particles, seq_len_k, depth)
    V: value shape == (..., num_particles, seq_len_v, depth_v)
    Z: value shape == (..., num_particles, seq_len_v, depth)
    mask: Float tensor with shape broadcastable 
          to (..., num_particles, seq_len_q, seq_len_k). Defaults to None.
    mode: to distinct between 'encoder-decoder' (no SMC for now) attention & self-attention 
    
  Returns:
    output (new Z), attention_weights, K, V
  """

  # FOR SMC: SHAPE OF Q (..., NUM_PARTICLES, seq_len_q, depth)
             # SHAPE OF K (..., NUM_PARTICLES, seq_len_k, depth)
            # SHAPE OF V (..., NUM_PARTICLES, seq_len_v, depth_v)
      
  # FOR SMC: k[l]=K0:k[l], v[l]=v0:k[l], q[l]=q[Il]     
  if K is not None:
    # compute K(0:k) from K(0:k-1) & k
    # dim of K in the case of multi-head attention: (batch_size, num_particles, num_heads, seq_length, Depth)
    K=tf.concat([K[:,:,:,:dec_timestep,:], k, tf.expand_dims(K[:,:,:,dec_timestep+1,:],axis=3)], axis=3)
  else:
    K=k
  if V is not None:
    # compute the V(0:k) with V(0:k-1) & v
    V=tf.concat([V[:,:,:,:dec_timestep,:], v, tf.expand_dims(V[:,:,:,dec_timestep+1,:],axis=3)], axis=3)
  else:
    V=v
  
  # adapt shape of q for doing the matmul
  matmul_qk = tf.matmul(q, K, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  
  # scale matmul_qk
  dk = tf.cast(tf.shape(K)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9) 
    # ok, makes sense: mask multiplied by a number closed to -infinity (to zero out after softmax) and sum with scaled_attention_logits 
    #(because sum means product after softmax)

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., num_particles, seq_len_q, seq_len_k)
  
  output = tf.matmul(attention_weights, V)
  # here, output= z_k[l] -> shape (..., num_particles, depth)
  
  if mode=='self' and Z is not None:
    # compute the Z(0:k) from z(0:k-1) & output
    Z[:,:,dec_timestep,:]=output
  else:
    Z=output
  
  # The output return the variable Z(0:k)
  # FOR THE SMC TRANSFORMER, THE OUTPUT SHAPE SHOULD BE (..., NUM_PARTICLES, seq_len_q, depth_v) 
  
  return (Z,K,V), attention_weights

In [0]:
def print_out(q, k, v):
  (temp_out, temp_K, temp_V), temp_attn = scaled_dot_product_attention(
      q, k, v, None, 3)
  print ('Attention weights are:')
  print (temp_attn)
  print ('Output is:')
  print (temp_out)

In [0]:
np.set_printoptions(suppress=True)

temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)

# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

In [0]:
# This query aligns with a repeated key (third and fourth), 
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

In [0]:
# This query aligns equally with the first and second key, 
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

Pass all the queries together.

In [0]:
temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32)  # (3, 3)
print_out(temp_q, temp_k, temp_v)

## Multi-head attention

<img src="https://www.tensorflow.org/images/tutorials/transformer/multi_head_attention.png" width="500" alt="multi-head attention">


Multi-head attention consists of four parts:
*    Linear layers and split into heads.
*    Scaled dot-product attention.
*    Concatenation of heads.
*    Final linear layer.

Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers and split up into multiple heads. 

The `scaled_dot_product_attention` defined above is applied to each head (broadcasted for efficiency). **An appropriate mask must be used in the attention step.**  The attention output for each head is then concatenated (using `tf.transpose`, and `tf.reshape`) and put through a final `Dense` layer.

**Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.**

In [0]:
class MultiHeadAttention(tf.keras.layers.Layer):
  '''
  multi-head attention mechanism for each layer of the Transformer
  -args: 
    -d_model: depth model
    -num_heads: number of heads for the multi-head attention mechanism
    -num_particles: number of particles to generate 
    -dec_timestep: current decoding timestep (=k) for the sequential mechanism
    -mode: self-attention (default) or encoder/decoder attention
    '''
  def __init__(self, d_model, num_heads, num_particles, dec_timestep, mode='self'): # 2 arguments added: dec_timestep, mode. 
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    
    assert d_model % self.num_heads == 0
    
    self.depth = d_model // self.num_heads
    
    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)
    
    self.dense = tf.keras.layers.Dense(d_model)
    
    # distinct between encoder-decoder attention (no SMC) & self-attention (w/ SMC):
    self.mode=mode
    self.num_particles=num_particles
    self.timestep=dec_timestep
        
  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    (batch_size, num_particle, seq_length, d_model) => (batch_size, num_particle, seq_length, num_heads, depth=d_model/num_heads)
    """
    x = tf.reshape(x, (batch_size, self.num_particles, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 1, 3, 2, 4])
    
  def concat_heads(self, x):
    '''concat attention parameters over all heads (and permute dimensions)
    -returns a tensor of shape (B, P, S, D)'''
    scaled_attention = tf.transpose(x, perm=[0, 1, 3, 2, 4])  # (batch_size, NUM_PARTICLES, seq_len_q, num_heads, depth)

    return tf.reshape(scaled_attention, 
                                  (tf.shape(scaled_attention)[0], tf.shape(scaled_attention)[1], -1, self.d_model))  # (batch_size, NUM_PARTICLES, seq_len_q, d_model)

  def call(self, v, k, q, sigma, mask, K=None, V=None, Z=None):
    '''
    -Args:
      -v,k,q: v(k), k(k), q(k): attention parameters (over all heads) @ current decoding timestep. > shape (B,P,D)
      - sigma: to compute the Gaussian noise with the reparametrization trick > float tensor of shape (B, P, S, D)
      -mask: padding or look_ahead mask. 
      -K,V,Z: KO:k, V0:k, Z0:k: total length attention parameters (until decoding timestep) > shape (B, P, S, D)
    -Returns:
      -K:0:k+1, V0:k+1, Z0:k+1
      -attention_weights
    '''
    batch_size = tf.shape(q)[0]
    
    #> FOR SMC: q is only the query of the current word: shape (batch_size, num_particles, d_model)
    q = self.wq(q)  # (batch_size, NUM_PARTICLES, seq_len_q, d_model)  
    k = self.wk(k)  # (batch_size, NUM_PARTICLES, seq_len_k, d_model)
    v = self.wv(v)  # (batch_size, NUM_PARTICLES, seq_len_v, d_model)
    
    q = self.split_heads(q, batch_size)  # (batch_size, NUM_PARTICLES, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, NUM_PARTICLES, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, NUM_PARTICLES, num_heads, seq_len_v, depth)
    
    if K is not None:
      K = self.split_heads(K, batch_size)
    if V is not None:
      V = self.split_heads(V, batch_size)
    if Z is not None:
      Z = self.split_heads(Z, batch_size)
 
    # FOR SMC: attention_weights.shape == (batch_size, NUM_PARTICLES, num_heads, seq_len_q, seq_len_k)
    (scaled_attention, K, V), attention_weights= scaled_dot_product_attention(q, k, v, mask, self.timestep, K, V, Z, self.mode)
    
    # concat attention, K, V over all the heads
    concat_attention=self.concat_heads(scaled_attention)
    concat_K=self.concat_heads(K)
    concat_V=self.concat_heads(V)
    
    # Add gaussian noise here: before or after the dense layer? (test both)
    # COMPUTE THE REPARAMETRIZATION TRICK
    self.stddev=tf.math.multiply(sigma, tf.random.normal(shape=tf.shape(concat_attention), name='gaussian_noise_reparametrized'))
    # replace by a 'fixed' noise for the inference part (with pre-training) > OPTIONAL
    #self.gaussian_noise=tf.keras.layers.GaussianNoise(self.stddev)
    #output = self.gaussian_noise(self.dense(concat_attention))  # (batch_size, NUM_PARTICLES, seq_len_q, d_model)

    output=concat_attention+self.stddev
        
    # THE OUTPUT IS ALSO THE VARIABLE Z (CONCATENATION OF THE Z OF EACH HEAD)
    # FOR SMC: OUTPUT SHAPE (batch_size, NUM_PARTICLES, seq_len_q, d_model)
    return (output, concat_K, concat_V), attention_weights

Create a `MultiHeadAttention` layer to try out. At each location in the sequence, `y`, the `MultiHeadAttention` runs all 8 attention heads across all other locations in the sequence, returning a new vector of the same length at each location.

In [0]:
temp_mha = MultiHeadAttention(d_model=512, num_heads=8, num_particles=10, dec_timestep=3, mode='enc-dec')
y = tf.random.uniform((1, 10, 60, 512))  # (batch_size, encoder_sequence, d_model)
(Z,K,V), attn = temp_mha(y, k=y, q=y, sigma=1, mask=None)
#out.shape, attn.shape, K.shape, V.shape
print(Z.shape)

## Point wise feed forward network

Point wise feed forward network consists of two fully-connected layers with a ReLU activation in between.

In [0]:
def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, NUM_PARTICLES, seq_len, d_model)
  ])

In [0]:
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape

## Decoder

<img src="https://www.tensorflow.org/images/tutorials/transformer/transformer.png" width="600" alt="transformer">

The transformer model follows the same general pattern as a standard [sequence to sequence with attention model](nmt_with_attention.ipynb). 

* **The input sentence is passed through `N` encoder layers that generates an output for each word/token in the sequence.**
* The decoder attends on the encoder's output and its own input (self-attention) to predict the next word. 

### Decoder layer

Each decoder layer consists of sublayers:

1.   Masked multi-head attention (with look ahead mask and padding mask)
2.   Point wise feed forward networks

Each of these sublayers has a residual connection around it followed by a layer normalization. The output of each sublayer is `LayerNorm(x + Sublayer(x))`. The normalization is done on the `d_model` (last) axis.

There are N decoder layers in the transformer.

**In other words, the decoder predicts the next word by self-attending to its own output.** See the demonstration above in the scaled dot product attention section.

In [0]:
# original DecoderLayer from TF 2.0 tutorial on Tranformer
class DecoderLayer(tf.keras.layers.Layer):
  '''adaptated version of the original Decoder Layer of the Transformer. 
  The only difference are the shapes of the input tensor (B, P, S, D) instead of (B, S, D)
  -args:
    -d_model: model depth
    -num_heads: number of heads in the multi-head attention mechanism
    -dff: output dimension of the feed forward network
    -num_particles: number of simulated particles for the latent state space of the Transformer
    -rate: dropout rate for output layers
  '''

  def __init__(self, d_model, num_heads, dff, num_particles, rate=0.1):
    super(DecoderLayer, self).__init__()
    
    self.dec_timestep=0
    self.num_particles=num_particles
    
    self.mha1 = MultiHeadAttention(d_model, num_heads, num_particles, self.dec_timestep, mode='self')

    self.ffn = point_wise_feed_forward_network(d_model, dff)
 
    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)

  def set_decoder(self, decoder):
    '''trick because this function is present in the DecoderLayer_SMC class.'''
    pass
    
  def call(self, x, training, look_ahead_mask, padding_mask):
    '''
    -args: 
        -x: 
        -training: boolean to distinct between training and evaluation phase. 
        -look_ahead_mask: for masking future decoding timestep
        -padding_mask: for fixed-size words sequence. 

    -returns
        -r0:k: entire sequence (until current decoding timestep) output of the Decoder layer > dim (B, P, S, D)
        -attention_weights_block: useful to visualize attention 
    '''
    (Z,K,V), attn_weights_block1 = self.mha1(x, x, x, 0, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(Z, training=training)
    out1 = self.layernorm1(Z + x)
    
    ffn_output = self.ffn(out1)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out1)  # (batch_size, target_seq_len, d_model)
    return out3, attn_weights_block1

In [0]:
sample_decoder_layer = DecoderLayer(512, 8, 2048, 10)

sample_decoder_layer_output, _,= sample_decoder_layer(
    tf.random.uniform((64, 10, 50, 512)),
    False, None, None)

sample_decoder_layer_output.shape  # (batch_size, target_seq_len, d_model)

In [0]:
# Code for the Decoder Layer with SMC. 
class DecoderLayer_SMC(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, sigma,  num_particles, layer_num, decoder=None, rate=0.1):
    '''
    -Args: 
      -d_model: model depth
      -num_heads: number of heads in the multi-head attention mechanism
      -dff: output dimension of the feed forward network
      -target_vocab_size (for computing the sampling weights)
      -maximum_position_encoding: number of positions for positional encodings. 
      -sigma: constant noise used for the Gaussian Noise with the reparametrization trick > float tensor of shape (B, P, S, D) # CODE IT AS A LEARNED PARAMETER INSTEAD. 
      -num_particles: number of simulated particles for the latent state space of the Transformer
      -layer_num: only used if resampling is done at layer-level (in the Decoder class)
      -rate: dropout rate for output layers
    '''
    super(DecoderLayer_SMC, self).__init__()
    
    # store the decoding timestep
    self.dec_timestep=0
    self.mha1 = MultiHeadAttention(d_model, num_heads, num_particles, self.dec_timestep, mode='self')
    self.ffn = point_wise_feed_forward_network(d_model, dff)
 
    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)
    
    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
    self.pos_encoding_SMC=positional_encoding_SMC(maximum_position_encoding, d_model, num_particles)
    self.dropout = tf.keras.layers.Dropout(rate)
    
    self.num_particles=num_particles
    self.d_model=d_model
    self.target_vocab_size=target_vocab_size

    self.layer_num=layer_num
    self.decoder=decoder

    # This layer needs to share weights with the Transformer.final_layer
    if self.decoder is None:
      print('SMC mechanism done at decoder-level...')
      self.output_layer=tf.keras.layers.Dense(target_vocab_size)
    
    self.initialize=False

  def set_decoder(self, decoder):
    '''only used is the resampling is done at layer level.'''
    self.decoder=decoder

  def forward_pass_from_layer(self, x):
    '''compute the forward pass from the layer until the decoder output
      only used if re-sampling is done at layer level.
    '''
    #if self.decoder is not None:
        #forward_layers=self.decoder.dec_layers[self.layer_num:]
        #for layer in forward_layers:
          #forward_func=layer()
    #else:
    forward_func=self.output_layer
    return forward_func(x)

  def initialize_indices_matrix(self, batch_size, seq_length):
    # TO COMPUTE AT TRANSFORMER LEVEL? 
    ind_matrix=tf.zeros(shape=(batch_size, self.num_particles, seq_length),dtype=tf.int32)
    self.ind_matrix=ind_matrix
    return ind_matrix # tf.stop_gradient(ind_matrix)? 
    
  def preprocess_words(self, x, training):
    '''add words embeddings and positional encodings:
        -Args: 
          -x: 3D tensor of words sequence > dim (B, S, dim_words)
          -training: boolean for dropout
        -Returns: 
          - A 3D tensor of pre-processed words sequence > dim (B, S, D)
    '''
    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) # division by the root of the d_model
    # addition of the positional encoding to the input x for the current decoding step:
    #x += self.pos_encoding_SMC[:,:,self.dec_timestep, :] # dim of positional encoding (1, num_particles,  num_positions, d_model)
    x += self.pos_encoding[:,self.dec_timestep, :] # dim of positional encoding (1, num_positions, d_model)
    x = self.dropout(x, training=training)
    return x
    
  def sample_previous_word(self, indices, prev_Z, training):
    '''sample Xk-1[Ik[l]] the embedding of a selected word at the previous position, with probabilities given by the last layer output
    Args:
      -prev_Z: Z0:k-1 > dim (B, P, S, D)
      -indices: used to recompute Zk-1 > dim (B,P)
      -training: boolean for dropout 
    Returns:
      - the batch of sampled word of shape (B, P, 1, D)
    '''
    # recompute Z0:k-1 with the set of indices
    if len(indices.shape)==3:
      indices=tf.squeeze(indices, axis=-1)
    
    # compute z(k-1) with resampling 
    Z_previousk=tf.gather(prev_Z[:,:,self.dec_timestep,:],indices, axis=1, batch_dims=1)
    # compute Z0:k-1
    if self.dec_timestep==0:
      prev_Z=tf.concat([tf.expand_dims(Z_previousk,axis=2),prev_Z[:,:,self.dec_timestep+1:,:]], axis=2)
    else:
      prev_Z_left=prev_Z[:,:,:self.dec_timestep,:]
      prev_Z=tf.concat([prev_Z_left,tf.expand_dims(Z_previousk,axis=2),prev_Z[:,:,self.dec_timestep+1:,:]], axis=2)
  
    #compute the log probabilities associated to the prediction at Z0:k-1
    sample_words_id=[]
    predictions_probas=self.forward_pass_from_layer(Z_previousk) # dimensions (batch_size, num_particles, vocabulary_size)

    # select a word id randomly with proba equal to predictions_probas. 
    for n in range(self.num_particles):
      # TRY TO ELIMINATE THIS FOR LOOP
      sample_words_id+=[tf.random.categorical(predictions_probas[:,n,:], num_samples=1)]

    sample_words_id=tf.concat(sample_words_id, axis=1) # dimensions (B, P)
    sample_words_id=tf.expand_dims(sample_words_id, axis=-1) # adding the seq_length dimension
    # > dim (B, P, 1)

    #preprocess the words with the word embedding & positional encoding.
    x=self.preprocess_words(sample_words_id, training) # dim (B, P, 1)
    return x
  
  def sample_and_keep_indices(self, prev_sampling_weights, ind_matrix): # add a mask argument?
    '''samples the set of N indices for doing the weights resampling
    adds this set of indices to the matrix of ancestor indices
    Args:
    -prev_Z: Z0:k-1 > dim (B, P, S, D)
    -prev_sampling_weights: w(k-1) > dim (B, P, V) or (B, P)? 
    -indice matrix: I0:k-1 > dim (B, P, S)
    Returns:
    -The current set of indices to do a forward pass on the Decoder Layer > dim (batch_size, num_particles)
    -The updated ancestor indices matrix > dim (batch_size, NUM_PARTICLES, seq_length)'''

    # generate a uniform distribution between 0 and 1 
    unif_distrib=tf.random.uniform(shape=tf.shape(prev_sampling_weights), maxval=1)
    # add a tf.stop_gradient(prev_sampling_weights)? 

    # use the function compute ancestral_index
    #indices=self.compute_ancestral_index(prev_sampling_weights, unif_distrib)

    # Sample current set of indices with proba proportional to prev_sampling_weights 
    indices=tf.random.categorical(prev_sampling_weights, self.num_particles) # shape (..., num_particles)

    # Add this set of indices to the indices matrix tensor: 
    indices=tf.cast(indices, tf.int32)
    indices=tf.expand_dims(indices, axis=-1)
    updated_ind_matrix=tf.concat([ind_matrix[:,:,:self.dec_timestep],indices,ind_matrix[:,:,self.dec_timestep+1:]],axis=-1)
  
    return indices, updated_ind_matrix # tf.stop_gradient(indices), tf.stop_gradient(updated_ind_matrix)

  def compute_ancestral_index(self, prev_sampling_weights, uniform_distribution):
    '''
    -args: 
      -prev_sampling_weights: float tensor of shape (B,P)
      -uniform distribution: float tensor of dimension (B,P)
    -returns:
      -the current set of M indices > tensor of shape (B,P)
    '''
    batch_size=tf.shape(prev_sampling_weights)[0]
    num_particles=tf.shape(prev_sampling_weights)[1]

    # compute w_bar:
    W_0=tf.expand_dims(tf.constant(0,dtype=tf.float32, shape=(batch_size,)), axis=-1)
    W_m=[tf.reduce_sum(prev_sampling_weights[:,:m], axis=-1) for m in range(num_particles)]
    W_m=tf.stack(W_m, axis=-1)
    W_m=tf.concat([W_0, W_m], axis=-1)

    indices_func=np.zeros(shape=tf.shape(prev_sampling_weights))

    # use tf.random.categorical (check if the same)

    # TRY TO REMOVE THIS DOUBLE FOR LOOP!!! > see this github repo as an example: 
    #https://github.com/rlabbe/filterpy/blob/master/filterpy/monte_carlo/resampling.py
    for b in range(batch_size): 
      for i in range(num_particles):
        unif=uniform_distribution[b,i]
        if unif>=W_m[b,i] and unif<=W_m[b,i+1]:
          indices_func[b,i]=i+1  

    indices_func=tf.convert_to_tensor(indices_func, dtype=tf.int32)
    
    ancestral_index=tf.stack([tf.reduce_sum(indices_func[:,:i], axis=-1) for i in range(num_particles)], axis=-1)
    ancestral_index=tf.cast(ancestral_index, tf.int32)

    return ancestral_index # tf.stop_gradient(ancestral_index)

  def call(self, x, PREV_SAMPL_WEIGHTS, K, V, TARGET_WORD_ID, training, 
           look_ahead_mask, padding_mask, previous_word=None):
    '''
    -args: 
        -x: input tensor of the multi-head attention mechanism (Z0:k-1)
        -prev_sampling_weights:
        -K: K0:k-1
        -V: VO:k-1
        -Target_word_ID: to compute the new set of sampling weights > dim (B,)
        -training: for dropout
        -look_ahead_mask: to mask the future decoding timesteps
        -padding_mask: to have fixed-size words sequence. 
        -previous_word: used only if are not sampling a word for computing attention but taking directly the previous word. 
    -returns: 
        -attention vectors (Z0:k, K0:k, V0:k)
        -sampling_weights wk
        -indices matrix I0:k
        -attention_weights_block: for attention vizualisation 
    '''
    
    # check if the decoder layer has been initialized. 
    # to remove? initialization done at Transformer level
           
    
    # FOR SMC: 1. SAMPLE THE SET OF N INDICES USING THE sample_indices function
    if self.dec_timestep==0:
      self.initialize_indices_matrix(tf.shape(K)[0], tf.shape(K)[2]) # TO REMOVE (OR NOT) INITIALIZATION DONE AT TRANSFORMER LEVEL. 
    
    indices, self.ind_matrix=self.sample_and_keep_indices(PREV_SAMPL_WEIGHTS, self.ind_matrix)
    
    # 2. Using the sampled indices, get the set of N sampled previous embedded words using the function 'sample_previous_words'
    if previous_word is None:
      # sampling method
      sample_word=self.sample_previous_word(indices, x, training) # dim (B, P, 1, D)
    else:
      # greedy method
      sample_word=self.preprocess_words(previous_word)
    
    # compute the self-attention vectors over x
    # tensors of dim (B,P,S,D)
    (attn1, K, V), attn_weights_block1= self.mha1(sample_word, sample_word, sample_word, sigma, look_ahead_mask, K, V)

    # RESAMPLE TRAJECTORIES OF ATTENTION PARAMETERS FROM THE INDICES MATRIX AND THE CURRENT SET OF WEIGHTS.  
    attn1=tf.gather(attn1, self.ind_matrix, batch_dims=1)
    attn1=tf.squeeze(attn1, axis=3)
    K=tf.gather(K, self.ind_matrix, batch_dims=1)
    V=tf.gather(V, self.ind_matrix, batch_dims=1)

    # (batch_size, NUM_PARTICLES, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x) # ERROR OF SHAPES HERE: here x shouldn't be the previous word but the whole (masked) sequence of words (or the previous_Z)

    ffn_output = self.ffn(out1)  # (batch_size, NUM_PARTICLES, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out1)  # (batch_size, NUM_PARTICLES, target_seq_len, d_model)
    
    # 3. FOR SMC: compute the new set of weights.
    if len(tf.shape(TARGET_WORD_ID))==1:
      TARGET_WORD_ID=tf.expand_dims(TARGET_WORD_ID, axis=-1)
    predictions_probas=self.output_layer(out3) # (batch_size, NUM_PARTICLES, target_seq_len, target_vocab_size)
    sampling_weights=tf.gather(predictions_probas[:,:,self.dec_timestep,:], TARGET_WORD_ID, axis=-1, batch_dims=1)
    #sampling_weights=predictions_probas[:,:,self.dec_timestep,TARGET_WORD_ID]
    
    # add tf.stop_gradient here? 

    self.dec_timestep+=1 # right place to put it? 
    
    # FOR SMC: RETURN ADDITIONAL PARAMETERS, THE CURRENT SAMPLING WEIGHTS (wk[l]), THE ANCESTOR INDICES MATRIX, K, AND V. 
    return (out3, K, V), sampling_weights, self.ind_matrix, attn_weights_block1 
    # or here tf.stop_gradient(sampling_weights)

In [0]:
d_model=512
num_heads=8
dff=2048
target_vocab_size=1000
num_particles=10
max_positional_encoding=5000
sigma=1

sample_decoder_layer = DecoderLayer_SMC(d_model, num_heads, dff,target_vocab_size, max_positional_encoding, sigma, num_particles, layer_num=0)

PREV_SAMPL_WEIGHTS=tf.random.uniform(shape=(64,num_particles), maxval=1)
K=tf.random.uniform((64, num_particles, 50, 512))
V=tf.random.uniform((64, num_particles, 50, 512))
TARGET_WORD_ID=tf.constant(34, shape=(64,1))

(sample_decoder_layer_output, K, V), sampling_weights, ind_matrix, _ = sample_decoder_layer(
    tf.random.uniform((64, 10, 50, 512)), PREV_SAMPL_WEIGHTS, K, V, TARGET_WORD_ID, training=False, look_ahead_mask=None, padding_mask=None)

sample_decoder_layer_output.shape  # (batch_size, target_seq_len, d_model)

In [0]:
#print(sampling_weights)
print(sampling_weights.shape)
#print(ind_matrix)
print(ind_matrix.shape)
#print(K)
print(K.shape)
# print(V)
print(V.shape)

### Decoder

 The `Decoder` consists of:
1.   Output Embedding
2.   Positional Encoding
3.   N decoder layers

The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.

In [0]:
# class Decoder that takes a SMC Decoder Layers as input. 
class Decoder(tf.keras.layers.Layer):
  '''Class Decoder with the Decoder architecture
  -args
    -num_layers: number of layers in the Decoder
    -d_model: model depth 
    -num_heads: number of heads in the multi-attention mechanism
    -dff: output dim of the feedforward network
    -target_vocab_size (for computing the sampling weights for the last layer (or all layers))
    -maxixum_position_encoding: to preprocess the words sequence (addition of positional embeddings)
    -sigma: constant noise for the Gaussian Noise model with reparametrization trick > float tensor of shape (B, P, S, D)
    -PF_algo: decoder-level (default) or layer-level
    -rate: dropout rate for feed-forward layers. 
    '''
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, sigma,  num_particles, PF_algo='decoder-level', rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    self.num_particles=num_particles
    self.sigma=sigma
    self.PF_algo=PF_algo
    
    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
    self.pos_encoding_SMC=positional_encoding_SMC(maximum_position_encoding, d_model, num_particles)


    # build the decoder architecture

    # for layer-level PF resampling mechanism: list of layers= DecoderLayer_SMC layers
    if PF_algo=='layer-level':
      print('building the Decoder with SMC mechanism done at layer level...')
      self.dec_layers = [DecoderLayer_SMC(d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, sigma, num_particles, l) 
                       for l in range(num_layers)]


    # for decoder-level PF resampling mechanism: 
    # list of layers= N-1 DecoderLayer layers for the first N-1 layers + one DecoderLayer_SMC as the top layer. 
    elif PF_algo=='decoder-level':
      print('building the Decoder with SMC mechanism done at decoder level...')
      self.dec_layers=[DecoderLayer(d_model, num_heads, dff, num_particles) 
                       for _ in range(num_layers-1)]+[DecoderLayer_SMC(d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, sigma, num_particles, num_layers)]
    else:
      assert False, "Invalid PF_algo: should be: layer-level/decoder-level"
    
    self.dropout = tf.keras.layers.Dropout(rate)
    self.dec_timestep=0
    self.output_layer=self.dec_layers[self.num_layers-1].output_layer
    self.decoder_initialized=False


  def set_decoder_inside_layers(self, decoder):
    '''only useful for layer-level resampling mechanism'''
    for i in range(self.num_layers):
      self.dec_layers[i].set_decoder(decoder)
    self.decoder_initialized=True

  def preprocess_words_input(self, x, training):
    '''pre_process sequence of words by adding embeddings + positional encoding
      -Args:
        -x: 3D tensor for the input sequence of words > dim (B, P, S, d_input) OR a 2D tensor of dim (B, P, S) (word_id instead of words...)
        -training: boolean for dropout
      -Returns:
        -A 3D tensor of the pre-processed sequence of words > dim (B, S, D)
    '''
    seq_len = tf.shape(x)[2]
    if len(tf.shape(x))==3:
      x = self.embedding(x) # (batch_size, num_particles, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) # division by the root of the d_model
    x += self.pos_encoding_SMC[:,:, :seq_len, :] # addition of the positional encoding to the input x
    x = self.dropout(x, training=training)
    return x
    
  def call(self, x, PREV_SAMPLING_WEIGHTS, K, V, TARGET_WORD_ID, training, 
           look_ahead_mask, padding_mask, previous_word=None):
    '''
    -args:
      -x: input of the first decoder layer (X0:k-1)
      -prev_sampling_weights: wk-1
      -K: K0:k-1
      -V: V0k-1
      -TARGET_WORD_ID: to compute resampling weights wk > shape (B,)
      -training
      -look_ahead_mask
      -padding_mask
      -previous_word: used if self.attn_method='greedy',
    -returns: 
      -attention vectors (Z[final_layer]0:k, K[final_layer]0:k, V[final_layer]0:k)
      -sampling weights wk
      -ancestor indices matrix I0:k
      -attention_weights: for attention visualization
    '''
    attention_weights = {}
    
    if self.PF_algo=='layer-level':
      for i in range(self.num_layers):
        # PF @ layer-level: compute attention vectors and sampling_weights for each layer
        # PREV_SAMPLING_WEIGHTS (from previous decoding step or sampling_weights from previous layer?)
        (x,K,V), sampling_weights, ind_matrix, block2 = self.dec_layers[i](x, PREV_SAMPLING_WEIGHTS, K, V, TARGET_WORD_ID, training,
                                               look_ahead_mask, padding_mask, previous_word)

        attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
    
    elif self.PF_algo=='decoder-level':
      # do the pre_processing step for x (not included at layer level)
      x=self.preprocess_words_input(x, training)
      for i in range(self.num_layers-1):
        # No PF & resampling mechanism @ the first N-1 layers
        x, block2=self.dec_layers[i](x, training, look_ahead_mask, padding_mask)
        attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

      # PF mechanism & resampling @ the last layer.
      (x,K,V), sampling_weights, ind_matrix, block2 = self.dec_layers[-1](x, PREV_SAMPLING_WEIGHTS, K, V, TARGET_WORD_ID, training,
                                               look_ahead_mask, padding_mask, previous_word)
    else:
      assert False, "Invalid PF_algo: should be: layer-level/decoder-level"
      
    attention_weights['decoder_layer{}_block2'.format(self.num_layers+1)] = block2
      
    self.dec_timestep+=1
    
    # x.shape == (batch_size, target_seq_len, d_model)
    return (x,K,V), tf.stop_gradient(sampling_weights), tf.stop_gradient(ind_matrix), attention_weights # tf.stop_gradient(sampling_weights)

In [0]:
num_particles=5
sample_decoder = Decoder(num_layers=2, d_model=64, num_heads=8, 
                         dff=2048, target_vocab_size=1000,
                         maximum_position_encoding=5000, sigma=1, num_particles=5, PF_algo='decoder-level')


PREV_SAMPL_WEIGHTS=tf.random.uniform(shape=(64,num_particles), maxval=1)
K=tf.random.uniform((64, num_particles, 20, 64))
V=tf.random.uniform((64, num_particles, 20, 64))
TARGET_WORD_ID=tf.constant(10, shape=(64,1))
Z=tf.random.uniform((64, num_particles, 20, 64))

(output, K, V), final_sampling_weights, ind_matrix, attn = sample_decoder(Z, PREV_SAMPL_WEIGHTS,
                              K,V, TARGET_WORD_ID,
                              training=False, look_ahead_mask=None, 
                              padding_mask=None)

output.shape

## Create the Transformer

Transformer consists of the encoder, decoder and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.

In [0]:
class Transformer(tf.keras.Model):
  '''class for the Transformer Model
  -args
    -num_layers: number of decoder layers
    -d_model: model_depth
    -num_heads: number of heads in the multi-head attention mechanism. 
    -dff: output dimension of the feed-forward layer. 
    -target_vocab_size:for computing the resampling weights
    -pe_target: maximum_positional_encoding
    -num_particles: number of particles generated. 
    -sigma: constant noise for the Gaussian Noise (reparametrization trick) > float tensor of shape (B, P, S, D). 
    -PF_algo: decoder-level/ layer-level > PF resampling mechanism done @ decoder-level (only on last layer) or layer level.
    -attn_method: greedy (use previous for computing attention vectors) or sampling (sample a word with proba equal to log_probas)
    -rate: dropout rate for the feed-forward layer. 
    '''
  def __init__(self, num_layers, d_model, num_heads, dff, 
               target_vocab_size, pe_target, num_particles, sigma,  PF_algo='layer-level', attn_method='sampling', rate=0.1):
    super(Transformer, self).__init__()

    self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                           target_vocab_size, pe_target, sigma, num_particles, PF_algo, rate)
    
    #get the output layer of the last decoder layer as final layer. 
    #self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    self.final_layer=self.decoder.dec_layers[num_layers-1].output_layer

    self.num_particles=num_particles
    self.d_model=d_model

    self.initialize=False

  def preprocess_words(self, x, dec_timestep, training):
    '''add words embeddings and positional encodings:
        -Args: 
          -x: 3D tensor of words sequence > dim (B, S, dim_words)
          -training: boolean for dropout
        -Returns: 
          - A 3D tensor of pre-processed words sequence > dim (B, S, D)
    '''
    x = self.decoder.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) # division by the root of the d_model
    # addition of the positional encoding to the input x for the current decoding step:
    #x += self.pos_encoding_SMC[:,:,self.dec_timestep, :] # dim of positional encoding (1, num_particles,  num_positions, d_model)
    x += self.decoder.pos_encoding[:,dec_timestep, :] # dim of positional encoding (1, num_positions, d_model)
    x = self.decoder.dropout(x, training=training)
    return tf.reshape(x, shape=[tf.shape(x)[0], tf.shape(x)[2], tf.shape(x)[1], tf.shape(x)[-1]])

  def initialize_attn_SMC_parameters(self, batch_size, seq_length, initial_word_tensor):
    ''' initialize the attention parameters of the Transformer
          -Args: 
            -batch_size
            -seq_length: longueur of input sequence of words 
            -initial_word_tensor: 1D tensor of dim (batch_size) with the initial words for each element of the batch. 
            Used to compute the initial set of weights 
          -Returns      dtype=int32)>
 
            -Z0, K0, V0 (dim (B,P,S,D)) W0 (dim (B,P)), initial indices matrix (dim (B, P, S))
    '''
    # initialize K0, V0, Z0 (=V0)
    K=tf.random.uniform(shape=(batch_size, self.num_particles, seq_length, self.d_model), maxval=1)
    V=tf.random.uniform(shape=(batch_size, self.num_particles, seq_length, self.d_model), maxval=1)
    Z=V
    # initialize w0
    log_probas=self.decoder.dec_layers[0].forward_pass_from_layer(Z) # shape (B, P, S, V)
    log_probas_initial=log_probas[:,:,0,:]
    initial_word_tensor=tf.expand_dims(initial_word_tensor, axis=-1)
    initial_weights=tf.squeeze(tf.gather(log_probas_initial, initial_word_tensor, axis=-1, batch_dims=1),axis=-1)
    # tf.stop_gradient(initial_weights)? 

    # call the initialization of the ancestor indices matrix
    ind_matrix_init=self.decoder.dec_layers[-1].initialize_indices_matrix(batch_size, seq_length)
    # tf.stop_gradient(ind_matrix_init)? 

    self.initialize=True

    return (Z, K, V), initial_weights, ind_matrix_init

  def compute_average_weighted_output(self, decoder_output, resampling_weights):
      '''
      -Args:
        - decoder_output: current output of the decoder (M particles) > dims (B, P, D)
        - resampling_weights: output sampling weights of the decoder > dims (B,P)
      -Returns:
        - the weighted average of the decoder's M particles > dims (B,D)
      '''
      # USELESS ACTUALLY HERE. 
      mean_output = tf.squeeze(tf.matmul(decoder_output, resampling_weights, transpose_a=True))
      sum_weights = tf.tile(tf.expand_dims(tf.reduce_sum(resampling_weights, axis=-1), axis=-1), 
                            [1, mean_output.shape[-1], 1])
      sum_weights = tf.reduce_sum(resampling_weights, axis=1)
      mean_output= mean_output / sum_weights
      return mean_output
    
  def call(self, input_tensor, prev_sampling_weights, K, V, targets, training, 
           look_ahead_mask, dec_padding_mask):
    # remove prev_sampling_weights, K, V (initialize them instead)
    # target_word_id here is instead the targets tensor (tensor shift from one compared to the input tensor)
    # add a decoding_timestep argument? 
    '''
    -args:
      -input tensor: transformer input data : sequence of words id. > shape (B,S)
      -targets: target tensor
      -training: for dropout layers
      -look_ahead_mask: 
      -dec_padding_mask: 
    -returns
      -final_output: Y0:S > shape (?, P, S, V)
      -attention_weights: for visualisation
    '''

    # initialize the attention parameters
    batch_size=tf.shape(input_tensor)[0]
    seq_length=tf.shape(input_tensor)[1]
    # get initial_word_id instead of initial_word? 
    (Z0, K0, V0), w0, I0=self.initialize_attn_SMC_parameters(batch_size, seq_length, input_tensor[:,0])
    # tf.stop_gradient(w0)? 

    # add the function 'set_decoder_inside_layers?
    self.decoder.set_decoder_inside_layers(self.decoder) # DOES THIS WORK???

    # do a loop over decoding_timestep...
    for t in range(tf.shape(targets)[1]): 
      target_word_id=targets[:,t]
      sampling_weights=w0
      K=K0
      V=V0
      print('decoding timestep:', t)

      if len(tf.shape(input_tensor)) < 4: 
        # transform the 2D input_tensor (dim (B,S)) into a 4D input_tensor (dim (B,P,S,D))
        input_tensor=tf.expand_dims(input_tensor, axis=-1)
        input_tensor=self.preprocess_words(input_tensor, t, training=training) # dim (B, P, S)
        input_tensor=tf.tile(input_tensor, multiples=[1,self.num_particles,1,1])

      # dec_output.shape == (batch_size, NUM_PARTICLES, tar_seq_len, d_model)
      (dec_output, K, V), sampling_weights, ind_matrix, attention_weights= self.decoder(
        input_tensor, sampling_weights, K, V, target_word_id, training, look_ahead_mask, dec_padding_mask)
      
      # compute the predicted word from the decoder output
      # REPLACE BY TEACHER FORCING
      # cf tutorial notebook on NMt with attention:
      #dec_input = tf.expand_dims(targ[:, t], 1)
      final_output = self.final_layer(dec_output)  # (batch_size, NUM_PARTICLES, tar_seq_len, target_vocab_size)
      current_dec_output=final_output[:,:,t,:]
      average_output=self.compute_average_weighted_output(current_dec_output, sampling_weights) # dims (B, V) # LIGN TO DEBUG. 
    
      # concatentate the predicted_id to the input_tensor which is given to the decoder
      # as its input.
      # find the predicted_word and add the word embedding
      predicted_word=tf.cast(tf.argmax(average_output, axis=-1), tf.int32)
      predicted_word = self.decoder.embedding(tf.expand_dims(predicted_word, axis=-1))
      predicted_word=tf.tile(tf.expand_dims(predicted_word, axis=1), [1,self.num_particles,1,1])

      # add it to the input_tensor
      input_tensor=tf.cast(input_tensor, dtype=tf.float32)
      input_tensor=tf.concat([input_tensor[:,:,:t+1,:], predicted_word, input_tensor[:,:,t+2:,:]], axis=2)

    return final_output, attention_weights

In [0]:
sample_transformer = Transformer(
    num_layers=2, d_model=512, num_heads=8, dff=2048, sigma=1, target_vocab_size=8000, 
    pe_target=6000, num_particles=10)

temp_data =tf.cast(tf.zeros(shape=(64,51)), dtype=tf.int32)
input_tensor=temp_data[:,:51]
targets=temp_data[:, 1:]

sample_transformer.initialize

PREV_SAMPL_WEIGHTS=tf.random.uniform(shape=(64,num_particles), maxval=1)
K=tf.random.uniform((64, num_particles, 50, 512))
V=tf.random.uniform((64, num_particles, 50, 512))

fn_out, _ = sample_transformer(input_tensor, PREV_SAMPL_WEIGHTS, K, V, targets, training=False, 
                               look_ahead_mask=None,
                               dec_padding_mask=None)

fn_out.shape  # (batch_size, tar_seq_len, target_vocab_size)

## Set hyperparameters

To keep this example small and relatively fast, the values for *num_layers, d_model, and dff* have been reduced. 

The values used in the base model of transformer were; *num_layers=6*, *d_model = 512*, *dff = 2048*. See the [paper](https://arxiv.org/abs/1706.03762) for all the other versions of the transformer.

Note: By changing the values below, you can get the model that achieved state of the art on many tasks.

In [0]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = tokenizer_pt.vocab_size + 2
target_vocab_size = tokenizer_en.vocab_size + 2
dropout_rate = 0.1

## additional hyperparameters for the SMC mechanism
#### Gaussian Noise
* initialization of the standard deviation
* exact 'place' where the gaussian noise is added 

#### number of particles 

#### attention parameters computation
* sampling a word with proba proportional to log probabilites of previous timestep (sampling method) versus using directly the previous word (greedy method). 

# 2. TRAINING OF THE TRANSFORMER MODEL. (to modify for the SMC transformer). 

### Custom training

In [0]:
def loss(predicted_y, target_y):
  return tf.reduce_mean(tf.square(predicted_y - target_y))

In [0]:
def train(model, inputs, outputs, learning_rate):
  with tf.GradientTape() as t:
    current_loss = loss(model(inputs), outputs)
  dW, db = t.gradient(current_loss, [model.W, model.b])
  model.W.assign_sub(learning_rate * dW)
  model.b.assign_sub(learning_rate * db)

In [0]:
def train(model, inputs, outputs, learning_rate):
  with tf.GradientTape() as t:
    current_loss = loss(model(inputs), outputs)
  dW, db = t.gradient(current_loss, [model.W, model.b])
  model.W.assign_sub(learning_rate * dW)
  model.b.assign_sub(learning_rate * db)

In [0]:
model = Model()

# Collect the history of W-values and b-values to plot later
Ws, bs = [], []
epochs = range(10)
for epoch in epochs:
  Ws.append(model.W.numpy())
  bs.append(model.b.numpy())
  current_loss = loss(model(inputs), outputs)

  train(model, inputs, outputs, learning_rate=0.1)
  print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %
        (epoch, Ws[-1], bs[-1], current_loss))

# Let's plot it all
plt.plot(epochs, Ws, 'r',
         epochs, bs, 'b')
plt.plot([TRUE_W] * len(epochs), 'r--',
         [TRUE_b] * len(epochs), 'b--')
plt.legend(['W', 'b', 'True W', 'True b'])
plt.show()

## Optimizer
### TO MODIFY FOR SMC WITH A CUSTOM GRADIENT

Use the Adam optimizer with a custom learning rate scheduler according to the formula in the [paper](https://arxiv.org/abs/1706.03762).

$$\Large{lrate = d_{model}^{-0.5} * min(step{\_}num^{-0.5}, step{\_}num * warmup{\_}steps^{-1.5})}$$


In [0]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()
    
    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps
    
  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)
    
    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [0]:
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

In [0]:
temp_learning_rate_schedule = CustomSchedule(d_model)

plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")

## Loss and metrics

Since the target sequences are padded, it is important to apply a padding mask when calculating the loss.

In [0]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

In [0]:
def loss_function(real, pred):
  # TO CHANGE WITH A CUSTOM LOSS
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  # keep the loss above and add the 'left part' of the language_models.tex
  
  return tf.reduce_mean(loss_)

  # check if you compute a loss at layer levl (can be done in TF2) and sum over layers
  # or do a global loss

In [0]:
def loss_left_part_SMC(real, pred):
  #see formula on .tex file
  #to compute the determinant: https://www.tensorflow.org/api_docs/python/tf/linalg/det
  

In [0]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')

## Training and checkpointing

In [0]:
transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)

In [0]:
def create_masks(inp, tar):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)
  
  # Used in the 2nd attention block in the decoder.
  # This padding mask is used to mask the encoder outputs.
  dec_padding_mask = create_padding_mask(inp)
  
  # Used in the 1st attention block in the decoder.
  # It is used to pad and mask future tokens in the input received by 
  # the decoder.
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
  
  return enc_padding_mask, combined_mask, dec_padding_mask

Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every `n` epochs.

In [0]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

The target is divided into tar_inp and tar_real. tar_inp is passed as an input to the decoder. `tar_real` is that same input shifted by 1: At each location in `tar_input`, `tar_real` contains the  next token that should be predicted.

For example, `sentence` = "SOS A lion in the jungle is sleeping EOS"

`tar_inp` =  "SOS A lion in the jungle is sleeping"

`tar_real` = "A lion in the jungle is sleeping EOS"

The transformer is an auto-regressive model: it makes predictions one part at a time, and uses its output so far to decide what to do next. 

During training this example uses teacher-forcing (like in the [text generation tutorial](./text_generation.ipynb)). Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.

As the transformer predicts each word, *self-attention* allows it to look at the previous words in the input sequence to better predict the next word.

To prevent the model from peaking at the expected output the model uses a look-ahead mask.

In [0]:
EPOCHS = 20

In [0]:
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]
  
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
  with tf.GradientTape() as tape:
    predictions, _ = transformer(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables) # DOES THIS NEED TO CHANGE?  
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
  
  train_loss(loss)
  train_accuracy(tar_real, predictions)

Portuguese is used as the input language and English is the target language.

In [0]:
for epoch in range(EPOCHS):
  start = time.time()
  
  train_loss.reset_states()
  train_accuracy.reset_states()
  
  # inp -> portuguese, tar -> english
  for (batch, (inp, tar)) in enumerate(train_dataset):
    train_step(inp, tar)
    
    if batch % 50 == 0:
      print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
          epoch + 1, batch, train_loss.result(), train_accuracy.result()))
      
  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))
    
  print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                train_loss.result(), 
                                                train_accuracy.result()))

  print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

# 3. EVALUATION/INFERENCE

![formula for inference](https://colab.research.google.com/drive/1K61J3rRbBrJW9dhy24jpwHPiKCKUluTP#scrollTo=mI2yf1lWkUaZ/content/IMG_7002.JPG)

### The prediction loop (from text generation with RNN)

The following code block generates the text:

* It Starts by choosing a start string, initializing the RNN state and setting the number of characters to generate.

* Get the prediction distribution of the next character using the start string and the RNN state.

* Then, use a categorical distribution to calculate the index of the predicted character. Use this predicted character as our next input to the model.

* The RNN state returned by the model is fed back into the model so that it now has more context, instead than only one word. After predicting the next word, the modified RNN states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.


Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences.

In [0]:
def generate_text(model, start_string):
  # Evaluation step (generating text using the learned model)

  # Number of characters to generate
  num_generate = 1000

  # Converting our start string to numbers (vectorizing)
  input_eval = [char2idx[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)

  # Empty string to store our results
  text_generated = []

  # Low temperatures results in more predictable text.
  # Higher temperatures results in more surprising text.
  # Experiment to find the best setting.
  temperature = 1.0
  
  # TO MODIFY USING THE SMC TRANSFORMER MODEL. 
  # Here batch size == 1
  model.reset_states()
  for i in range(num_generate):
      predictions = model(input_eval)
      # remove the batch dimension
      predictions = tf.squeeze(predictions, 0)

      # using a categorical distribution to predict the word returned by the model
      predictions = predictions / temperature
      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()

      # We pass the predicted word as the next input to the model
      # along with the previous hidden state
      input_eval = tf.expand_dims([predicted_id], 0)

      text_generated.append(idx2char[predicted_id])

  return (start_string + ''.join(text_generated))

In [0]:
print(generate_text(model, start_string=u"ROMEO: "))

## Evaluate

The following steps are used for evaluation:

* **Encode the input sentence using the Portuguese tokenizer (`tokenizer_pt`). Moreover, add the start and end token so the input is equivalent to what the model is trained with. This is the encoder input.**
* **The decoder input is the `start token == tokenizer_en.vocab_size`.**
* Calculate the padding masks and the look ahead masks.
* **The `decoder` then outputs the predictions by looking at the `encoder output` and its own output (self-attention).**
* Select the last word and calculate the argmax of that.
* **Concatentate the predicted word to the decoder input as pass it to the decoder.**
* **In this approach, the decoder predicts the next word based on the previous words it predicted.**

Note: The model used here has less capacity to keep the example relatively faster so the predictions maybe less right. To reproduce the results in the paper, use the entire dataset and base transformer model or transformer XL, by changing the hyperparameters above.

In [0]:
def evaluate(inp_sentence):
  start_token = [tokenizer_pt.vocab_size]
  end_token = [tokenizer_pt.vocab_size + 1]
  
  # inp sentence is portuguese, hence adding the start and end token
  inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
  encoder_input = tf.expand_dims(inp_sentence, 0)
  
  # as the target is english, the first word to the transformer should be the
  # english start token.
  decoder_input = [tokenizer_en.vocab_size]
  output = tf.expand_dims(decoder_input, 0)
    
  for i in range(MAX_LENGTH):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)
  
    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)
    
    # select the last word from the seq_len dimension
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
    
    # return the result if the predicted_id is equal to the end token
    if predicted_id == tokenizer_en.vocab_size+1:
      return tf.squeeze(output, axis=0), attention_weights
    
    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return tf.squeeze(output, axis=0), attention_weights

In [0]:
def plot_attention_weights(attention, sentence, result, layer):
  fig = plt.figure(figsize=(16, 8))
  
  sentence = tokenizer_pt.encode(sentence)
  
  attention = tf.squeeze(attention[layer], axis=0)
  
  for head in range(attention.shape[0]):
    ax = fig.add_subplot(2, 4, head+1)
    
    # plot the attention weights
    ax.matshow(attention[head][:-1, :], cmap='viridis')

    fontdict = {'fontsize': 10}
    
    ax.set_xticks(range(len(sentence)+2))
    ax.set_yticks(range(len(result)))
    
    ax.set_ylim(len(result)-1.5, -0.5)
        
    ax.set_xticklabels(
        ['<start>']+[tokenizer_pt.decode([i]) for i in sentence]+['<end>'], 
        fontdict=fontdict, rotation=90)
    
    ax.set_yticklabels([tokenizer_en.decode([i]) for i in result 
                        if i < tokenizer_en.vocab_size], 
                       fontdict=fontdict)
    
    ax.set_xlabel('Head {}'.format(head+1))
  
  plt.tight_layout()
  plt.show()

In [0]:
def translate(sentence, plot=''):
  result, attention_weights = evaluate(sentence)
  
  predicted_sentence = tokenizer_en.decode([i for i in result 
                                            if i < tokenizer_en.vocab_size])  

  print('Input: {}'.format(sentence))
  print('Predicted translation: {}'.format(predicted_sentence))
  
  if plot:
    plot_attention_weights(attention_weights, sentence, result, plot)

In [0]:
translate("este é um problema que temos que resolver.")
print ("Real translation: this is a problem we have to solve .")

In [0]:
translate("os meus vizinhos ouviram sobre esta ideia.")
print ("Real translation: and my neighboring homes heard about this idea .")

In [0]:
translate("vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.")
print ("Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .")

You can pass different layers and attention blocks of the decoder to the `plot` parameter.

In [0]:
translate("este é o primeiro livro que eu fiz.", plot='decoder_layer4_block2')
print ("Real translation: this is the first book i've ever done.")

## Summary

In this tutorial, you learned about positional encoding, multi-head attention, the importance of masking and how to create a transformer.

Try using a different dataset to train the transformer. You can also create the base transformer or transformer XL by changing the hyperparameters above. You can also use the layers defined here to create [BERT](https://arxiv.org/abs/1810.04805) and train state of the art models. Futhermore, you can implement beam search to get better predictions.