In [1]:
import collections
from typing import Dict, Sequence, Text

from absl import logging

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



##### https://github.com/google/flax/tree/master/examples/sst2

In [2]:
tf.__version__

'1.14.0'

## input pipeline

In [3]:
def build_vocab(datasets: Sequence[tf.data.Dataset],
                special_tokens: Sequence[Text] = (b'<pad>', b'<unk>', b'<s>', b'</s>'),
                min_freq: int = 0) -> Dict[Text, int]:
  """Returns a vocabulary of tokens with optional minimum frequency."""
  # Count the tokens in the datasets.
  counter = collections.Counter()
  for dataset in datasets:
    for example in tfds.as_numpy(dataset):
      counter.update(whitespace_tokenize(example['sentence']))

  # Add special tokens to the start of vocab.
  vocab = collections.OrderedDict()
  for token in special_tokens:
    vocab[token] = len(vocab)

  # Add all other tokens to the vocab if their frequency is >= min_freq.
  for token in sorted(list(counter.keys())):
    if counter[token] >= min_freq:
      vocab[token] = len(vocab)

  logging.info('Number of unfiltered tokens: %d', len(counter))
  logging.info('Vocabulary size: %d', len(vocab))
  return vocab


def whitespace_tokenize(text: Text) -> Sequence[Text]:
  """Splits an input into tokens by whitespace."""
  return text.strip().split()


def get_shuffled_batches(dataset: tf.data.Dataset,
                         seed: int = 0,
                         batch_size: int = 64) -> tf.data.Dataset:
  """Returns a Dataset that consists of padded batches when iterated over.
  This shuffles the examples randomly each epoch. The random order is
  deterministic and controlled by the seed.
  Batches are padded because sentences have different lengths.
  Sentences that are shorter in a batch will get 0s added at the end, until
  all sentences in the batch have the same length.
  Args:
    dataset: A TF Dataset with examples to be shuffled and batched.
    seed: The seed that determines the shuffling order, with a different order
      each epoch.
    batch_size: The size of each batch. The remainder is dropped.
  Returns:
    A TF Dataset containing padded batches.
  """
  # For shuffling we need to know how many training examples we have.
  num_examples = dataset.reduce(np.int64(0), lambda x, _: x + 1).numpy()

  # `padded_shapes` says what kind of shapes to expect: [] means a scalar, [-1]
  # means a vector of variable length, and [1] means a vector of size 1.
  return dataset.shuffle(
      num_examples, seed=seed, reshuffle_each_iteration=True).padded_batch(
          batch_size,
          padded_shapes={
              'idx': [],
              'sentence': [-1],
              'label': [1],
              'length': []
          },
          drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)


def get_batches(dataset: tf.data.Dataset,
                batch_size: int = 64) -> tf.data.Dataset:
  """Returns a Dataset that consists of padded batches when iterated over."""
  return dataset.padded_batch(
      batch_size,
      padded_shapes={
          'idx': [],
          'sentence': [-1],
          'label': [1],
          'length': []
      },
      drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)


