In [1]:
import dataclasses
import typing as tp
import jax.numpy as jnp
from flax import nnx
import numpy as np
import jax
from jax import lax

In [2]:
@dataclasses.dataclass
class TransformerConfig:
    vocab_size: int
    logits_via_embedding: bool = False
    dtype: tp.Any = jnp.float32
    emb_dim: int = 512
    num_heads: int = 8
    num_layers: int = 3
    qkv_dim: int = 512
    mlp_dim: int = 2048
    max_len: int = 2048
    dropout_rate: float = 0.1
    attention_dropout_rate: float = 0.1

In [3]:
def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
  def init(key, shape, dtype=np.float32):
    del key, dtype
    d_feature = shape[-1]
    pe = np.zeros((max_len, d_feature), dtype=np.float32)
    position = np.arange(0, max_len)[:, np.newaxis]
    scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
    div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor)
    pe[:, : d_feature // 2] = np.sin(position * div_term)
    pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term)
    pe = pe[np.newaxis, :, :]  # [1, max_len, d_feature]
    return jnp.array(pe)

  return init

def shift_right(x: jax.Array, axis: int = 1):
  """Shift the input to the right by padding and slicing on axis."""
  pad_widths: list[tuple[int, int]] = [(0, 0)] * len(x.shape)
  pad_widths[axis] = (1, 0)
  padded = jnp.pad(
    x, pad_widths, mode='constant', constant_values=x.dtype.type(0)
  )
  return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis)

def shift_inputs(x: jax.Array, axis: int = 1):
  """Shift inputs and replace EOS by 0 for packed inputs."""
  shifted = shift_right(x, axis=axis)

  return shifted

In [None]:
class AddPositionalEmbs(nnx.Module):
    def __init__(self, config, *, decode=False, rngs):
        self.config = config
        self.decode = decode
        self.init_func = sinusoidal_init(config.max_len)
        self.pos_embedding = self.init_func(rngs.params(), (config.max_len, config.emb_dim))

    def __call__(self, inputs):
        length = inputs.shape[1]
        if self.decode:
            _, _, df = self.pos_embedding.shape
            pos_embedding = lax.dynamic_slice(
                self.pos_embedding, jnp.array((0, self.cache_index.value, 0)), (1, 1, df)
            )
        else:
            pos_embedding = self.pos_embedding[:, :length, :]

        return inputs + pos_embedding
    
    def init_cache(self, input_shape, dtype = jnp.float32):
        self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.uint32))

In [5]:
class MlpBlock(nnx.Module):
    def __init__(self, config, rngs):
        self.config = config
        self.linear1 = nnx.Linear(config.emb_dim, config.mlp_dim, rngs=rngs)
        self.linear2 = nnx.Linear(config.mlp_dim, config.emb_dim, rngs=rngs)
        self.dropout = nnx.Dropout(rate=config.dropout_rate)

    def __call__(self, inputs, rngs):
        x = self.linear1(inputs)
        x = nnx.relu(x)
        x = self.dropout(x, rngs=rngs)
        output = self.linear2(x)
        output = self.dropout(output, rngs=rngs)
        return output

In [6]:
class EncoderDecoderBlock(nnx.Module):
    def __init__(self, config, rngs):
        self.config = config
        self.ln1 = nnx.LayerNorm(num_features=config.emb_dim, rngs=rngs)
        self.ln2 = nnx.LayerNorm(num_features=config.emb_dim, rngs=rngs)
        self.attention = nnx.MultiHeadAttention(
            num_heads=config.num_heads,
            in_features=config.emb_dim,
            qkv_features=config.qkv_dim,
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=config.attention_dropout_rate,
            rngs=rngs,
            )
        self.mlp = MlpBlock(config=config, rngs=rngs)
        self.dropout = nnx.Dropout(rate=config.dropout_rate)
        
    def __call__(self, inputs, rngs, decoder_mask):
        x = self.ln1(inputs)
        x = self.attention(x, rngs=rngs, mask=decoder_mask)
        x = self.dropout(x, rngs=rngs)
        x = x + inputs
        z = self.ln2(x)
        z = self.mlp(z, rngs)

        return z

