# Setup

The below code mixes TF and JAX, which won't work on stock colab with TPUs.

Find the cell of your DM Py3 TPUv2 colab, 

---


```
borgcfg --vars=cell=YOURCELL production/borg/colaboratory/mint/deepmind_tpu_py3.borg update colab_kernel.kernel
```

# Imports

In [0]:
from six.moves import cPickle
import os
import datetime
import random
import tempfile
from functools import partial
import copy

import gin
gin.enter_interactive_mode()

import numpy as onp
import jax
from jax import lax
from jax import random as jr
from jax import numpy as np
from jax.ops import index, index_update

from matplotlib import pyplot as plt

import tensorflow.google as tf
tf.enable_eager_execution()
import tensorflow_datasets as tfds

from colabtools import adhoc_import
with adhoc_import.Google3():
  import trax
  from trax.supervised import trainer_lib
  from trax import layers as tl
  from trax.supervised import inputs as trax_input
  from trax import models as trax_models
  from trax.models import transformer
  from trax import optimizers as trax_optimizers
  from trax import backend
  from trax import shapes

T2T: skipped importing 1 data_generators modules. OK if no other errors. Depend on _heavy or problem-specific py_binary targets if trying to use a module that was skipped.


# Trax Input Fn

In [0]:
@gin.configurable()
def vikram_inputs(
    n_batch=8192,
    num_input_timestamps=1001,
    num_embed=117,
    num_output_predictions=57,
    cns_path='/readahead/200M/cns/jq-d/home/levskaya/calico/vikram/',
    # shuffle_files=True,
    batch_shuffle_size=1024,
    n_prefetch=16,
    mode='train'):

  print(cns_path)
  # grab filenames from CNS
  train_files = tf.gfile.Glob(os.path.join(cns_path, '*train*'))
  dev_files =  tf.gfile.Glob(os.path.join(cns_path, '*dev*'))
  test_files = tf.gfile.Glob(os.path.join(cns_path, '*test*'))

  # tf.example proto parsing
  feature_description = {
    "inputs": tf.VarLenFeature(tf.float32),
    "targets": tf.VarLenFeature(tf.float32),
  }

  def _parse_example(x):
    return tf.parse_example([x], feature_description)

  # reshaping
  input_shape = [-1, num_input_timestamps, num_embed]
  input_dtype = np.float32
  target_shape = [-1, 1, num_output_predictions]
  target_dtype = np.float32
  
  def _reshape(x):
    inps = x['inputs'].values
    inps = tf.reshape(inps, input_shape)
    #inps = inps[:, 499:501, :] ## LOOK AT PROMOTER REGION ONLY AS BASELINE
    tgts = x['targets'].values
    # condition = tf.less(tgts, -0.5)
    # tgts = tf.dtypes.cast(tf.where(condition, tf.zeros_like(tgts), tf.ones_like(tgts)),tf.int32)
    tgts = tf.reshape(tgts, target_shape)
    return (inps, tgts)

  # tf.data chain
  def make_dataset_iterator(data_files):
    return (
        tf.data.TFRecordDataset(data_files)
        .map(_parse_example)
        .repeat() #used to be after map_reshape
        .shuffle(batch_shuffle_size) #used to be after batch
        .batch(n_batch, drop_remainder=True)
        .map(_reshape)
        .prefetch(n_prefetch)
        .as_numpy_iterator()
        )

  def make_dataset_iterator2(data_files):
    return (
        tf.data.TFRecordDataset(data_files)
        .map(_parse_example) #removed shuffle & repeat operations
        .batch(n_batch, drop_remainder=False)
        .map(_reshape)
        .prefetch(n_prefetch)
        .as_numpy_iterator()
        )

  if mode=='linregtrain':
    return make_dataset_iterator2(train_files)
  elif mode=='test':
    return make_dataset_iterator2(test_files)
  else:
    # make dataset iterators
    ds_train = lambda: make_dataset_iterator(train_files)
    ds_dev = lambda: make_dataset_iterator(dev_files)
    ds_test = lambda: make_dataset_iterator(test_files)
    
    # put information in form trax wants
    input_shape_without_batch = list(input_shape)[1:]
    target_shape_without_batch = list(target_shape)[1:]
    return trax_input.Inputs(train_stream=ds_train,
                  train_eval_stream=ds_dev,
                  eval_stream=ds_test,
                  input_shape=input_shape_without_batch,
                  input_dtype=input_dtype,
                  target_shape=target_shape_without_batch,
                  target_dtype=target_dtype)

# Modified Transformer

In [0]:
@tl.base.layer()
def no_padding_mask(x, **kwargs):
  del kwargs
  return np.reshape(np.ones(x.shape[0]*x.shape[-2], dtype=x.dtype), 
                    (x.shape[0], 1, 1, x.shape[-2]))