class SST2DataSource:
  """Provides SST-2 data as pre-processed batches, a vocab, and embeddings."""
  # pylint: disable=too-few-public-methods

  def __init__(self, min_freq: int = 0):
    # Load datasets.
    data = tfds.load('glue/sst2')
    train_raw = data['train']
    valid_raw = data['validation']
    test_raw = data['test']

    # Print an example.
    logging.info('Data sample: %s', next(tfds.as_numpy(train_raw.skip(4))))

    # Get a vocabulary and a corresponding GloVe word embedding matrix.
    vocab = build_vocab((train_raw,), min_freq=min_freq)

    unk_idx = vocab[b'<unk>']
    bos_idx = vocab[b'<s>']
    eos_idx = vocab[b'</s>']

    # Turn data examples into pre-processed examples by turning each sentence
    # into a sequence of token IDs. Also pre-prepend a beginning-of-sequence
    # token <s> and append an end-of-sequence token </s>.

    def tokenize(text: tf.Tensor):
      """Whitespace tokenize text."""
      return [whitespace_tokenize(text.numpy())]

    def tf_tokenize(text: tf.Tensor):
      return tf.py_function(tokenize, [text], Tout=tf.string)

    def encode(tokens: tf.Tensor):
      """Encodes a sequence of tokens (strings) into a sequence of token IDs."""
      return [[vocab[t] if t in vocab else unk_idx for t in tokens.numpy()]]

    def tf_encode(tokens: tf.Tensor):
      """Maps tokens to token IDs."""
      return tf.py_function(encode, [tokens], Tout=tf.int64)

    def tf_wrap_sequence(sequence: tf.Tensor):
      """Prepends BOS ID and appends EOS ID to a sequence of token IDs."""
      return tf.concat(([bos_idx], tf.concat((sequence, [eos_idx]), 0)), 0)

    def preprocess_example(example: Dict[Text, tf.Tensor]):
      example['sentence'] = tf_wrap_sequence(
          tf_encode(tf_tokenize(example['sentence'])))
      example['label'] = [example['label']]
      example['length'] = tf.shape(example['sentence'])[0]
      return example

    self.preprocess_fn = preprocess_example

    # Pre-process all datasets.
    self.train_dataset = train_raw.map(preprocess_example).cache()
    self.valid_dataset = valid_raw.map(preprocess_example).cache()
    self.test_dataset = test_raw.map(preprocess_example).cache()

    self.valid_raw = valid_raw
    self.test_raw = test_raw

    self.vocab = vocab
    self.vocab_size = len(vocab)

    self.unk_idx = unk_idx
    self.bos_idx = bos_idx
    self.eos_idx = eos_idx

## Modeling

In [4]:
"""LSTM classifier model for SST-2."""

import functools
from typing import Any, Callable, Dict, Text

import flax
from flax import nn
import jax
import jax.numpy as jnp
from jax import lax

import numpy as np

# pylint: disable=arguments-differ,too-many-arguments


@functools.partial(jax.jit, static_argnums=(0, 1, 2, 3))
def create_model(seed: int, batch_size: int, max_len: int,
                 model_kwargs: Dict[Text, Any]):
  """Instantiates a new model."""
  module = TextClassifier.partial(train=False, **model_kwargs)
  _, initial_params = module.init_by_shape(
      jax.random.PRNGKey(seed),
      [((batch_size, max_len), jnp.int32),
       ((batch_size,), jnp.int32)])
  model = nn.Model(module, initial_params)
  return model


def word_dropout(inputs: jnp.ndarray, rate: float, unk_idx: int, 
        deterministic: bool = False):
  """Replaces a fraction (rate) of inputs with <unk>."""
  if deterministic or rate == 0.:
    return inputs

  mask = jax.random.bernoulli(nn.make_rng(), p=rate, shape=inputs.shape)
  return jnp.where(mask, jnp.array([unk_idx]), inputs)


class Embedding(nn.Module):
  """Embedding Module."""

  def apply(self,
            inputs: jnp.ndarray,
            num_embeddings: int,
            features: int,
            emb_init: Callable[...,
                               np.ndarray] = nn.initializers.normal(stddev=0.1),
            frozen: bool = False):
    # inputs.shape = <int64>[batch_size, seq_length]
    embedding = self.param('embedding', (num_embeddings, features), emb_init)
    embed = jnp.take(embedding, inputs, axis=0)
    if frozen:  # Keep the embeddings fixed at initial (pretrained) values.
      embed = lax.stop_gradient(embed)
    return embed


class LSTM(nn.Module):
  """LSTM encoder. Turns a sequence of vectors into a vector."""

  def apply(self,
            inputs: jnp.ndarray,
            lengths: jnp.ndarray,
            hidden_size: int = None):
    # inputs.shape = <float32>[batch_size, seq_length, emb_size].
    # lengths.shape = <int64>[batch_size,]
    batch_size = inputs.shape[0]
    carry = nn.LSTMCell.initialize_carry(
        jax.random.PRNGKey(0), (batch_size,), hidden_size)
    _, outputs = flax.jax_utils.scan_in_dim(
        nn.LSTMCell.partial(name='lstm_cell'), carry, inputs, axis=1)
    return outputs[jnp.arange(batch_size), jnp.maximum(0, lengths - 1), :]