In [7]:
class Decoder(nnx.Module):
    def __init__(self, config, *, decode = False, rngs):
        self.config = config
        self.decode = decode
        self.output_embed = nnx.Embed(num_embeddings=config.vocab_size, 
                                      features=config.emb_dim,
                                      rngs=rngs)
        self.posembed_out = AddPositionalEmbs(config=config, rngs=rngs)
        self.dropout = nnx.Dropout(rate=config.dropout_rate)

        for idx in range(config.num_layers):
            layer = EncoderDecoderBlock(config,rngs)
            setattr(self, f'encoderdecoderblock_{idx}', layer)

        self.encoderdecoder_norm = nnx.LayerNorm(num_features=config.emb_dim, rngs=rngs)

    def __call__(self, inputs, rngs, decoder_mask=None):
        y = inputs.astype('int32')
        if not self.decode:
            y = shift_inputs(y)
        y = self.output_embed(inputs)
        y = self.posembed_out(y)
        y = self.dropout(y, rngs=rngs)

        for idx in range(self.config.num_layers):
            layer = getattr(self, f'encoderdecoderblock_{idx}')
            y = layer(y, rngs=rngs, decoder_mask=decoder_mask)

        y = self.encoderdecoder_norm(y)

        logits = self.output_embed.attend(y)
        logits = logits / jnp.sqrt(y.shape[-1])

        return logits

In [8]:
class TransformerLM(nnx.Module):
    def __init__(self, config, *, decode=False, rngs):
        self.config = config
        self.decode = decode
        self.decoder = Decoder(config, rngs=rngs)

    def __call__(self, inputs, rngs):
        if self.decode:
            decoder_mask = None
        else:
            decoder_mask = nnx.combine_masks(
                nnx.make_attention_mask(inputs > 0, inputs > 0, dtype=self.config.dtype),
                nnx.make_causal_mask(inputs, dtype=self.config.dtype),
            )
            
        logits = self.decoder(inputs, rngs, decoder_mask=decoder_mask)
        return logits

In [9]:
config = TransformerConfig(vocab_size=30_728)

In [10]:
rngs = nnx.Rngs(0)

In [11]:
transformer = TransformerLM(config, rngs=rngs)

transformer.set_attributes(deterministic=False, decode=True)

In [12]:
from typing_extensions import Protocol, runtime_checkable

@runtime_checkable
class HasCache(Protocol):
  def init_cache(self, input_shape, dtype): ...

In [19]:
inputs = jax.random.randint(rngs.params(), (1, 1), minval=0, maxval=config.vocab_size)

In [20]:
for _path, m in transformer.iter_modules():
    if isinstance(m, HasCache):
      input_shape = (inputs.shape[0], config.max_len, config.emb_dim)
      m.init_cache(input_shape, dtype=config.dtype)

In [27]:
transformer(inputs, rngs)

Array([[[ 0.00925957, -0.02592195, -0.02881995, ..., -0.00104929,
         -0.00665741,  0.06749014]]], dtype=float32)

In [28]:
nnx.state(transformer, nnx.Cache)

ValueError: Arrays leaves are not supported, at 'decoder/posembed_out/pos_embedding': [[[ 0.          0.          0.         ...  1.          1.
    1.        ]
  [ 0.84147096  0.82177866  0.8018049  ...  1.          1.
    1.        ]
  [ 0.9092974   0.9365102   0.9582946  ...  1.          1.
    1.        ]
  ...
  [ 0.17589758 -0.44885764 -0.96926904 ...  0.9759369   0.97760755
    0.97916263]
  [-0.7333133   0.47858194 -0.381975   ...  0.9759134   0.97758573
    0.9791423 ]
  [-0.9683193   0.9942562   0.51274335 ...  0.97589     0.9775639
    0.979122  ]]]

In [None]:
inputs

In [None]:
for _ in range(300):
    inputs = jax.random.randint(rngs.params(), (6, config.max_len), minval=0, maxval=config.vocab_size)

    output = transformer(inputs, rngs)

In [None]:
nnx.displasy(transformer)

# TESTING

In [10]:
from flax import nnx
import jax.numpy as jnp

In [11]:
rngs = nnx.Rngs(0)

In [12]:
a = nnx.Embed(
    num_embeddings=30_000,
    features=128,
    embedding_init=nnx.initializers.normal(stddev=1.0),
    rngs=rngs,
    )

In [17]:
import jax

In [19]:
y = jax.random.uniform(rngs.params(), (128, ))

In [20]:
a.attend(y) / jnp.sqrt(y.shape[-1])

Array([ 0.69756776, -0.8900631 ,  0.10091793, ..., -0.8226665 ,
        0.25933772, -0.23882598], dtype=float32)

In [31]:
inputs = jnp.array([[1, 2, 3, 4, 1, 0, 0, 0, 0], [1, 2, 3, 4, 5, 0, 0, 0, 0] ])

In [32]:
inputs

Array([[1, 2, 3, 4, 1, 0, 0, 0, 0],
       [1, 2, 3, 4, 5, 0, 0, 0, 0]], dtype=int32)

In [33]:
nnx.make_causal_mask(inputs)

Array([[[[1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1.]]],


       [[[1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1.]]]], dtype=float32)

In [34]:
nnx.combine_masks(
        nnx.make_attention_mask(inputs > 0, inputs > 0),
        nnx.make_causal_mask(inputs),
      )

Array([[[[1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


       [[[1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]]]], dtype=float32)