def _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode, activation):
  """Returns a list of layers implementing a feed-forward block.
  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    activation: the non-linearity in feed-forward layer
  Returns:
    A list of layers which maps vectors to vectors.
  """
  dropout_middle = tl.Dropout(
      rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode)
  dropout_final = tl.Dropout(
      rate=dropout, name='ff_final_%d' % layer_idx, mode=mode)

  return [
      tl.LayerNorm(),
      tl.Dense(d_ff),
      activation(),
      dropout_middle,
      tl.Dense(d_model),
      dropout_final,
  ]

def _EncoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode,
                  ff_activation):
  """Returns a list of layers that implements a Transformer encoder block.
  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.
  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
  attention = tl.Attention(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  dropout_ = tl.Dropout(
      rate=dropout, name='dropout_enc_attn', mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),
  ]

@tl.layer()
def subsetdata(x, startidx, stopidx, **kwargs):
  # print(x.shape)
  x = x[:, startidx:stopidx, :]
  # print(x.shape)
  return x

@gin.configurable()
def non_tokenizing_transformer(n_classes=57,
                               d_model=512, #512
                               d_ff=2048, #2048
                               n_layers=2,
                               n_heads=8,
                               dropout=0.1,
                               max_len=1001,
                               mode='train',
                               ff_activation=tl.Relu):
  """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
  # embedder = [
  #     # tl.Embedding(d_model, vocab_size),
  #     tl.Dense(d_model),
  #     tl.Dropout(rate=dropout, name='emb_dropout', mode=mode),
  #     tl.PositionalEncoding(max_len=max_len),
  # ]
  # return tl.Serial(                             #      tokens
  #    tl.Dup(),                                 # toks toks
  #    tl.Parallel(embedder, no_padding_mask()),  # vecs mask
  #    [encoder_blocks(d_model, d_ff, n_heads, dropout, i, mode)
  #     for i in range(n_layers)],               # vecs mask
  #    tl.Parallel([], tl.Drop()),               # ____  0
  #    tl.LayerNorm(),                           # vecs
  #    tl.Mean(axis=1),  # Average on length.    # vecs
  #    tl.Dense(n_classes)                      # vecs
  #    # tl.LogSoftmax()
  # )
  positional_encoder = [
    # tl.Embedding(d_model, vocab_size),
    tl.Dense(d_model),
    tl.Dropout(rate=dropout, name='emb_dropout', mode=mode),
    tl.PositionalEncoding(max_len=max_len)]

  encoder_blocks = [
      tl._EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
      for i in range(n_layers)]

  # Assemble and return the model.
  return tl.Serial(                               # toks
      # Encode.
      tl.Branch(positional_encoder, no_padding_mask()),  # vecs masks
      encoder_blocks,                             # vecs masks
      tl.Select([0], n_in=2),                     # vecs
      tl.LayerNorm(),                             # vecs
      subsetdata(startidx=498, stopidx=501),
      # Map to output categories.
      tl.Mean(axis=1),                            # vecs
      tl.Dense(n_classes),                        # vecs
  )


# **Run**

In [0]:
from scipy import stats

# # Config Trax for simple synthetic data problem
# gin.parse_config("""
# # Parameters for MultifactorSchedule:
# # ==============================================================================
# MultifactorSchedule.constant = 0.1
# MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay'
# MultifactorSchedule.warmup_steps = 8000
# """)

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
output_dir = os.path.expanduser("/tmp/trax_%s" % timestamp)
print(output_dir)

_ = trainer_lib.train(model=non_tokenizing_transformer,
               optimizer=trax_optimizers.Adam, #Adam, #RMSProp, Adafactor
               loss_fn=tl.L2LossScalar, #CrossEntropyLossScalar, #L2LossScalar, #CrossEntropyLossScalar, 
               inputs=vikram_inputs,
               output_dir=output_dir,
               steps=20000,
               eval_steps=10,
               eval_frequency=1000,
               has_weights=False,
               metrics = {'loss': trax.layers.metrics.L2Scalar})

/tmp/trax_20200101_1720
/readahead/200M/cns/jq-d/home/levskaya/calico/vikram/
Step      0: Starting training using 8 devices
Step      0: Total number of trainable weights: 6907961

Step      1: Ran 1 train steps in 39.31 secs
Step      1: Evaluation
Step      1: train loss |  136.64440765
Step      1: eval  loss |  143.02021942
Step      1: Finished evaluation

Step   1000: Ran 999 train steps in 133.47 secs
Step   1000: Evaluation
Step   1000: train loss |  32.59833298
Step   1000: eval  loss |  22.95791292
Step   1000: Finished evaluation

Step   2000: Ran 1000 train steps in 110.77 secs
Step   2000: Evaluation
Step   2000: train loss |  26.39411240
Step   2000: eval  loss |  29.57369671
Step   2000: Finished evaluation

Step   3000: Ran 1000 train steps in 110.58 secs
Step   3000: Evaluation
Step   3000: train loss |  20.64959049
Step   3000: eval  loss |  25.49830055
Step   3000: Finished evaluation