class MLP(nn.Module):
  """A 2-layer MLP."""

  def apply(self,
            inputs: jnp.ndarray,
            hidden_size: int = None,
            output_size: int = None,
            output_bias: bool = False,
            dropout: float = None,
            train: bool = None):
    # inputs.shape = <float32>[batch_size, seq_length, hidden_size]
    hidden = nn.Dense(inputs, hidden_size, name='hidden')
    hidden = nn.tanh(hidden)
    if train:
      hidden = nn.dropout(hidden, rate=dropout)
    output = nn.Dense(hidden, output_size, bias=output_bias, name='output')
    return output


class LSTMClassifier(nn.Module):
  """LSTM classifier."""

  def apply(self,
            embed: jnp.ndarray,
            lengths: jnp.ndarray,
            hidden_size: int = None,
            output_size: int = None,
            dropout: float = None,
            emb_dropout: float = None,
            train: bool = None):
    """Encodes the input sequence and makes a prediction using an MLP."""
    # embed <float32>[batch_size, seq_length, embedding_size]
    # lengths <int64>[batch_size]
    if train:
      embed = nn.dropout(embed, rate=emb_dropout)

    # Encode the sequence of embedding using an LSTM.
    hidden = LSTM(embed, lengths, hidden_size=hidden_size, name='lstm')
    if train:
      hidden = nn.dropout(hidden, rate=dropout)

    # Predict the class using an MLP.
    logits = MLP(
        hidden,
        hidden_size=hidden_size,
        output_size=output_size,
        output_bias=False,
        dropout=dropout,
        name='mlp',
        train=train)
    return logits


class TextClassifier(nn.Module):
  """Full classification model."""

  def apply(self,
            inputs: jnp.ndarray,
            lengths: jnp.ndarray,
            unk_idx: int = 1,
            vocab_size: int = None,
            embedding_size: int = None,
            word_dropout_rate: float = None,
            freeze_embeddings: bool = None,
            train: bool = False,
            emb_init: Callable[..., Any] = nn.initializers.normal(stddev=0.1),
            **kwargs):
    # Apply word dropout.
    if train:
      inputs = word_dropout(inputs, rate=word_dropout_rate, unk_idx=unk_idx)

    # Embed the inputs.
    embed = Embedding(
        inputs,
        vocab_size,
        embedding_size,
        emb_init=emb_init,
        frozen=freeze_embeddings,
        name='embed')

    # Encode with LSTM and classify.
    logits = LSTMClassifier(
        embed, lengths, train=train, name='lstm_classifier', **kwargs)
    return logits

## train

In [5]:
import collections
from typing import Any, Dict, Text, Tuple

from absl import app
from absl import flags
from absl import logging

import flax
import flax.training.checkpoints
from flax import nn

import jax
import jax.numpy as jnp

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tensorflow.compat.v2.io import gfile

In [6]:
FLAGS = flags.FLAGS

flags.DEFINE_float(
    'learning_rate', default=0.0005,
    help=('The learning rate for the Adam optimizer.'))

flags.DEFINE_integer(
    'batch_size', default=64,
    help=('Batch size for training.'))

flags.DEFINE_integer(
    'num_epochs', default=20,
    help=('Number of training epochs.'))

flags.DEFINE_float(
    'dropout', default=0.5,
    help=('Dropout rate'))

flags.DEFINE_float(
    'emb_dropout', default=0.5,
    help=('Embedding dropout rate'))

flags.DEFINE_float(
    'word_dropout_rate', default=0.1,
    help=('Word dropout rate. Replaces input words with <unk>.'))

flags.DEFINE_string(
    'model_dir', default='output_dir',
    help=('Directory to store model data'))

flags.DEFINE_integer(
    'hidden_size', default=256,
    help=('Hidden size for the LSTM and MLP.'))

flags.DEFINE_integer(
    'embedding_size', default=256,
    help=('Size of the word embeddings.'))

