In [1]:
import jax.random as jr
import jax.numpy as jnp
import flax.linen as nn

from jaxtyping import Float, Array
from modules import LayerNorm, Embed, PosEmbed, TransformerBlock, Unembed


In [2]:
class Transformer(nn.Module):
    model_dim: int = 768
    layer_norm_eps: Float = 1e-5
    vocab_dim: int = 50_257
    context_length: int = 1024
    num_heads: int = 12
    head_dim: int = 64
    mlp_dim: int = 3072
    num_layers: int = 12

    @nn.compact
    def __call__(self, tokens) -> Array:
        embedding = Embed(
            features=self.model_dim,
            num_embeddings=self.vocab_dim,
        )(tokens)

        pos_embedding = PosEmbed(
            features=self.model_dim,
            context_length=self.context_length,
        )(tokens)

        x = embedding + pos_embedding

        for _ in range(self.num_layers):
            x = TransformerBlock(
                num_heads=self.num_heads,
                head_dim=self.head_dim,
                model_dim=self.model_dim,
                mlp_dim=self.mlp_dim,
                epsilon=self.layer_norm_eps,
            )(x)

        x = LayerNorm(epsilon=self.layer_norm_eps)(x)

        logits = Unembed(
            features=self.model_dim,
            num_embeddings=self.vocab_dim,
        )(x)

        return logits


print(Transformer().tabulate(jr.PRNGKey(0), jnp.ones((1, 1024), dtype=jnp.int32)))


[3m                              Transformer Summary                               [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│               │ Transformer   │ [2mint32[0m[1,1024] │ [2mfloat32[0m[1,1… │               │
├───────────────┼───────────────┼───────────────┼──────────────┼───────────────┤
│ Embed_0       │ Embed         │ [2mint32[0m[1,1024] │ [2mfloat32[0m[1,1… │ embedding:    │
│               │               │               │              │ [2mfloat32[0m[5025… │
│               │               │               │              │               │
│               │               │               │              │ [1m38,597,376 [0m   │
│            