<a href="https://colab.research.google.com/github/MichalRyszardWojcik/hello-world/blob/master/trax_2020_08_19_first_day.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on https://github.com/jalammar/jalammar.github.io/blob/master/notebooks/Trax_TransformerLM_Intro.ipynb

The example from there returns a flipped array.

My example is basically the same but it returns an array filled with the maximum value.

[a,b,c,d] --> [m,m,m,m] where m = max(a,b,c,d)

In [None]:
import os
import numpy as np
! pip install -q -U trax
import trax

In [17]:
# Create a Transformer model.
def tiny_transformer_lm(mode='train'):
  return trax.models.TransformerLM(  
          d_model=32, d_ff=128, n_layers=2, 
          vocab_size=32, mode=mode)

In [73]:
def max_ints_task(batch_size, length=4):
  while True:
    random_ints = m = np.random.randint(1, 31, (batch_size,length))
    source = random_ints

    maxline = np.amax(source,1).reshape(1,batch_size)
    max = np.copy(maxline)
    for i in range(length-1):
      max = np.concatenate((max,maxline),0)
    target = max.transpose()

    zero = np.zeros([batch_size, 1], np.int32)
    x = np.concatenate([zero, source, zero, target], axis=1)

    loss_weights = np.concatenate([np.zeros((batch_size, length+2)),
                                    np.ones((batch_size, length))], axis=1)
    yield (x, x, loss_weights)  # Here inputs and targets are the same.

In [71]:
a = max_ints_task(8)
sequence_batch, _ , masks = next(a)
sequence_batch


array([[ 0,  7,  9,  8, 16,  0, 16, 16, 16, 16],
       [ 0, 19,  2,  3,  3,  0, 19, 19, 19, 19],
       [ 0, 28,  4, 20, 10,  0, 28, 28, 28, 28],
       [ 0, 30, 16, 12, 27,  0, 30, 30, 30, 30],
       [ 0,  7,  3, 30,  7,  0, 30, 30, 30, 30],
       [ 0,  4, 26, 12, 16,  0, 26, 26, 26, 26],
       [ 0, 13,  1, 19, 24,  0, 24, 24, 24, 24],
       [ 0, 14, 20,  4,  5,  0, 20, 20, 20, 20]])

In [74]:
max_ints_inputs = trax.data.inputs.Inputs(lambda _: max_ints_task(16))

In [None]:
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl.gz  # Remove old model.

# Train tiny model with Trainer.
trainer = trax.supervised.Trainer(
    model=tiny_transformer_lm,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adafactor,  # Change optimizer params here.
    lr_schedule=trax.lr.constant(0.001),  # Change lr schedule here.
    inputs=max_ints_inputs,
    output_dir=output_dir)

# Train for 3 epochs each consisting of 500 train batches, eval on 2 batches.
n_epochs  = 3
train_steps = 800
eval_steps = 2
for _ in range(n_epochs):
  trainer.train_epoch(train_steps, eval_steps)

In [86]:
input = np.array([[0, 4, 6, 29, 10, 0]])

# Initialize model for inference.
predict_model = tiny_transformer_lm(mode='predict')
predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)
predict_model.init_from_file(os.path.join(output_dir, "model.pkl.gz"),
                             weights_only=True, input_signature=predict_signature)

# Run the model
outputt = trax.supervised.decoding.autoregressive_sample(
    predict_model, input, temperature=0.0, max_length=4)

# Print the contents of output
outputt

array([[29, 29, 29, 29]])

In [97]:
def output(input):
  predict_model = tiny_transformer_lm(mode='predict')
  predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)
  predict_model.init_from_file(os.path.join(output_dir, "model.pkl.gz"),weights_only=True, input_signature=predict_signature)
  return trax.supervised.decoding.autoregressive_sample(predict_model, input, temperature=0.0, max_length=4)

In [110]:
#input = np.array([[0, 4, 6, 29, 10, 0]])
#output(input)

def randominput():
  x = np.random.randint(1, 31, (1,6))
  x[0,0] = 0; x[0,5] = 0
  return x

def randomdemo():
  a = randominput()
  b = output(a)
  return [a,b]

for _ in range(7): print(randomdemo())

[array([[ 0, 17, 19, 26,  5,  0]]), array([[26, 26, 26, 26]])]
[array([[ 0, 24, 10, 15, 23,  0]]), array([[24, 24, 24, 24]])]
[array([[ 0,  6, 12, 26, 14,  0]]), array([[26, 26, 26, 26]])]
[array([[ 0, 28, 26,  9,  2,  0]]), array([[28, 28, 28, 28]])]
[array([[ 0, 27, 18,  4, 14,  0]]), array([[27, 27, 27, 27]])]
[array([[ 0, 24,  8, 17, 18,  0]]), array([[24, 24, 24, 24]])]
[array([[ 0,  2, 19, 30, 28,  0]]), array([[30, 30, 30, 30]])]