flags.DEFINE_integer(
    'max_seq_len', default=55,
    help=('Maximum sequence length in the dataset.'))

flags.DEFINE_integer(
    'min_freq', default=5,
    help=('Minimum frequency for training set words to be in the vocabulary.'))

flags.DEFINE_float(
    'l2_reg', default=1e-6,
    help=('L2 regularization weight'))

flags.DEFINE_integer(
    'seed', default=0,
    help=('Random seed for network initialization.'))

flags.DEFINE_integer(
    'checkpoints_to_keep', default=1,
    help=('How many checkpoints to keep. Default: 1 (keep best model only)'))

In [15]:
@jax.vmap
def binary_cross_entropy_loss(logit: jnp.ndarray, label: jnp.ndarray):
  """Numerically stable binary cross entropy loss.
  This function is vmapped, so it is written for a single example, but can
  handle a batch of examples.
  Args:
    logit: The output logits.
    label: The correct labels.
  Returns:
    The binary cross entropy loss for each given logit.
  """
  return label * nn.softplus(-logit) + (1 - label) * nn.softplus(logit)


@jax.jit
def train_step(optimizer: Any, inputs: jnp.ndarray, lengths: jnp.ndarray,
               labels: jnp.ndarray, rng: Any, l2_reg: float):
  """Single optimized training step.
  Args:
    optimizer: The optimizer to use to update the weights.
    inputs: A batch of inputs. <int64>[batch_size, seq_len]
    lengths: The lengths of the sequences in the batch. <int64>[batch_size]
    labels: The labels of the sequences in the batch. <int64>[batch_size, 1]
    rng: Random number generator for dropout.
    l2_reg: L2 regularization weight.
  Returns:
    optimizer: The optimizer in its new state.
    loss: The loss for this step.
  """
  rng, new_rng = jax.random.split(rng)
  def loss_fn(model):
    with nn.stochastic(rng):
      logits = model(inputs, lengths, train=True)
    loss = jnp.mean(binary_cross_entropy_loss(logits, labels))

    # L2 regularization
    l2_params = jax.tree_leaves(model.params['lstm_classifier'])
    l2_weight = jnp.sum([jnp.sum(p ** 2) for p in l2_params])
    l2_penalty = l2_reg * l2_weight

    loss = loss + l2_penalty
    return loss, logits

  loss, _, grad = optimizer.compute_gradient(loss_fn)
  optimizer = optimizer.apply_gradient(grad)
  return optimizer, loss, new_rng


def get_predictions(logits: jnp.ndarray) -> jnp.ndarray:
  """Returns predictions given a batch of logits."""
  outputs = jax.nn.sigmoid(logits)
  return (outputs > 0.5).astype(jnp.int32)