Step   4000: Ran 1000 train steps in 108.64 secs
Step   4000: Evaluation
Step   4

KeyboardInterrupt: ignored

# **Test pretrained NN model**

In [0]:
#print(backend.device_count())
# output_dir='/tmp/trax_20191125_1349'

# Inference model
predict_model = non_tokenizing_transformer(mode='eval')
predict_signature = shapes.ShapeDtype((1, 1001, 117), dtype=np.float32) #shape of input
# predict_signature = shapes.ShapeDtype((1, 2, 117), dtype=np.float32) #shape of input
predict_model.init(predict_signature)

# Load from file (API for trainer_state may change)
trainer_state = trainer_lib.load_trainer_state(output_dir)
# predict_model.params = trainer_state.opt_state.weights[0]

weights = trainer_state.opt_state.weights[0]
model_state = trainer_state.model_state[0]
rng = backend.random.get_prng(0)

# Run inference
preds = []
obs = []
ds_test = vikram_inputs(mode='test', n_batch=2)

# make dataset iterators
for input in ds_test:
  # print(input[0].shape)
  preds.append(predict_model(input[0], weights=weights, state=model_state, rng=rng)) #, rng=random_key
  obs.append(np.squeeze(input[1]))
preds = np.vstack(preds)
obs = np.vstack(obs)

print(preds.shape)
print(obs.shape)
print(preds.dtype)
print(obs.dtype)
print(preds)

#preds = [1 if x < 0.5 else 0 for x in np.exp(preds[:,0])] 
#print(preds)

# import matplotlib.pyplot as plt
# fig1, ax1 = plt.subplots()
# ax1.hist(np.exp(preds[:,0]), bins=100)
# fig1, ax1 = plt.subplots()
# ax1.hist(obs[:,0], bins=100)

# from sklearn import metrics
# print(metrics.accuracy_score(obs[:,0], preds[:,0]))

slope, intercept, r_value, p_value, std_err = stats.mstats.linregress(preds[:,53], obs[:,53])
print('Test R^2 = %.3f' % (r_value**2))

Model loaded from /tmp/trax_20200101_1720/model.pkl at step 4000
/readahead/200M/cns/jq-d/home/levskaya/calico/vikram/
(1000, 57)
(1000, 57)
float32
float32
[[-1.1813539  -0.6862472  -0.86021227 ... -0.998313   -1.2930957
  -0.81592447]
 [ 0.11105715  0.25464073  0.3180077  ... -0.40874064 -0.08990811
   0.61788934]
 [ 0.27884412  0.20377345  0.2841013  ...  0.31727543  0.36500064
   0.64982074]
 ...
 [ 0.33665946  0.13464828  0.23719117 ...  0.49370202  0.33674708
   0.64944476]
 [-1.0406121  -0.94504756 -0.9263856  ... -0.71154547 -0.85463125
  -0.9171174 ]
 [-1.273698   -1.0722353  -1.052209   ... -0.87417805 -0.88557506
  -0.91105646]]
Test R^2 = 0.615


# Baseline Linear Reg

In [0]:
import numpy as np
from scipy import stats
from sklearn import linear_model, metrics

obs = []
inp = []
ds_train = vikram_inputs(mode='linregtrain', n_batch=2000)

###TRAINING

# make dataset iterators
for input in ds_train:
  print(input[0].shape)
  inp.append(np.mean(input[0], axis=1))
  obs.append(np.squeeze(input[1]))

obs = np.vstack(obs)
inp = np.vstack(inp)

print(obs.shape)
print(inp.shape)

regr = linear_model.LinearRegression()
# regr = linear_model.LogisticRegression()
regr.fit(inp, obs[:,1])


###TESTING

obs = []
inp = []
ds_test = vikram_inputs(mode='test', n_batch=200)

# make dataset iterators
for input in ds_test:
  print(input[0].shape)
  inp.append(np.mean(input[0], axis=1))
  obs.append(np.squeeze(input[1]))

obs = np.vstack(obs)
inp = np.vstack(inp)
y_hat = regr.predict(inp)

# print(metrics.accuracy_score(obs[:,0], y_hat))
slope, intercept, r_value, p_value, std_err = stats.linregress(obs[:,53], y_hat)
print('BASELINE: Test R^2 = %.3f' % r_value**2)

/readahead/200M/cns/jq-d/home/levskaya/calico/vikram/
(2000, 2, 117)
(2000, 2, 117)
(2000, 2, 117)
(2000, 2, 117)
(2000, 2, 117)
(2000, 2, 117)
(2000, 2, 117)
(1877, 2, 117)
(15877, 57)
(15877, 117)
/readahead/200M/cns/jq-d/home/levskaya/calico/vikram/
(200, 2, 117)
(200, 2, 117)
(200, 2, 117)
(200, 2, 117)
(200, 2, 117)
BASELINE: Test R^2 = 0.524
