In [1]:
import jax.random as jr
import jax.numpy as jnp
import flax.linen as nn

from jaxtyping import Float, Int, Array
from modules import LayerNorm, Embed, PosEmbed, TransformerBlock, Unembed
from config import Config

In [2]:

class Transformer(nn.Module):
    cfg: Config

    def setup(self):
        self.embed = Embed(self.cfg)
        self.pos_embed = PosEmbed(self.cfg)
        self.blocks = [TransformerBlock(self.cfg) for _ in range(self.cfg.n_layers)]
        self.norm = LayerNorm(self.cfg)
        self.unembed = Unembed(self.cfg)

    def __call__(self, tokens: Int[Array, "b p"]) -> Float[Array, "b p v"]:
        x = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        x = self.unembed(x)
        return x


In [10]:
model = Transformer(cfg=Config(debug=True))
shape = (2, 4)
RNG = jr.PRNGKey(0)
variables = model.init(RNG, jnp.ones(shape, jnp.int32))
