# Trax Transformer

We train **Trax Transformer** on a simple copy problem and run inference.
* Training and inference can run on TPU, even with multiple input lengths
* Inputs are fed from python but it's asynchronous so doesn't slow training
* Transformer in predict mode implements fast inference (attention caches)

BEGIN GOOGLE-INTERNAL

To use, click Connect, then the Start tab, borg runtime, and select the DeepMind JellyDonut (TPUv2, Python 3) borg runtime type.

END GOOGLE-INTERNAL

## General Setup
Execute the following few cells (once) before running any of the code samples in this notebook.

In [0]:
from six.moves import cPickle
import os
import datetime
import random
import tempfile
from functools import partial
import copy

import numpy as onp
import jax
from jax import lax
from jax import random as jr
from jax import numpy as np
from jax.ops import index, index_update

from matplotlib import pyplot as plt

import tensorflow.google as tf
tf.enable_eager_execution()
import tensorflow_datasets as tfds

from colabtools import adhoc_import
with adhoc_import.Google3():
  import trax
  from trax import trainer_lib
  from trax import layers as tl
  from trax import inputs as trax_input
  from trax import models as trax_models
  from trax import optimizers as trax_optimizers
  from trax import backend
  from trax import shapes
  from trax import trainer_lib

ImportError: ignored

# Transformer

In [0]:
def feed_forward(d_model, d_ff, dropout, layer_idx, mode):
  """Feed-forward block with layer normalization at start."""
  return tl.Serial(
      tl.LayerNorm(),
      tl.Dense(d_ff),
      tl.Relu(),
      tl.Dropout(rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode),
      tl.Dense(d_model),
      tl.Dropout(rate=dropout, name='ff_final_%d' % layer_idx, mode=mode),
  )


def encoder_block(d_model, d_ff, n_heads, dropout, layer_idx, mode):
  """Returns a layer sequence that implements a Transformer encoder block.

  The input to the layer sequence is a pair, (activations, mask), where the
  mask was created from the original source tokens to prevent attending to the
  padding part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'

  Returns:
    A sequence of layers that maps an (activations, mask) pair to an
    (activations, mask) pair.
  """
  attention = [
      tl.LayerNorm(),
      tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode),
      tl.Dropout(rate=dropout, name='enc_attn_dropout', mode=mode),
  ]
  ff = [
      feed_forward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode),
  ]
  return tl.Serial(
      tl.Residual(attention),
      tl.Residual(ff),
  )


@tl.layer()
def no_padding_mask(x, **kwargs):
  del kwargs
  return np.reshape(np.ones(x.shape[0]*x.shape[-2], dtype=x.dtype),
                    (x.shape[0], 1, 1, x.shape[-2]))


def non_tokenizing_transformer(n_classes=57,
                               d_model=512,
                               d_ff=2048,
                               n_layers=6,
                               n_heads=8,
                               dropout=0.1,
                               max_len=1001,
                               mode='train'):
  """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
  embedder = [
      # tl.Embedding(d_model, vocab_size),
      tl.Dense(d_model),
      tl.Dropout(rate=dropout, name='emb_dropout', mode=mode),
      tl.PositionalEncoding(max_len=max_len),
  ]
  return tl.Serial(                             #      tokens
      tl.Dup(),                                 # toks toks
      tl.Parallel(embedder, no_padding_mask()),  # vecs mask
      [encoder_block(d_model, d_ff, n_heads, dropout, i, mode)
       for i in range(n_layers)],               # vecs mask
      tl.Parallel([], tl.Drop()),               # ____  0
      tl.LayerNorm(),                           # vecs
      tl.Mean(axis=1),  # Average on length.    # vecs
      tl.Dense(n_classes),                      # vecs
  )

NameError: ignored

In [0]:
def vikram_inputs(
    n_batch=200,
    num_input_timestamps=1001,
    num_embed=117,
    num_output_predictions=57,
    cns_path='/readahead/200M/cns/tm-d/home/vikrama/',
    n_prefetch=4):
  """Prepare inputs."""

  # grab filenames from CNS
  test_files = tf.gfile.Glob(os.path.join(cns_path, '*test*'))

  # tf.example proto parsing
  feature_description = {
      'inputs': tf.VarLenFeature(tf.float32),
      'targets': tf.VarLenFeature(tf.float32),
  }

  def _parse_example(x):
    return tf.parse_example([x], feature_description)

  # reshaping
  input_shape = [-1, num_input_timestamps, num_embed]
  input_dtype = np.float32
  target_shape = [-1, 1, num_output_predictions]
  target_dtype = np.float32

  def _reshape(x):
    inps = x['inputs'].values
    inps = tf.reshape(inps, input_shape)
    # inps = inps[:, 499:501, :]
    tgts = x['targets'].values
    tgts = tf.reshape(tgts, target_shape)
    return (inps, tgts)

  # tf.data chain
  def make_dataset_iterator(data_files):
    return (  # could have read as numpy directly rather than TFrecords
        # pref reads 4 batches into CPU to have queue of input ready
        tf.data.TFRecordDataset(data_files)
        .map(_parse_example) #removed shuffle layer
        .batch(n_batch, drop_remainder=False) #changed from True to False to not drop last batch
        .map(_reshape)
        .prefetch(n_prefetch)
        .as_numpy_iterator()  # converts TFRecord into numpy
        )
  
  return make_dataset_iterator(test_files)

# Inference model
output_dir = '/cns/tm-d/home/vikrama/rs=6.3/test7/15'

predict_model = non_tokenizing_transformer(mode='eval', d_model=512, d_ff=2048, max_len=1001, n_classes=57, n_heads=2, n_layers=8)

predict_signature = shapes.ShapeDtype((1, 1001, 117), dtype=np.float32) #shape of input
predict_model.initialize_once(predict_signature)

# Load from file (API for trainer_state may change)
trainer_state = trainer_lib.load_trainer_state(output_dir)
predict_model.params = trainer_state.opt_state.weights[0]

# Run inference
preds = []
obs = []
inp = []
ds_test = vikram_inputs()

# make dataset iterators
for input in ds_test:
  print(input[0].shape)
  # inp.append(np.mean(input[0][:,499:501,:], axis=1))
  # preds.append(predict_model(input[0], rng=random_key))
  # obs.append(np.squeeze(input[1]))
preds = np.vstack(preds)
obs = np.vstack(obs)
inp = np.vstack(inp)

print(preds.shape)
print(obs.shape)
print(inp.shape)

from scipy import stats
from sklearn import linear_model

regr = linear_model.LinearRegression()
regr.fit(inp, obs[:,0])
y_hat = regr.predict(inp)
slope, intercept, r_value, p_value, std_err = stats.linregress(obs[:,0], y_hat)
print('BASELINE: Test R^2 = %.3f' % r_value**2)

rvalues = []
for i in range(0,preds.shape[1]):
     slope, intercept, r_value, p_value, std_err = stats.mstats.linregress(preds[:,i], obs[:,i])
     print('Test R^2 %d = %.3f' % (i, r_value**2))
     rvalues.append(r_value)

Model loaded from /cns/tm-d/home/vikrama/rs=6.3/test7/15/model.pkl at step 20000
(200, 2, 117)
(200, 2, 117)


KeyboardInterrupt: ignored

In [0]:
import numpy as np

x=np.ones((2,3,3,3))
x
c=np.reshape(np.ones(x.shape[0]*x.shape[-2], dtype=x.dtype),
                    (x.shape[0], 1, 1, x.shape[-2]))

c.shape

(2, 1, 1, 3)