In [1]:
import numpy as np
import trax

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


In [2]:
from trax import layers as tl
from trax import shapes, fastmath

In [3]:
from trax.fastmath import numpy as fastnp

In [4]:
one = fastnp.ones([4,4])

In [5]:
one

DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32)

In [6]:
one + 1

DeviceArray([[2., 2., 2., 2.],
             [2., 2., 2., 2.],
             [2., 2., 2., 2.],
             [2., 2., 2., 2.]], dtype=float32)

In [7]:
x = fastnp.array([[1, 1, 1, 1],[1, 1, 1, 1]])

In [8]:
fastnp.dot(x, one)

DeviceArray([[4., 4., 4., 4.],
             [4., 4., 4., 4.]], dtype=float32)

In [9]:
def f(x):
    return 2 * x * x
    

In [10]:
grad_f = trax.fastmath.grad(f)

In [11]:
trax

&lt;module &#39;trax&#39; from &#39;/home/prhyme/.local/lib/python3.8/site-packages/trax/__init__.py&#39;&gt;

In [12]:
grad_f(1.0)

DeviceArray(4., dtype=float32)

In [13]:
print(grad_f(2.0))

8.0


In [14]:
relu = tl.Relu()

In [15]:
relu.name

&#39;Relu&#39;

In [16]:
relu.n_in

1

In [17]:
relu.n_out

1

In [18]:
x = np.array([-2,0,1,3])

In [19]:
relu(x)

DeviceArray([0, 0, 1, 3], dtype=int32)

In [20]:
# Concatenate Layer

In [21]:
concat = tl.Concatenate()

In [22]:
concat.name

&#39;Concatenate&#39;

In [23]:
tl.Relu()(np.array([1,1,]))

DeviceArray([1, 1], dtype=int32)

In [24]:
concat.n_in

2

In [25]:
concat.n_out

1

In [26]:
x1 = np.array([1,1,2,2])
x2 = np.array([3,3,4,4])

In [27]:
concat([x1, x2])

DeviceArray([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

In [28]:
# Layers are configurable

In [29]:
concat_3 = tl.Concatenate(n_items=3)

In [30]:
x3 = np.array([5,5, 6,6])

In [31]:
concat_3([x1,x2,x3])

DeviceArray([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6], dtype=int32)

In [32]:
help(tl.Concatenate)

Help on class Concatenate in module trax.layers.combinators:

class Concatenate(trax.layers.base.Layer)
 |  Concatenate(n_items=2, axis=-1)
 |  
 |  Concatenates n tensors into a single tensor.
 |  
 |  Method resolution order:
 |      Concatenate
 |      trax.layers.base.Layer
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, n_items=2, axis=-1)
 |      Creates a partially initialized, unconnected layer instance.
 |      
 |      Args:
 |        n_in: Number of inputs expected by this layer.
 |        n_out: Number of outputs promised by this layer.
 |        name: Class-like name for this layer; for use when printing this layer.
 |        sublayers_to_print: Sublayers to display when printing out this layer;
 |          By default (when None) we display all sublayers.
 |  
 |  forward(self, xs)
 |      Computes this layer&#39;s output as part of a forward pass through the model.
 |      
 |      Authors of new layer subclasses should override this method 

In [33]:
# Layers can have weights

In [34]:
help(tl.LayerNorm)

Help on class LayerNorm in module trax.layers.normalization:

class LayerNorm(trax.layers.base.Layer)
 |  LayerNorm(epsilon=1e-06)
 |  
 |  Layer normalization.
 |  
 |  Method resolution order:
 |      LayerNorm
 |      trax.layers.base.Layer
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, epsilon=1e-06)
 |      Creates a partially initialized, unconnected layer instance.
 |      
 |      Args:
 |        n_in: Number of inputs expected by this layer.
 |        n_out: Number of outputs promised by this layer.
 |        name: Class-like name for this layer; for use when printing this layer.
 |        sublayers_to_print: Sublayers to display when printing out this layer;
 |          By default (when None) we display all sublayers.
 |  
 |  forward(self, x)
 |      Computes this layer&#39;s output as part of a forward pass through the model.
 |      
 |      Authors of new layer subclasses should override this method to define the
 |      forward computation

