In [None]:
import functools
import jax
from jax import numpy as jnp, random, lax
import numpy as onp

In [None]:
from flax import nn, struct

In [None]:
from flax.core.scope import Scope, init, apply, Array, group_kinds

In [None]:
def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,
          kernel_init=nn.linear.default_kernel_init,
          bias_init=nn.initializers.zeros):
  kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))
  y = jnp.dot(inputs, kernel)
  if bias:
    y += scope.param('bias', bias_init, (features,))
  return y

model_fn = functools.partial(dense, features=3)

x = jnp.ones((1, 2))
y, params = init(model_fn)(random.PRNGKey(0), x)
print(params)

def mlp(scope: Scope, inputs: Array, features: int):
  hidden = dense(scope.push('hidden'), inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

In [None]:
from typing import Iterable, Any

def dense_general(
    scope: Scope, inputs: Array, features: int, axis = -1,
    batch_dims=(), bias=True, dtype=jnp.float32,
    kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros,
    precision=None):
  """Applies a linear transformation to the inputs along multiple dimensions.

  Args:
    scope: module scope
    inputs: The nd-array to be transformed.
    features: tuple with numbers of output features.
    axis: tuple with axes to apply the transformation on.
    batch_dims: tuple with batch axes.
    bias: whether to add a bias to the output (default: True).
    dtype: the dtype of the computation (default: float32).
    kernel_init: initializer function for the weight matrix.
    bias_init: initializer function for the bias.
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
  Returns:
    The transformed input.
  """
  inputs = jnp.asarray(inputs, dtype)
  features = _as_tuple(features)
  axis = _as_tuple(axis)
  batch_dims = _as_tuple(batch_dims)
  
  if batch_dims:
    max_dim = onp.max(batch_dims)
    if set(batch_dims) != set(range(max_dim + 1)):
      raise ValueError(f'batch_dims {batch_dims} must be consecutive leading '
                       'dimensions starting from 0.')
      
  ndim = inputs.ndim
  n_batch_dims = len(batch_dims)
  axis = _normalize_axes(axis, ndim)
  batch_dims = _normalize_axes(batch_dims, ndim)
  n_axis, n_features = len(axis), len(features)

  batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
  kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
  kernel = scope.param('kernel', kernel_init, batch_shape + kernel_shape)
  kernel = jnp.asarray(kernel, dtype)

  batch_ind = tuple(range(n_batch_dims))
  contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
  out = lax.dot_general(inputs,
                        kernel,
                        ((axis, contract_ind), (batch_dims, batch_ind)),
                        precision=precision)
  if bias:
    bias = scope.param('bias', bias_init, batch_shape + features)

    # Reshape bias for broadcast.
    expand_dims = sorted(
        set(range(inputs.ndim)) - set(axis) - set(batch_dims))
    for ax in expand_dims:
      bias = jnp.expand_dims(bias, ax)
    bias = jnp.asarray(bias, dtype)
    out = out + bias
  return out

def _as_tuple(x):
  if isinstance(x, Iterable):
    return tuple(x)
  return (x,)

def _normalize_axes(axes, ndim):
  # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
  return tuple([ax if ax >= 0 else ndim + ax for ax in axes])

init(dense_general)(random.PRNGKey(0), jnp.ones((2, 3,)), features=1, batch_dims=0)

In [None]:
@struct.dataclass
class CacheEntry:
  key: onp.ndarray
  value: onp.ndarray
  i: onp.ndarray

def multi_head_dot_product_attention(
    scope: Scope,
          inputs_q,
          inputs_kv,
          num_heads,
          dtype=jnp.float32,
          qkv_features=None,
          out_features=None,
          attention_axis=None,
          causal_mask=False,
          padding_mask=None,
          key_padding_mask=None,
          segmentation=None,
          key_segmentation=None,
          cache=False,
          broadcast_dropout=True,
          dropout_rng=None,
          dropout_rate=0.,
          deterministic=False,
          precision=None,
          kernel_init=nn.attention.default_kernel_init,
          bias_init=nn.initializers.zeros,
          bias=True,
          attention_fn=nn.attention.dot_product_attention):
  """Applies multi-head dot product attention on the input data.

  Projects the inputs into multi-headed query, key, and value vectors,
  applies dot-product attention and project the results to an output vector.

  This can be used for encoder-decoder attention by specifying both `inputs_q`
  and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
  setting `inputs_kv` to None.

  Args:
    inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
    inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
      or None for self-attention, inn which case key/values will be derived
      from inputs_q.
    num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
      should be divisible by the number of heads.
    dtype: the dtype of the computation (default: float32)
    qkv_features: dimension of the key, query, and value.
    out_features: dimension of the last projection
    attention_axis: axes over which the attention is applied ( 'None' means
      attention over all axes, but batch, heads, and features).
    causal_mask: boolean specifying whether to apply a causal mask on the
      attention weights. If True, the output at timestep `t` will not depend
      on inputs at timesteps strictly greater than `t`.
    padding_mask: boolean specifying query tokens that are pad token.
    key_padding_mask: boolean specifying key-value tokens that are pad token.
    segmentation: segment indices for packed inputs_q data.
    key_segmentation: segment indices for packed inputs_kv data.
    cache: an instance of `flax.nn.attention.Cache` used for efficient
      autoregressive decoding.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the kernel of the Dense layers.
    bias_init: initializer for the bias of the Dense layers.
    bias: bool: whether pointwise QKVO dense transforms use bias.
    attention_fn: dot_product_attention or compatible function. Accepts
    query, key, value, and returns output of shape
    `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``

  Returns:
    output of shape `[bs, dim1, dim2, ..., dimN, features]`.
  """

  assert causal_mask or not cache, (
      'Caching is only support for causal attention.')

  if inputs_kv is None:
    inputs_kv = inputs_q

  if attention_axis is None:
    attention_axis = tuple(range(1, inputs_q.ndim - 1))

  features = out_features or inputs_q.shape[-1]
  qkv_features = qkv_features or inputs_q.shape[-1]

  assert qkv_features % num_heads == 0, (
      'Memory dimension must be divisible by number of heads.')
  head_dim = qkv_features // num_heads

  dense = functools.partial(dense_general,
      axis=-1,
      dtype=dtype,
      features=(num_heads, head_dim),
      kernel_init=kernel_init,
      bias_init=bias_init,
      bias=bias,
      precision=precision)
  # project inputs_q to multi-headed q/k/v
  # dimensions are then [bs, dims..., n_heads, n_features_per_head]
  query = scope.child(dense, 'query')(inputs_q)
  key = scope.child(dense, 'key')(inputs_kv)
  value = scope.child(dense, 'value')(inputs_kv)

  if cache:
    if not scope.has_variable('cache', 'entry'):
      ndim, tail_shape = (key.ndim, key.shape[-2:])
      def init_fn(shape):
        full_shape = shape + tail_shape
        if len(full_shape) != ndim:
          raise ValueError('Shape should be a tuple with the shape of the batch'
                          'and attention dims.')
        return CacheEntry(
            key=jnp.zeros(full_shape),
            value=jnp.zeros(full_shape),
            i=jnp.zeros((), jnp.uint32))
      cache_entry = init_fn
    else:
      cache_entry = scope.get_variable('cache', 'entry')
      if not isinstance(cache_entry, CacheEntry):
        raise ValueError('Cache is not initialized.')

      expected_shape = list(cache_entry.key.shape[:-2])
      for attn_dim in attention_axis:
        expected_shape[attn_dim] = 1
      expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
      if expected_shape != inputs_q.shape:
        raise ValueError('Invalid shape provided, '
                          'expected shape %s instead got %s.' %
                          (expected_shape, inputs_q.shape))

      cshape = cache_entry.key.shape
      indices = [0] * len(cshape)
      i = cache_entry.i
      attn_size = onp.prod(onp.take(cshape, attention_axis))
      for attn_dim in attention_axis:
        attn_size //= cshape[attn_dim]
        indices[attn_dim] = i // attn_size
        i = i % attn_size

      key = lax.dynamic_update_slice(cache_entry.key, key, indices)
      value = lax.dynamic_update_slice(cache_entry.value, value, indices)
      one = jnp.array(1, jnp.uint32)
      cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                        key=key,
                                        value=value)
      

      # TODO(levskaya): verify this is still needed in translation decoding.
      key_padding_mask = jnp.broadcast_to(
          (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
      key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]
    scope.put_variable('cache', 'entry', cache_entry)

  # create attention masks
  mask_components = []

  if causal_mask:
    if cache and isinstance(cache_entry, CacheEntry):
      bias_pre_shape = (1,) * (key.ndim - 1)
      attn_shape = tuple(onp.take(key.shape, attention_axis))
      attn_size = onp.prod(attn_shape)
      ii = jnp.arange(attn_size, dtype=jnp.uint32)
      mask = ii < cache_entry.i
      mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
    else:
      mask_components.append(nn.attention._make_causal_mask(key, attention_axis))

  if padding_mask is not None:
    if key_padding_mask is None:
      key_padding_mask = padding_mask
    padding_mask = nn.attention.make_padding_mask(
        padding_mask_query=padding_mask,
        padding_mask_key=key_padding_mask,
        query_shape=query.shape,
        key_shape=key.shape,
        attention_axis=attention_axis)
    mask_components.append(padding_mask)

  if segmentation is not None:
    if key_segmentation is None:
      key_segmentation = segmentation
    segmentation_mask = nn.attention.make_padding_mask(
        padding_mask_query=segmentation,
        padding_mask_key=key_segmentation,
        query_shape=query.shape,
        key_shape=key.shape,
        attention_axis=attention_axis,
        segmentation_mask=True)
    mask_components.append(segmentation_mask)

  if mask_components:
    attention_mask = mask_components[0]
    for component in mask_components[1:]:
      attention_mask = jnp.logical_and(attention_mask, component)

    # attention mask in the form of attention bias
    attention_bias = lax.select(
        attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype),
        jnp.full(attention_mask.shape, -1e10).astype(dtype))
  else:
    attention_bias = None

  # apply attention
  x = attention_fn(
      query,
      key,
      value,
      dtype=dtype,
      axis=attention_axis,
      bias=attention_bias,
      precision=precision,
      dropout_rng=dropout_rng,
      dropout_rate=dropout_rate,
      broadcast_dropout=broadcast_dropout,
      deterministic=deterministic)

  # back to the original inputs dimensions
  out = scope.child(dense_general, name='out')(
      x,
      features=features,
      axis=(-2, -1),
      kernel_init=kernel_init,
      bias_init=bias_init,
      bias=bias,
      dtype=dtype,
      precision=precision)

  return out

