# Creating a GRU model using Trax: 

In this notebook we will be using Trax's layers to implement the GRU architecture 

In [1]:
#!pip install -q -U trax


In [2]:
#uncomment to install trax 
#%pip install trax

In [4]:
import trax
from trax import layers as tl

# A helper function that prints information for every layer (sublayer within `Serial`):
from utils import show_layers

ModuleNotFoundError: No module named 'trax'

Trax allows us to define neural network architectures by stacking layers (similarly to other libraries such as Keras). For this the `Serial()` which is the Keras's Sequential "equivalent" is often used as it is a combinator that allows to stack layers serially using function composition.



## GRU MODEL

To create a `GRU` model we will need the following Trax layers (Documentation link attached with each layer name):
   - [`ShiftRight`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.attention.ShiftRight) Shifts the tensor to the right by padding with zeros on axis 1. The `mode` should be specified and it refers to the context in which the model is being used. Possible values are: 'train', 'eval' or 'predict', predict mode is for fast inference. Defaults to "train". With this layer, the input sequence is shifted to the right so, at every time step, the GRU cell doesn't get as input the same element that needs to be predicted. Note that this layer isn't always neccessary, its inclusion depends on the NLP task at hand.

   - [`Embedding`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Embedding) Maps discrete tokens to vectors. It will have shape `(vocabulary length X dimension of output vectors)`. The dimension of output vectors (also called `d_feature`) is the number of elements in the word embedding.
   - [`GRU`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.rnn.GRU) The GRU layer. It leverages another Trax layer called [`GRUCell`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.rnn.GRUCell). The hidden state dimension should be specified as `n_units` and should match the number of elements in the word embedding --by design in Trax. If we want to stack two consecutive GRU layers, it can be done by using python's list comprehension. to get the following architecture

   <img src="images/3_grus.png" width="400"/>

   - [`Dense`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Dense) Vanilla Dense layer.
   - [`LogSoftMax`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.LogSoftmax) Log Softmax function.

Putting everything together the GRU model will look like this:

In [None]:
mode = 'train'
vocab_size = 256
model_dimension = 512
n_layers = 3

GRU = tl.Serial(
      # Do remember to pass the mode parameter if you are using it for interence/test 
      # as default is train 
      tl.ShiftRight(mode=mode), 
      tl.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
      # Stack 3 GRU layers together
      [tl.GRU(n_units=model_dimension) for _ in range(n_layers)], 
      tl.Dense(n_units=vocab_size),
      tl.LogSoftmax()
    )

In [None]:
show_layers(GRU)

That's it! a full GRU architecture with 5 lines using Trax !!