In [3]:
import jax
import jax.numpy as jnp

from flax import linen as nn

## Implement a Transformer block as a layer

In [4]:
class TransformerBlock(nn.Module):
    """Transformer block."""
    embed_dim: int
    num_heads: int
    ff_dim: int
    rate: float = 0.1

    @nn.compact
    def __call__(self, x, training: bool):
        attn_output = nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.embed_dim)(x)
        attn_output = nn.Dropout(rate=self.rate, deterministic = not training)(attn_output)
        out1 = x + attn_output
        out1 = nn.LayerNorm()(out1)
        ffn_output = nn.relu(nn.Dense(self.ff_dim)(out1))
        ffn_output = nn.Dense(self.embed_dim)(ffn_output)
        ffn_output = nn.Dropout(rate=self.rate, deterministic = not training)(ffn_output)
        return nn.LayerNorm()(out1 + ffn_output)

## Implement embedding layer

In [5]:
class TokenAndPositionEmbedding(nn.Module):
    """ Combine token embedding and position embedding."""
    maxlen: int
    vocab_size: int
    embed_dim: int

    @nn.compact
    def __call__(self, x):
        positions = jnp.expand_dims(jnp.arange(x.shape[1]), axis=0)
        pos_emb = nn.Embed(num_embeddings=self.maxlen, features=self.embed_dim)(positions)
        tok_emb = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)(x)
        return pos_emb + tok_emb

In [6]:
class Transformer(nn.Module):
    """Transformer model."""
    num_classes: int
    vocab_size: int
    maxlen: int
    embed_dim: int
    num_heads: int
    ff_dim: int
    rate: float = 0.1
    
    @nn.compact
    def __call__(self, x, training: bool):
        x = TokenAndPositionEmbedding(maxlen=self.maxlen, vocab_size=self.vocab_size, embed_dim=self.embed_dim)(x)
        x = TransformerBlock(embed_dim=self.embed_dim, num_heads=self.num_heads, ff_dim=self.ff_dim, rate=self.rate)(x, training=training)
        x = jnp.mean(x, axis=1) # GlobalAveragePooling1D ?
        x = nn.Dropout(rate=self.rate, deterministic = not training)(x)
        x = nn.Dense(20)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.rate, deterministic = not training)(x)
        x = nn.Dense(self.num_classes)(x)
        # x = nn.log_softmax(x)
        return x

In [7]:
vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer
transformer = Transformer(num_classes=2, vocab_size=vocab_size, maxlen=maxlen, embed_dim=embed_dim, num_heads=num_heads, ff_dim=ff_dim)

In [8]:
# print(transformer.tabulate(jax.random.PRNGKey(0), jnp.ones_like(x_train[:1])))
print(transformer.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 200), dtype=jnp.int32), training=False))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)



[3m                              Transformer Summary                               [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs      [0m[1m [0m┃[1m [0m[1mparams      [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│               │ Transformer   │ -             │ [2mfloat32[0m[1,2]  │              │
│               │               │ [2mint32[0m[1,200]  │               │              │
│               │               │ - training:   │               │              │
│               │               │ False         │               │              │
├───────────────┼───────────────┼───────────────┼───────────────┼──────────────┤
│ TokenAndPosi… │ TokenAndPosi… │ [2mint32[0m[1,200]  │ [2mfloat32[0m[1,20… │              │
├───────────────┼────────────

In [9]:
from tensorflow import keras;

2023-02-19 00:44:22.929820: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-02-19 00:44:23.365061: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-19 00:44:23.365137: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [10]:
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax                           # Common loss functions and optimizers

In [11]:
@struct.dataclass
class Metrics(metrics.Collection):
  accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss')

In [12]:
rng = jax.random.PRNGKey(0)
transformer.init(rng, jnp.ones((1, 200), dtype=jnp.int32), training=False);

In [13]:
class TrainState(train_state.TrainState):
  metrics: Metrics
  key: jax.random.KeyArray

def create_train_state(module, root_key, learning_rate):
  """Creates an initial `TrainState`."""
  main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
  params = module.init(params_key, jnp.ones((1, 200), dtype=jnp.int32), training=False)['params'] # initialize parameters by passing a template image
  tx = optax.adam(learning_rate=learning_rate)
  return TrainState.create(
      apply_fn=module.apply, params=params, tx=tx, key=dropout_key, metrics=Metrics.empty())

In [14]:
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
  dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['token'],
      training=True,
      rngs={'dropout': dropout_train_key}
      )
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']).mean()
    return loss, logits
  
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

In [15]:
@jax.jit
def compute_metrics(*, state, batch):
  logits = state.apply_fn({'params': state.params}, x=batch['token'], training=False)
  loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
  metric_updates = state.metrics.single_from_model_output(
    logits=logits, labels=batch['label'], loss=loss)
  metrics = state.metrics.merge(metric_updates)
  state = state.replace(metrics=metrics)
  return state

In [16]:
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

In [17]:
state = create_train_state(transformer, params_key, learning_rate=1e-3)

In [18]:
metrics_history = {'train_loss': [],
                   'train_accuracy': [],
                   'test_loss': [],
                   'test_accuracy': []}

In [19]:

(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)

x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)

# We are going to drop the last few samples to make the batch size even.
batch_size = 32
num_batches = len(x_train) // batch_size
x_train = x_train[:num_batches * batch_size]
y_train = y_train[:num_batches * batch_size]
x_val = x_val[:num_batches * batch_size]
y_val = y_val[:num_batches * batch_size]

print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")


24992 Training sequences
24992 Validation sequences


In [20]:

num_steps_per_epoch = len(x_train) // batch_size
for epoch in range(1,4):
    for step, batch_ix in enumerate(range(0, len(x_train), 32)):
        batch = {"token": jnp.array(x_train[batch_ix:batch_ix+32]), "label": jnp.array(y_train[batch_ix:batch_ix+32])}
        state = train_step(state, batch, dropout_key=dropout_key)
        state = compute_metrics(state=state, batch=batch) # aggregate batch metrics

        if (step+1)*epoch % num_steps_per_epoch == 0: # one training epoch has passed
            for metric,value in state.metrics.compute().items(): # compute metrics
                metrics_history[f'train_{metric}'].append(value) # record metrics
            state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch

            # Compute metrics on the test set after each training epoch
            test_state = state
            # for test_batch_idx in test_ds.as_numpy_iterator():

            test_batch = {"token": jnp.array(x_val), "label": jnp.array(y_val)}
            test_state = compute_metrics(state=test_state, batch=test_batch)

            for metric,value in test_state.metrics.compute().items():
                metrics_history[f'test_{metric}'].append(value)

            print(f"train epoch: {((step+1)*epoch) // num_steps_per_epoch}, "
                f"loss: {metrics_history['train_loss'][-1]}, "
                f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
            print(f"test epoch: {(step+1)*epoch // num_steps_per_epoch}, "
                f"loss: {metrics_history['test_loss'][-1]}, "
                f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")


train epoch: 1, loss: 0.3994245231151581, accuracy: 80.27368927001953
test epoch: 1, loss: 0.33239904046058655, accuracy: 85.30729675292969
train epoch: 2, loss: 0.20272627472877502, accuracy: 92.36555480957031
test epoch: 2, loss: 0.35978537797927856, accuracy: 85.24327850341797
train epoch: 3, loss: 0.139784038066864, accuracy: 95.1624526977539
test epoch: 3, loss: 0.3845835328102112, accuracy: 85.91149139404297
