In [None]:
import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np

## Self-Attention

Haiku has a built in MultiHeadAttention layer, but it does not include causal masking. The multi-head attention block can be extended to build a masked self-attention block. The block accepts a query, key, value as well as the mask and returns the output as a JAX array.

In [None]:
class SelfAttention(hk.MultiHeadAttention):
    def __call__(self, query: jnp.ndarray, key: Optional[jnp.ndarray] = None, value: Optional[jnp.ndarray] = None, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        key = key if key is not None else query
        value = value if value is not None else query

        seq_len = query.shape[1]
        casual_mask = np.tril(np.ones((seq_len, seq_len)))
        mask = mask * casual_mask if mask is not None else casual_mask

        return super().__call__(query, key, value, mask)


In [None]:
class DenseBlock(hk.Module):
    """A 2-layer MLP"""

    def __init__(self, init_scale: float, widening_factor: int = 4, name: Optional[str] = None):
        super().__init__(name)
        self._init_scale = init_scale
        self._widening_factor = widening_factor

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        hiddens = x.shape[-1]
        initializer = hk.initializers.VarianceScaling(self._init_scale)
        x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
        x = jax.nn.gelu(x)

        return hk.Linear(hiddens, w_init=initializer)(x)


In [None]:
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
    """Apply a unique layer normalization to x with default settings"""
    return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name)(x)


In [None]:
class Transformer(hk.Module):
    """Transformer Stack"""

    def __init__(self, num_heads: int, num_layers: int, dropout_rate: float, name: Optional[str] = None):
        super().__init__(name)
        self._num_heads = num_heads
        self._num_layers = num_layers
        self._dropout_rate = dropout_rate


    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray] = None, is_training: bool = True) -> jnp.ndarray:
        """
        Connects the transformer.

        Args:
            h: Inputs, [B, T, H].
            mask: Padding mask, [B, T] // the mask to apply to the attention scores
            is_training: Whether we are in training mode

        Returns:
            Array of shape [B, T, H].
        """

        init_scale = 2. / self._num_layers
        dropout_rate = self._dropout_rate if is_training else 0.

        if mask is not None:
            mask = mask[:, None, None, :]

        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f"h{i}_ln_1")
            h_attn = SelfAttention(num_heads=self._num_heads, key_size=64, w_init_scale=init_scale, name=f"h{i}_attn")(h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='h_ln_f')

        return h


In [None]:
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :
    tokens = data['obs']
    input_mask = jnp.greater(tokens, 0)
    seq_length = tokens.shape[1]

    # Embed the input tokens and positions.
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
    token_embs = token_embedding_map(tokens)
    positional_embeddings = hk.get_parameter(
        'pos_embs', [seq_length, d_model], init=embed_init)
    input_embeddings = token_embs + positional_embeddings
    return input_embeddings, input_mask

In [None]:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
                     num_layers: int, dropout_rate: float):
    """Create the model's forward pass."""

    def forward_fn(data: Mapping[str, jnp.ndarray],
                   is_training: bool = True) -> jnp.ndarray:
        """Forward pass."""
        input_embeddings, input_mask = embeddings(data, vocab_size)

        # Run the transformer over the inputs.
        transformer = Transformer(
            num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
        output_embeddings = transformer(input_embeddings, input_mask, is_training)

        # Reverse the embeddings (untied).
        return hk.Linear(vocab_size)(output_embeddings)

    return forward_fn

In [None]:
def lm_loss_fn(forward_fn,
               vocab_size: int,
               params,
               rng,
               data: Mapping[str, jnp.ndarray],
               is_training: bool = True) -> jnp.ndarray:
    """Compute the loss on data wrt params."""
    logits = forward_fn(params, rng, data, is_training)
    targets = jax.nn.one_hot(data['target'], vocab_size)
    assert logits.shape == targets.shape

    mask = jnp.greater(data['obs'], 0)
    loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
    loss = jnp.sum(loss * mask) / jnp.sum(mask)

    return loss