<a href="https://colab.research.google.com/github/akiabe/coding-practice/blob/master/trax_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
# install trax
!pip install -q -U trax
import trax

In [33]:
# imports
import os
import numpy as np
from trax import layers as tl

In [34]:
# relu rayer
relu = tl.Relu()

print("name :", relu.name)
print("input :", relu.n_in)
print("output :", relu.n_out)

x = np.array([-2, -1, 0, 1, 2])
print("x :", x)

y = relu(x)
print("y :", y)

name : Relu
input : 1
output : 1
x : [-2 -1  0  1  2]
y : [0 0 0 1 2]


In [35]:
# concat layer
concat = tl.Concatenate()

print("name :", concat.name)
print("input :", concat.n_in)
print("output :", concat.n_out)

x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("x1  :", x1)
print("x2 :", x2)

y = concat([x1, x2])
print("y :", y)

name : Concatenate
input : 2
output : 1
x1  : [-10 -20 -30]
x2 : [1. 2. 3.]
y : [-10. -20. -30.   1.   2.   3.]


In [36]:
# configure concat layer
concat_3 = tl.Concatenate(n_items=3)

print("name :", concat_3.name)
print("input :", concat_3.n_in)
print("output :", concat_3.n_out)

x1 = np.array([-10, -20, -30])
x2 = x1 / -10
x3 = x2 * 0.99
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3)

y = concat_3([x1, x2, x3])
print("y :", y)

name : Concatenate
input : 3
output : 1
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
x3 : [0.99 1.98 2.97]
y : [-10.   -20.   -30.     1.     2.     3.     0.99   1.98   2.97]


In [37]:
# layer norm
norm = tl.LayerNorm()

x = np.array([0, 1, 2, 3], dtype="float")

from trax import shapes
norm.init(shapes.signature(x))

print("x shape :", x.shape)
print("x shape data type :", type(x.shape))
print("trax shape :", shapes.signature(x))
print("trax shape data type :", type(shapes.signature(x)))

print("name :", norm.name)
print("input :", norm.n_in)
print("output :", norm.n_out)
print("weight :", norm.weights[0])
print("biases :", norm.weights[1])

print("x :", x)
y = norm(x)
print("y :", y)

x shape : (4,)
x shape data type : <class 'tuple'>
trax shape : ShapeDtype{shape:(4,), dtype:float64}
trax shape data type : <class 'trax.shapes.ShapeDtype'>
name : LayerNorm
input : 1
output : 1
weight : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.]
x : [0. 1. 2. 3.]
y : [-1.3416404  -0.44721344  0.44721344  1.3416404 ]




In [39]:
#help(tl.LayerNorm)

In [None]:
#help(shapes.signature)

In [41]:
# custom layers
def TimesTwo():
  layer_name = "TimesTwo"
  
  def func(x):
    return x * 2
  
  return tl.Fn(layer_name, func)

times_two = TimesTwo()

print("name :", times_two.name)
print("inputs :", times_two.n_in)
print("outputs :", times_two.n_out)

x = np.array([1, 2, 3])
print("x :", x)

y = times_two(x)
print("y :", y)

name : TimesTwo
inputs : 1
outputs : 1
x : [1 2 3]
y : [2 4 6]


In [46]:
# serial combinator
serial = tl.Serial(
    tl.LayerNorm(),
    tl.Relu(),
    times_two,
)

x = np.array([-2, -1, 0, 1, 2])
serial.init(shapes.signature(x))

print("name :", serial.name)
print("sublayers: ", serial.sublayers)
print("input :", serial.n_in)
print("output :", serial.n_out)
print("weight & biases :", serial.weights)

print("x :", x)
y = serial(x)
print("y :", y)

name : Serial
sublayers:  [LayerNorm, Relu, TimesTwo]
input : 1
output : 1
weight & biases : ((DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), (), ())
x : [-2 -1  0  1  2]
y : [0.        0.        0.        1.4142132 2.8284264]




In [47]:
from trax import fastmath

x_numpy = np.array([1, 2, 3])
print("numpy :", type(x_numpy))

x_jax = fastmath.numpy.array([1, 2, 3])
print("jax trax numpy :", type(x_jax))

numpy : <class 'numpy.ndarray'>
jax trax numpy : <class 'jax.interpreters.xla.DeviceArray'>