def get_num_correct(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  """Returns the number of correct predictions."""
  return jnp.sum(get_predictions(logits) == labels)


@jax.jit
def eval_step(model: nn.Module, inputs: jnp.ndarray, lengths: jnp.ndarray,
              labels: jnp.ndarray):
  """A single evaluation step.
  Args:
    model: The model to be used for this evaluation step.
    inputs: A batch of inputs. <int64>[batch_size, seq_len]
    lengths: The lengths of the sequences in the batch. <int64>[batch_size]
    labels: The labels of the sequences in the batch. <int64>[batch_size, 1]
  Returns:
    loss: The summed loss on this batch.
    num_correct: The number of correct predictions in this batch.
  """
  logits = model(inputs, lengths, train=False)
  loss = jnp.sum(binary_cross_entropy_loss(logits, labels))
  num_correct = get_num_correct(logits, labels)
  return loss, num_correct


def evaluate(model: nn.Model, dataset: tf.data.Dataset):
  """Evaluates the model on a dataset.
  Args:
    model: A model to be evaluated.
    dataset: A dataset to be used for the evaluation. Typically valid or test.
  Returns:
    A dict with the evaluation results.
  """
  count = 0
  total_loss = 0.
  total_correct = 0

  for ex in tfds.as_numpy(dataset):
    inputs, lengths, labels = ex['sentence'], ex['length'], ex['label']
    count = count + inputs.shape[0]
    loss, num_correct = eval_step(model, inputs, lengths, labels)
    total_loss += loss.item()
    total_correct += num_correct.item()

  loss = total_loss / count
  accuracy = 100. * total_correct / count
  metrics = dict(loss=loss, acc=accuracy)

  return metrics


def log(stats, epoch, train_metrics, valid_metrics):
  """Logs performance for an epoch.
  Args:
    stats: A dictionary to be updated with the logged statistics.
    epoch: The epoch number.
    train_metrics: A dict with the training metrics for this epoch.
    valid_metrics: A dict with the validation metrics for this epoch.
  """
  train_loss = train_metrics['loss'] / train_metrics['total']
  logging.info('Epoch %02d train loss %.4f valid loss %.4f acc %.2f', epoch + 1,
               train_loss, valid_metrics['loss'], valid_metrics['acc'])

  # Remember the metrics for later plotting.
  stats['train_loss'].append(train_loss.item())
  for metric, value in valid_metrics.items():
    stats['valid_' + metric].append(value)

def train(
    model: nn.Model,
    learning_rate: float = None,
    num_epochs: int = None,
    seed: int = None,
    model_dir: Text = None,
    data_source: Any = None,
    batch_size: int = None,
    checkpoints_to_keep: int = None,
    l2_reg: float = None,
) -> Tuple[Dict[Text, Any], nn.Model]:
  """Training loop.
  Args:
    model: An initialized model to be trained.
    learning_rate: The learning rate.
    num_epochs: Train for this many epochs.
    seed: Seed for shuffling.
    model_dir: Directory to save best model.
    data_source: The data source with pre-processed data examples.
    batch_size: The batch size to use for training and validation data.
    l2_reg: L2 regularization weight.
  Returns:
    A dict with training statistics and the best model.
  """
  rng = jax.random.PRNGKey(seed)
  optimizer = flax.optim.Adam(learning_rate=learning_rate).create(model)
  stats = collections.defaultdict(list)
  best_score = 0.
  train_batches = get_shuffled_batches(
      data_source.train_dataset, batch_size=batch_size, seed=seed)
  valid_batches = get_batches(
      data_source.valid_dataset, batch_size=batch_size)

  for epoch in range(num_epochs):
    train_metrics = collections.defaultdict(float)

    # Train for one epoch.
    for ex in tfds.as_numpy(train_batches):
      inputs, lengths, labels = ex['sentence'], ex['length'], ex['label']
      optimizer, loss, rng = train_step(optimizer, inputs, lengths, labels, rng,
                                        l2_reg)
      train_metrics['loss'] += loss * inputs.shape[0]
      train_metrics['total'] += inputs.shape[0]

    # Evaluate on validation data. optimizer.target is the updated model.
    valid_metrics = evaluate(optimizer.target, valid_batches)
    log(stats, epoch, train_metrics, valid_metrics)

    # Save a checkpoint if this is the best model so far.
    if valid_metrics['acc'] > best_score:
      best_score = valid_metrics['acc']
      flax.training.checkpoints.save_checkpoint(
          model_dir, optimizer.target, epoch + 1, keep=checkpoints_to_keep)

  # Done training. Restore best model.
  logging.info('Training done! Best validation accuracy: %.2f', best_score)
  best_model = flax.training.checkpoints.restore_checkpoint(model_dir, model)

  return stats, best_model


In [8]:
tf.enable_v2_behavior()

In [9]:
data_source = SST2DataSource(min_freq=5)

In [10]:
model = create_model(
      3107,
      4,
      10,
      dict(
          vocab_size=data_source.vocab_size,
          embedding_size=50,
          hidden_size=50,
          output_size=1,
          unk_idx=data_source.unk_idx,
          dropout=0.5,
          emb_dropout=0.5,
          word_dropout_rate=0.1))



In [16]:
train_stats, model = train(
      model,
      learning_rate=0.0005,
      num_epochs=1,
      seed=3107,
      model_dir='output_dir',
      data_source=data_source,
      batch_size=4,
      checkpoints_to_keep=0,
      l2_reg=1e-6)





















Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


KeyboardInterrupt: 