<a href="https://colab.research.google.com/github/akiabe/coding-practice/blob/master/trax_with_dnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [35]:
!pip install -q -U trax
import trax

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [36]:
# test tensors and fast math
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')

matrix = fastnp.array([
                       [1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]
])
vector = fastnp.ones(3)
product = fastnp.dot(vector, matrix)

print(type(matrix))

print(f"matrix :\n{matrix}")
print(f"vector :\n{vector}")
print(f"dot product :\n{product}")

<class 'jax.interpreters.xla.DeviceArray'>
matrix :
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector :
[1. 1. 1.]
dot product :
[12. 15. 18.]


In [37]:
# gradients
def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(grad_f(1.0))
print(grad_f(-2.0))

4.0
-8.0


In [38]:
import numpy as np
from trax import layers as tl

x = np.arange(15)
print(x)

embed = tl.Embedding(
    vocab_size=20,
    d_feature=4
)

embed.init(
    trax.shapes.signature(x)
)

y = embed(x)
print(x.shape)
print(y.shape)
print(y)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
(15,)
(15, 4)
[[-1.0577534  -1.942313   -1.1480163  -0.44229656]
 [-0.02922153  0.741222    0.947727   -0.57896084]
 [ 0.8722794  -0.3139976   1.016909   -0.16104808]
 [ 1.2119057   2.3335469   0.15371336  0.11246555]
 [ 0.06260403  1.5229656   0.55029255 -0.2252464 ]
 [-0.18557116  1.2805232   0.08519783 -0.35955766]
 [ 1.2621924   0.3161323  -0.82232356 -1.2015381 ]
 [-1.3860985   0.22837402  2.4857194  -0.36892715]
 [ 0.79144067  0.16667114 -0.79280484 -1.625344  ]
 [ 0.10067508  0.37912208  1.5271277   0.12817016]
 [-0.16322467 -1.4718566  -2.7391403  -1.2465898 ]
 [ 0.49899516 -1.1553074   0.02754989 -0.25867775]
 [ 0.27598247 -1.0560894   2.472187   -0.6668469 ]
 [ 1.1305292   0.24889068  0.24556398 -0.889463  ]
 [ 1.8789382   1.1204485  -0.31733564 -0.14913593]]


In [39]:
model = tl.Serial(
    tl.Embedding(vocab_size=8129, d_feature=256),
    tl.Mean(axis=1),
    tl.Dense(2),
    tl.LogSoftmax()
)

print(model)

Serial[
  Embedding_8129_256
  Mean
  Dense_2
  LogSoftmax
]


In [40]:
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream)) 

(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)


In [41]:
print(train_stream)

<generator object TFDS.<locals>.gen at 0x7ff4608270f8>


In [42]:
print(next(train_stream))
print(next(eval_stream))

(b'I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.', 0)
(b"There are films that make careers. For George Romero, it was NIGHT OF THE LIVING DEAD; for Kevin Smith, CLERKS; for Robert Rodriguez, EL MARIACHI. Add to that list Onur Tukel's absolutely amazing DING-A-LING-LESS. Flawless film-making, and as assured and as professional as any of the aforementioned movies. I haven't laughed this hard since I saw THE FULL MONTY. (And, e

In [70]:
data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )

train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)

example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}') 
print()

train_gen = train_batches_stream
print(train_gen.__next__())
print()
print(train_gen.__next__())
print()


shapes = [(8, 2048), (8,), (8,)]

(array([[ 139, 2293,   36, ...,    0,    0,    0],
       [ 182, 3898,   22, ...,    0,    0,    0],
       [ 139,   96,   13, ...,    0,    0,    0],
       ...,
       [ 919, 2586, 6582, ...,    0,    0,    0],
       [ 728, 1764,  962, ...,    0,    0,    0],
       [ 182, 3898,   22, ...,    0,    0,    0]]), array([1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1,
       0, 1, 0, 1, 0, 0, 0, 0, 0, 1]), array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
      dtype=float32))

(array([[ 274,   21,  947, ...,    0,    0,    0],
       [  28,  380, 1548, ...,    0,    0,    0],
       [ 182,   25,   12, ...,    0,    0,    0],
       ...,
       [ 182, 1077, 1476, ...,    0,    0,    0],
       [ 284,  297,  305, ...,    0,    0,    0],
       [ 139, 2519,  114, ...,    0,    0,    0]]), array([1, 1, 0, 1, 1, 0, 0, 0]), array([1., 1., 1., 1., 1.

In [44]:
import os
from trax.supervised import training

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

training_loop.run(2000)


Step      1: Ran 1 train steps in 1.40 secs
Step      1: train CrossEntropyLoss |  0.81095260
Step      1: eval  CrossEntropyLoss |  0.79715430
Step      1: eval          Accuracy |  0.46250000

Step    500: Ran 499 train steps in 24.55 secs
Step    500: train CrossEntropyLoss |  0.61700368
Step    500: eval  CrossEntropyLoss |  0.54913587
Step    500: eval          Accuracy |  0.70468750

Step   1000: Ran 500 train steps in 21.85 secs
Step   1000: train CrossEntropyLoss |  0.40181661
Step   1000: eval  CrossEntropyLoss |  0.40217640
Step   1000: eval          Accuracy |  0.83750000

Step   1500: Ran 500 train steps in 22.04 secs
Step   1500: train CrossEntropyLoss |  0.36357671
Step   1500: eval  CrossEntropyLoss |  0.37914133
Step   1500: eval          Accuracy |  0.83710938

Step   2000: Ran 500 train steps in 21.38 secs
Step   2000: train CrossEntropyLoss |  0.31435111
Step   2000: eval  CrossEntropyLoss |  0.32641223
Step   2000: eval          Accuracy |  0.87968750


In [None]:
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')

sentiment_log_probs = model(example_input[None, :])
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')