x = jnp.ones((1, 2, 4))
attn = functools.partial(multi_head_dot_product_attention, num_heads=1, cache=True, causal_mask=True)
y, variables = init(attn)(random.PRNGKey(0), x, x)
params = variables['param']
cache = jax.tree_map(lambda fn: fn((1, 2)), variables['cache'])
variables = variables.copy(cache=cache)

In [None]:
apply(attn, mutable='cache')(variables, x[:, 0:1], x)

In [None]:
@struct.dataclass
class Embedding:
  table: onp.ndarray

  def lookup(self, indices):
    return self.table[indices]

  def attend(self, query):
    return jnp.dot(query, self.table.T)

# all the embedding module does is provide a convenient initializers

def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=nn.linear.default_embed_init) -> Embedding:
  table = scope.param('table', init_fn, (num_embeddings, features))
  return Embedding(table)

embedding, _ = init(embedding)(random.PRNGKey(0), num_embeddings=2, features=3)
print(embedding.table)
print(embedding.lookup(1))
print(embedding.attend(jnp.ones((1, 3,))))

In [None]:
def lstm(scope, carry, inputs,
         gate_fn=nn.activation.sigmoid, activation_fn=nn.activation.tanh,
         kernel_init=nn.linear.default_kernel_init,
         recurrent_kernel_init=nn.initializers.orthogonal(),
         bias_init=nn.initializers.zeros):
  r"""A long short-term memory (LSTM) cell.

  the mathematical definition of the cell is as follows
  .. math::
      \begin{array}{ll}
      i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\
      f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\
      g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\
      o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\
      c' = f * c + i * g \\
      h' = o * \tanh(c') \\
      \end{array}
  where x is the input, h is the output of the previous time step, and c is
  the memory.

  Args:
    carry: the hidden state of the LSTM cell,
      initialized using `LSTMCell.initialize_carry`.
    inputs: an ndarray with the input for the current time step.
      All dimensions except the final are considered batch dimensions.
    gate_fn: activation function used for gates (default: sigmoid)
    activation_fn: activation function used for output and memory update
      (default: tanh).
    kernel_init: initializer function for the kernels that transform
      the input (default: lecun_normal).
    recurrent_kernel_init: initializer function for the kernels that transform
      the hidden state (default: orthogonal).
    bias_init: initializer for the bias parameters (default: zeros)
  Returns:
    A tuple with the new carry and the output.
  """
  c, h = carry
  hidden_features = h.shape[-1]
  # input and recurrent layers are summed so only one needs a bias.
  dense_h = lambda name: scope.child(dense, name)(
      h, features=hidden_features, bias=True,
      kernel_init=recurrent_kernel_init, bias_init=bias_init)
  dense_i = lambda name: scope.child(dense, name)(
      inputs, features=hidden_features, bias=False,
      kernel_init=kernel_init)
  i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))
  f = gate_fn(dense_i(name='if') + dense_h(name='hf'))
  g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))
  o = gate_fn(dense_i(name='io') + dense_h(name='ho'))
  new_c = f * c + i * g
  new_h = o * activation_fn(new_c)
  return (new_c, new_h), new_h

def lstm_init_carry(batch_dims, size, init_fn=jnp.zeros):
  shape = batch_dims + (size,)
  return init_fn(shape), init_fn(shape)

x = jnp.ones((1, 2))
carry = lstm_init_carry((1,), 3)
y, variables = init(lstm)(random.PRNGKey(0), carry, x)
jax.tree_map(onp.shape, (y, variables))