# Trax

Trax is much more concise. It runs on TensorFlow backend but allows yout to train models with 1 line commands. It also runs end to end. 

We can focus on learning instead of spending much hours on big frameword implementation. 

It is good for implementing new SOTA model like Transformers, Reformers, BERT. It is maintined by the Google Brain Team. 

### Trax based on 2 concepts:
- Layers
  - Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax Combinators, allowing to build layers and models of any complexity. 
- Combinators


**Trax:**
- It uses Tensorflow as a backend, but it uses JAX library to speed up computations. JAX is an enhanced and optimized version of numpy. 

In [1]:
import numpy as np

from trax import layers as tl
from trax import shapes
from trax import fastmath

Activation functions are also layers in Trax.

### RELU layer

In [4]:
relu = tl.Relu()

print(relu.name)
print(relu.n_in)
print(relu.n_out)



# Inputs
x = np.array([-2, -1, 0, 1, 2])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = relu(x)
print("-- Outputs --")
print("y :", y)

Serial
1
1
-- Inputs --
x : [-2 -1  0  1  2] 

-- Outputs --
y : [0 0 0 1 2]


## Concatenate layer

In [6]:
concat = tl.Concatenate()
print(concat.name)
print(concat.n_in)
print(concat.n_out)


# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")

# Outputs
y = concat([x1, x2])
print("-- Outputs --")
print("y :", y)

concat2 = tl.Concatenate(n_items=3)


Concatenate
2
1
-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.] 

-- Outputs --
y : [-10. -20. -30.   1.   2.   3.]


## Layers can have weights

In [7]:
norm = tl.LayerNorm()

x = np.array([0, 1, 2, 3], dtype="float")

# Convert the input datatype from usual tuple to trax ShapeDtype 
norm.init(shapes.signature(x))

print(x.shape)
print(shapes.signature(x), type(shapes.signature(x)))

norm.name, norm.n_in, norm.n_out, norm.weights[0], norm.weights[1]


(4,)
ShapeDtype{shape:(4,), dtype:float64} <class 'trax.shapes.ShapeDtype'>


  scale = jnp.ones(features, dtype=input_signature.dtype)
  bias = jnp.zeros(features, dtype=input_signature.dtype)


('LayerNorm',
 1,
 1,
 Array([1., 1., 1., 1.], dtype=float32),
 Array([0., 0., 0., 0.], dtype=float32))

## Custom Layers

In [9]:
def TimesTwo():
    layer_name = "TimesTwo"  # always give the custom layer a name to identify 

    def func(x):
        return x * 2

    return tl.Fn(layer_name, func)

times_two = TimesTwo()

times_two.name, times_two.n_in, times_two.n_out

('TimesTwo', 1, 1)

## Combinators

We can combine layers to build more complex layers. 


    * Serial Combinator

In [10]:
serial = tl.Serial(
    tl.LayerNorm(),
    tl.Relu(),
    times_two,

    tl.Dense(n_units=2),
    tl.Dense(n_units=1),
    tl.LogSoftmax()
)


# Initialization
x = np.array([-2, -1, 0, 1, 2]) #input
serial.init(shapes.signature(x)) #initialising serial instance

print("-- Serial Model --")
print(serial,"\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
print("weights & biases:", serial.weights, "\n")

# Inputs
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = serial(x)
print("-- Outputs --")
print("y :", y)

  scale = jnp.ones(features, dtype=input_signature.dtype)
  bias = jnp.zeros(features, dtype=input_signature.dtype)


-- Serial Model --
Serial[
  LayerNorm
  Serial[
    Relu
  ]
  TimesTwo
  Dense_2
  Dense_1
  LogSoftmax
] 

-- Properties --
name : Serial
sublayers : [LayerNorm, Serial[
  Relu
], TimesTwo, Dense_2, Dense_1, LogSoftmax]
expected inputs : 1
promised outputs : 1
weights & biases: ((Array([1, 1, 1, 1, 1], dtype=int32), Array([0, 0, 0, 0, 0], dtype=int32)), ((), (), ()), (), (Array([[-0.792707  , -0.85926765],
       [ 0.72052234,  0.6414506 ],
       [ 0.6638057 , -0.3427339 ],
       [ 0.3194956 , -0.5063189 ],
       [-0.3447836 ,  0.52460796]], dtype=float32), Array([ 4.2711878e-07, -1.1384492e-06], dtype=float32)), (Array([[-0.17818879],
       [ 0.00781706]], dtype=float32), Array([1.6863944e-06], dtype=float32)), ()) 

-- Inputs --
x : [-2 -1  0  1  2] 

-- Outputs --
y : [0.]


## JAX

Some things are not possible with JAX's fastmat numpy but still possible with regular Nump

In [11]:
x_numpy = np.array([1, 2, 3])

n_jax = fastmath.numpy.array([1,2, 3])

## **Classes and Subclasses**

In [12]:
class My_class():
    x = None

instance = My_class()

str(instance.x)

'None'

### `__init__` method

In [13]:
class My_class():
    def __init__(self, y):
        self.x = y
    
instance = My_class(10)

str(instance.x)

'10'

### `__call__` method

In [16]:
class My_class():
    def __init__(self, y):
        self.x = y
    
    def  __call__(self, z):
        self.x += z
        print(self.x)
    
instance = My_class(5)

instance(5)

10


#### Custom Methods

In [17]:
class My_Class: 
    def __init__(self, y, z): 
        self.x_1 = y
        self.x_2 = z
    def __call__(self):      
        a = self.x_1 - 2*self.x_2 
        return a
    def my_method(self, w):  
        result = self.x_1*self.x_2 + w
        return result

instance = My_Class(1, 10)

instance.my_method(10)
    

20

#### Subclass and Inheritance


`Trax` uses classes and subclassesto define layers. Every layer from the model is defined as a subclass of the layer of base class. 

In [25]:
'''
For subclass, every method and parameter is inherited from 'super' class including `__init__` and `__call__` 
'''

class sub_c(My_Class):
    def additional_method(self):
        print(self.x_1)

instance = sub_c(1,10)


## Data Generators

Behaves like a iterator, it will return the next item.

In [None]:
import random as rnd

def data_generator(batch, data_x, data_y, shuffle=True):
    data_lng = len(data_x)
    index_list = [*range(data_lng)] # list with the ordered idexes of sample data

    if shuffle:
        rnd.shuffle(index_list)
    
    idex = 0
    while True:
        x = []
        y = []

        for i in range(batch):
            if index >= data_lng:
                index=0

                if shuffe:
                    rnd.shuffle(index_list)
            
            x.append(data_x[index_list[index]])
            y.append(data_y[index_list[index]])
    
        yield (x, y)
        