In [35]:
norm = tl.LayerNorm()

In [36]:
help(shapes.signature)

Help on function signature in module trax.shapes:

signature(obj)
    Returns a `ShapeDtype` signature for the given `obj`.
    
    A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
    instances. Note that this function is permissive with respect to its inputs
    (accepts lists or tuples or dicts, and underlying objects can be any type
    as long as they have shape and dtype attributes) and returns the corresponding
    nested structure of `ShapeDtype`.
    
    Args:
      obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
          of such objects.
    
    Returns:
      A corresponding nested structure of `ShapeDtype` instances.



In [37]:
x = np.array([1,2,3,4], dtype='float')

In [38]:
# Use the input data signature to get shape and type for initializing weights and biases
# We need to convert the input datatype from usual tuple to trax ShapeDtype

In [39]:
norm.init(shapes.signature(x))

((DeviceArray([1., 1., 1., 1.], dtype=float32),
  DeviceArray([0., 0., 0., 0.], dtype=float32)),
 ())

In [40]:
type(x.shape)

tuple

In [41]:
type(shapes.signature(x))

trax.shapes.ShapeDtype

In [42]:
norm.weights

(DeviceArray([1., 1., 1., 1.], dtype=float32),
 DeviceArray([0., 0., 0., 0.], dtype=float32))

In [43]:
y = norm(x)

In [44]:
y

DeviceArray([-1.3416404 , -0.44721344,  0.44721344,  1.3416404 ], dtype=float32)

In [45]:
# Custom Layers

In [46]:
help(tl.Fn)

Help on function Fn in module trax.layers.base:

Fn(name, f, n_out=1)
    Returns a layer with no weights that applies the function `f`.
    
    `f` can take and return any number of arguments, and takes only positional
    arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
    The following, for example, would create a layer that takes two inputs and
    returns two outputs -- element-wise sums and maxima:
    
        `Fn(&#39;SumAndMax&#39;, lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
    
    The layer&#39;s number of inputs (`n_in`) is automatically set to number of
    positional arguments in `f`, but you must explicitly set the number of
    outputs (`n_out`) whenever it&#39;s not the default value 1.
    
    Args:
      name: Class-like name for the resulting layer; for use in debugging.
      f: Pure function from input tensors to output tensors, where each input
          tensor is a separate positional arg, e.g., `f(x0, x1) --&gt; x

In [47]:
def TimesTwo():
    layer_name = "TimesTwo"

    def func(x):
        return x*2

    return tl.Fn(layer_name, func)

In [48]:
times_two = TimesTwo()

In [49]:
times_two(x)

array([2., 4., 6., 8.])

In [50]:
def SumTen():
    layer_name = "SumTen"

    def add(x):
        return x+10
    return tl.Fn(layer_name, add)

In [51]:
sum = SumTen()

In [52]:
sum(10)

20

In [53]:
# Combinators

In [54]:
# Serial Combinator

In [55]:
serial = tl.Serial(
    tl.LayerNorm(),
    tl.Relu(),
    times_two,

    tl.Dense(n_units=2),
    tl.LogSoftmax()
)

In [60]:
serial.init(shapes.signature(x))

(((DeviceArray([1, 1, 1, 1, 1], dtype=int32),
   DeviceArray([0, 0, 0, 0, 0], dtype=int32)),
  (),
  (),
  (DeviceArray([[ 0.4637612 , -0.40276128],
                [ 0.80113703,  0.64133954],
                [ 0.49185318, -0.03974866],
                [ 0.73240066,  0.60542107],
                [-0.22441587,  0.6409376 ]], dtype=float32),
   DeviceArray([ 1.1225158e-06, -5.3297305e-07], dtype=float32)),
  ()),
 ((), (), (), (), ()))

In [61]:
x = np.array([1,2,3,4,5])

In [62]:
serial(x)

DeviceArray([-2.3665137 , -0.09850311], dtype=float32)

In [63]:
serial.n_out

1

In [64]:
serial.n_in

1