# Trax 

Trax is a concise framework, built on TensorFlow, for end to end machine learning. The key building blocks are layers and combinators. This notebook is just a taste to get familiar with the Trax framework and use some of its basic building blocks.

## Background

### Why Trax and not TensorFlow or PyTorch?

TensorFlow and PyTorch are both extensive frameworks that can do almost anything in deep learning. They offer a lot of flexibility, but that often means verbosity of syntax and extra time to code.

Trax is much more concise. It runs on a TensorFlow backend but allows us to train models with 1 line commands. Trax also runs end to end. This means we can focus on learning, instead of spending hours on the peculiarities of big framework implementation.

### Why not Keras then?

Keras is now part of Tensorflow itself from 2.0 onwards. Also, trax is good for implementing new state of the art algorithms like Transformers, Reformers, BERT because it is actively maintained by Google Brain Team for advanced deep learning tasks. It runs smoothly on CPUs,GPUs and TPUs as well with comparatively lesser modifications in code.

### How to Code in Trax
Building models in Trax relies on 2 key concepts:- **layers** and **combinators**.
Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax combinators, allowing us to build layers and models of any complexity.

### Trax, JAX, TensorFlow and Tensor2Tensor

As mentionned earlier, Trax uses Tensorflow as a backend, but it also uses the JAX library to speed up computation too. "JAX is like an enhanced and optimized version of numpy. 

Tensor2Tensor on the other hand started as an end to end solution much like how Trax is designed, but it grew unwieldy and complicated. So Trax as the new improved version that operates much faster and simpler.

### Resources

- Trax source code can be found on Github: [Trax](https://github.com/google/trax)
- JAX library: [JAX](https://jax.readthedocs.io/en/latest/index.html)


## Installing Trax

Trax has dependencies on JAX and some libraries like JAX which are yet to be supported in [Windows](https://github.com/google/jax/blob/1bc5896ee4eab5d7bb4ec6f161d8b2abb30557be/README.md#installation) but work well in Ubuntu and MacOS. We would suggest that if you are working on Windows, try to install Trax on WSL2. 

Official maintained documentation - [trax-ml](https://trax-ml.readthedocs.io/en/latest/) not to be confused with this [TraX](https://trax.readthedocs.io/en/latest/index.html)

In [1]:
%pip install trax

Collecting trax
  Using cached trax-1.4.1-py2.py3-none-any.whl (637 kB)
Collecting absl-py (from trax)
  Obtaining dependency information for absl-py from https://files.pythonhosted.org/packages/01/e4/dc0a1dcc4e74e08d7abedab278c795eef54a224363bb18f5692f416d834f/absl_py-2.0.0-py3-none-any.whl.metadata
  Using cached absl_py-2.0.0-py3-none-any.whl.metadata (2.3 kB)
Collecting funcsigs (from trax)
  Using cached funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)
Collecting gin-config (from trax)
  Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)
Collecting gym (from trax)
  Using cached gym-0.26.2.tar.gz (721 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting jax (from trax)
  Obtaining dependency inf

ERROR: Could not install packages due to an OSError: [Errno 28] No space left on device



## Imports

In [2]:
import numpy as np  # regular old numpy (it exists a trax implementation of numpy)

from trax import layers as tl  # core building block
from trax import shapes  # data signatures: dimensionality and type
from trax import fastmath  # uses jax, offers enhanced numpy (on steroids :p)

ModuleNotFoundError: No module named 'trax'

In [None]:
# Trax version  
!pip list | grep trax

## Layers
Layers are the core building blocks in Trax : they are the base classes.

They take inputs, compute functions/custom calculations and return outputs.

Let's inspect layer properties.

### Relu Layer

First let's build a relu activation function as a layer. A layer like this is one of the simplest types. There is no object initialization so it works just like a math function.

> In Trax, activation functions are also layers, 

In [None]:
# Create a relu trax layer
relu = tl.Relu()

# Inspect properties
print("-- Properties --")
print("name :", relu.name)
print("expected inputs :", relu.n_in)
print("promised outputs :", relu.n_out, "\n")

# Inputs
x = np.array([-2,-1,0,1,2])

print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = relu(x)
print("-- Outputs --")
print("y :", y)

### Concatenate Layer
Now let's build a layer that takes 2 inputs. Notice the change in the expected inputs property from 1 to 2.

In [None]:
# Create a concatenate trax layer
concat = tl.Concatenate()

print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")

# Outputs
y = concat([x1, x2])
print("-- Outputs --")
print("y :", y)

## Layers are Configurable
We can change the default settings of layers. For example, we can change the expected inputs for a concatenate layer from 2 to 3 using the optional parameter `n_items`.

In [None]:
# Configure a concatenate layer
concat_3 = tl.Concatenate(n_items=3)  # configure the layer's expected inputs

print("-- Properties --")
print("name :", concat_3.name)
print("expected inputs :", concat_3.n_in)
print("promised outputs :", concat_3.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
x3 = x2 * 0.99
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3, "\n")

# Outputs
y = concat_3([x1, x2, x3])
print("-- Outputs --")
print("y :", y)

## Layers can have Weights
Some layer types include mutable weights and biases that are used in computation and training. Layers of this type require initialization before use.

For example the `LayerNorm` layer calculates normalized data, that is also scaled by weights and biases. During initialization we pass the data shape and data type of the inputs, so the layer can initialize compatible arrays of weights and biases.

In [None]:
# Layer initialization
norm = tl.LayerNorm()

# We first must know what the input data will look like
x = np.array([0, 1, 2, 3], dtype="float")

# Use the input data signature to get shape and type for initializing weights and biases
norm.init(shapes.signature(x)) # We need to convert the input datatype from usual tuple to trax ShapeDtype

print("Normal shape:",x.shape, "Data Type:",type(x.shape))
print("Shapes Trax:",shapes.signature(x),"Data Type:",type(shapes.signature(x)))

# Inspect properties
print("-- Properties --")
print("name :", norm.name)
print("expected inputs :", norm.n_in)
print("promised outputs :", norm.n_out)
# Weights and biases
print("weights :", norm.weights[0])
print("biases :", norm.weights[1], "\n")

# Inputs
print("-- Inputs --")
print("x :", x)

# Outputs
y = norm(x)
print("-- Outputs --")
print("y :", y)

## Custom Layers

Things start getting more interesting with custom layers
We can create our own custom layers too and define custom functions for computations by using `tl.Fn`. 

In [None]:
# Define a custom layer
# In this example we will create a layer to calculate the input times 2

def TimesTwo():
    layer_name = "TimesTwo" #identify the custom layer by a name

    # Custom function for the custom layer
    def func(x):
        return x * 2

    return tl.Fn(layer_name, func)


# Test it
times_two = TimesTwo()

# Inspect properties
print("-- Properties --")
print("name :", times_two.name)
print("expected inputs :", times_two.n_in)
print("promised outputs :", times_two.n_out, "\n")

# Inputs
x = np.array([1, 2, 3])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = times_two(x)
print("-- Outputs --")
print("y :", y)

## Combinators
We can combine layers to build more complex layers. Trax provides a set of objects named combinator layers to make this happen. Combinators are themselves layers, so behavior commutes.



### Serial Combinator
This is the most common and easiest to use. 

For example we could build a simple neural network by combining layers into a single layer using the `Serial` combinator (just like the Sequential in Keras). This new layer then acts just like a single layer, so we can inspect intputs, outputs and weights. Or even combine it into another layer! Combinators can then be used as trainable models.

**Just as there is a serial combinator, there is a parallel combinator as well, like Keras's functional API**

In [None]:
# Uncomment any of them to see information regarding the function
# help(tl.Serial)
# help(tl.Parallel)

In [None]:
# Serial combinator
serial = tl.Serial(
    tl.LayerNorm(),         # normalize input
    tl.Relu(),              # convert negative values to zero
    times_two,              # the custom layer we created above,
    )

# Initialization
x = np.array([-2, -1, 0, 1, 2]) #input

serial.init(shapes.signature(x)) #initialising serial instance

print("-- Serial Model --")
print(serial,"\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
print("weights & biases:", serial.weights, "\n")

# Inputs
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = serial(x)
print("-- Outputs --")
print("y :", y)

## JAX
Trax, as mentionned earlier, has an enhanced numpy implementation compatible with JAX. Both tend to use the alias np so watch those import blocks.

**Note: There are certain things which are still not possible in fastmath.numpy which can be done in numpy so we will switch to the old numpy whenever needed**

In [None]:
# Numpy vs fastmath.numpy have different data types

# Regular numpy
x_numpy = np.array([1, 2, 3])
print("good old numpy : ", type(x_numpy), "\n")

# Fastmath and jax enhanced numpy
x_jax = fastmath.numpy.array([1, 2, 3])
print("jax trax numpy : ", type(x_jax))

Next you can see a simple vanilla NN(neural network) architecture containing 1 hidden(dense) layer with 128 cells and output (dense) layer with 10 cells on which we apply the final layer of logsoftmax.

In [None]:
#multi layer perceptron
mlp = tl.Serial(
  tl.Dense(128),
  tl.Relu(),
  tl.Dense(10),
  tl.LogSoftmax()
)

Each of the layers within the `Serial` combinator layer is considered a sublayer. 
We can try printing this object:

In [None]:
print(mlp)

Printing the model gives us the exact same information as the model's definition itself.

By just looking at the definition we can clearly see what is going on inside the neural network. Trax is very straightforward in the way a network is defined, that is one of the things that makes it awesome! 