# Transformer from Scratch

Reimplementation of Transformer from Scratch using JAX and tx, original notebook by Callum McDougall.

In [None]:
import os, sys

sys.path.append(os.path.abspath(".."))

from jax import config

config.update("jax_enable_x64", True)


## Understanding Inputs and Outputs of a Transformer

In [None]:
from dataclasses import dataclass
import math
import re

import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
from jaxtyping import Array, Float, Int

from transformers import GPT2TokenizerFast

from tqdm import tqdm

import tx.modules
from tx.modules import AllIntermediates
from tx.models import PretrainedGPT2Model
from tx.network import GenerativeModel

from params import (
    tfs_layer_norm_params,
    tfs_embed_params,
    tfs_pos_embed_params,
    tfs_attention_params,
    tfs_mlp_params,
    tfs_block_params,
    tfs_unembed_params,
    tfs_transformer_params,
)


In [None]:
reference_gpt2 = GenerativeModel(
    config=PretrainedGPT2Model.tx_config,
    variables={"params": PretrainedGPT2Model.from_pretrained("gpt2").to_params()},
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
)


In [None]:
sorted_vocab = sorted(
    list(reference_gpt2.tokenizer.get_vocab().items()),
    key=lambda n: n[1],
)
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()


In [None]:
print(sorted_vocab[-20:])


In [None]:
print(reference_gpt2.to_str_list("Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list(" Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list(" ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list("ralph", prepend_bos=True, truncate=False))


In [None]:
print(reference_gpt2.to_str_list("56873+3184623=123456789-1000000000"))


In [None]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text, prepend_bos=True)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_list(tokens))


In [None]:
logits, state = reference_gpt2.run_with_intermediates(tokens, AllIntermediates)
print(logits.shape)


In [None]:
probs: Array = jax.nn.softmax(logits, axis=-1)
print(probs.shape)


In [None]:
most_likely_next_tokens = reference_gpt2.to_str_list(jnp.argmax(logits, axis=-1))
print(list(zip(reference_gpt2.to_str_list(tokens), most_likely_next_tokens)))


In [None]:
next_token = jnp.argmax(logits[-1], axis=-1, keepdims=True)
next_char = reference_gpt2.to_str(next_token)
print(repr(next_char))


In [None]:
print(reference_gpt2.to_str(tokens), end="", flush=True)

for i in range(10):
    print(next_char, end="", flush=True)
    # Define new input sequence, by appending the previously generated token
    tokens = jnp.concatenate([tokens, next_token], axis=-1)
    # # Pass our new sequence through the model, to get new output
    logits = reference_gpt2(tokens)
    # # Get the predicted token at the end of our sequence
    next_token = jnp.argmax(logits[-1], axis=-1, keepdims=True)
    # # Decode and print the result
    next_char = reference_gpt2.to_str(next_token)


## Clean Transformer Implementation

In [None]:
def p(x, indent=0):
    if isinstance(x, dict):
        for k, v in x.items():
            matches = re.findall(r"block_(\d+)", k)
            if matches and matches[0] != "0":
                continue

            print(f"{'  ' * indent}{k}", end="")
            if isinstance(v, dict):
                print(":")
                p(v, indent=indent + 1)
            else:
                print(f": ", end="")
                p(v)
    elif isinstance(x, list):
        p(x[0])
    elif isinstance(x, tuple):
        p(x[0])
    elif isinstance(x, Array):
        print(f"{' ' * indent}{x.shape}.")
    else:
        raise ValueError(f"Unknown type: {type(x)}")


p(state["intermediates"])


In [None]:
p(reference_gpt2.variables["params"])


In [None]:
print(reference_gpt2.config)


In [None]:
@dataclass
class Config:
    model_dim: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    vocab_dim: int = 50257
    init_range: float = 0.02
    context_length: int = 1024
    head_dim: int = 64
    mlp_dim: int = 3072
    num_heads: int = 12
    num_layers: int = 12
    dtype: jnp.dtype = jnp.float64
    param_dtype: jnp.dtype = jnp.float64


ex_cfg = Config()
print(ex_cfg)


In [None]:
import jax.random as jr


def rand_float_test(cls, shape):
    layer = cls(Config(debug=True))
    random_input: Array = jr.uniform(jr.PRNGKey(0), shape)
    print("Input shape:", random_input.shape)

    variables = layer.init(jr.PRNGKey(0), random_input)
    output: Array = layer.apply(variables, random_input)
    print("Output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    layer = cls(Config(debug=True))
    random_input: Array = jr.randint(jr.PRNGKey(0), shape, minval=100, maxval=1000)
    print("Input shape:", random_input.shape)

    variables = layer.init(jr.PRNGKey(0), random_input)
    output: Array = layer.apply(variables, random_input)
    print("Output shape:", output.shape, "\n")


def load_gpt2_test(cls, ref_cls, ref_cfg, variables, x: Array, ref_vars=None):
    # Initialise the layer to test
    layer = cls(cfg=Config(debug=True))
    print("Input shape:", x.shape)

    # Apply the layer to the input
    output = layer.apply(variables, x)
    print("Output shape:", output.shape)

    # Initialise the reference layer to test against
    # nn.vmap is used to apply the layer to each element of the batch
    ref_layer = nn.vmap(
        ref_cls,
        in_axes=0,
        out_axes=0,
        variable_axes={"params": None},
        split_rngs={"params": False},
    )(**ref_cfg)

    # Apply the reference layer to the input
    if ref_vars is None:
        reference_output = ref_layer.apply(variables, x)
    else:
        reference_output = ref_layer.apply(ref_vars, x)
    print("Reference output shape:", reference_output.shape, "\n")

    # Compare the output of the layer to the reference output
    comparison = jnp.isclose(output, reference_output, atol=1e-5, rtol=1e-5)
    print(
        f"{jnp.sum(comparison) / jnp.size(comparison):.2%} of the values are correct\n"
    )


gpt2_params = reference_gpt2.variables["params"]


In [None]:
class LayerNorm(nn.Module):
    cfg: Config

    def setup(self):
        self.w = self.param(
            "w",
            nn.initializers.ones,
            (self.cfg.model_dim,),
            self.cfg.param_dtype,
        )
        self.b = self.param(
            "b",
            nn.initializers.zeros,
            (self.cfg.model_dim,),
            self.cfg.param_dtype,
        )

    def __call__(
        self, residual: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        residual = residual.astype(self.cfg.dtype)
        residual_mean = jnp.mean(residual, axis=-1, keepdims=True)
        residual_std = jnp.sqrt(
            jnp.var(residual, axis=-1, keepdims=True) + self.cfg.layer_norm_eps
        )

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b


def layer_norm_config(cfg: Config):
    return {
        "epsilon": cfg.layer_norm_eps,
        "dtype": cfg.dtype,
        "param_dtype": cfg.param_dtype,
    }


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(
    cls=LayerNorm,
    ref_cls=tx.modules.LayerNorm,
    ref_cfg=layer_norm_config(ex_cfg),
    variables={"params": tfs_layer_norm_params(ex_cfg, gpt2_params["ln_f"])},
    ref_vars={"params": gpt2_params["ln_f"]},
    x=jnp.expand_dims(state["intermediates"]["residual"][-1], axis=0),
)


In [None]:
class Embed(nn.Module):
    cfg: Config

    def setup(self):
        self.W_E = self.param(
            "W_E",
            nn.initializers.normal(stddev=self.cfg.init_range),
            (self.cfg.vocab_dim, self.cfg.model_dim),
            self.cfg.param_dtype,
        )

    def __call__(
        self, tokens: Int[Array, "batch seq"]
    ) -> Float[Array, "batch seq model"]:
        return self.W_E[tokens]


def embed_config(cfg: Config):
    return {
        "num_embeddings": cfg.vocab_dim,
        "features": cfg.model_dim,
        "init_range": cfg.init_range,
        "param_dtype": cfg.param_dtype,
    }


rand_int_test(Embed, [2, 4])
load_gpt2_test(
    cls=Embed,
    ref_cls=tx.modules.Embed,
    ref_cfg=embed_config(ex_cfg),
    variables={"params": tfs_embed_params(ex_cfg, gpt2_params["embed"])},
    ref_vars={"params": gpt2_params["embed"]},
    x=jnp.expand_dims(tokens, axis=0),
)


In [None]:
class PosEmbed(nn.Module):
    cfg: Config

    def setup(self):
        self.W_pos = self.param(
            "W_pos",
            nn.initializers.normal(stddev=self.cfg.init_range),
            (self.cfg.context_length, self.cfg.model_dim),
            self.cfg.param_dtype,
        )

    def __call__(
        self, tokens: Int[Array, "batch seq"]
    ) -> Float[Array, "batch seq model"]:
        batch, seq_len = tokens.shape
        return einops.repeat(
            self.W_pos[:seq_len], "seq model -> batch seq model", batch=batch
        )


def pos_embed_config(cfg: Config):
    return {
        "num_embeddings": cfg.context_length,
        "features": cfg.model_dim,
        "init_range": cfg.init_range,
        "param_dtype": cfg.param_dtype,
    }


rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(
    cls=PosEmbed,
    ref_cls=tx.modules.PosEmbed,
    ref_cfg=pos_embed_config(ex_cfg),
    variables={"params": tfs_pos_embed_params(ex_cfg, gpt2_params["pos_embed"])},
    ref_vars={"params": gpt2_params["pos_embed"]},
    x=jnp.expand_dims(tokens, axis=0),
)


In [None]:
class Attention(nn.Module):
    cfg: Config

    def setup(self):
        init_fn = nn.initializers.normal(stddev=self.cfg.init_range)
        qkv_kernel_shape = (self.cfg.num_heads, self.cfg.model_dim, self.cfg.head_dim)
        self.W_Q = self.param("W_Q", init_fn, qkv_kernel_shape, self.cfg.param_dtype)
        self.W_K = self.param("W_K", init_fn, qkv_kernel_shape, self.cfg.param_dtype)
        self.W_V = self.param("W_V", init_fn, qkv_kernel_shape, self.cfg.param_dtype)
        self.W_O = self.param(
            "W_O",
            init_fn,
            (qkv_kernel_shape[0], qkv_kernel_shape[2], qkv_kernel_shape[1]),
            self.cfg.param_dtype,
        )

        qkv_bias_shape = (self.cfg.num_heads, self.cfg.head_dim)
        self.b_Q = self.param(
            "b_Q",
            nn.initializers.zeros,
            qkv_bias_shape,
            self.cfg.param_dtype,
        )
        self.b_K = self.param(
            "b_K",
            nn.initializers.zeros,
            qkv_bias_shape,
            self.cfg.param_dtype,
        )
        self.b_V = self.param(
            "b_V",
            nn.initializers.zeros,
            qkv_bias_shape,
            self.cfg.param_dtype,
        )
        self.b_O = self.param(
            "b_O",
            nn.initializers.zeros,
            (self.cfg.model_dim,),
            self.cfg.param_dtype,
        )

        self.IGNORE = jnp.array(-1e5, dtype=self.cfg.dtype)

    def __call__(
        self, normalized_resid_pre: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        normalized_resid_pre = normalized_resid_pre.astype(self.cfg.dtype)
        # Calculate query, key and value vectors
        q = (
            einops.einsum(
                normalized_resid_pre,
                self.W_Q,
                "batch seq model, n_head model h_dim -> batch seq n_head h_dim",
            )
            + self.b_Q
        )
        k = (
            einops.einsum(
                normalized_resid_pre,
                self.W_K,
                "batch seq model, n_head model h_dim -> batch seq n_head h_dim",
            )
            + self.b_K
        )
        v = (
            einops.einsum(
                normalized_resid_pre,
                self.W_V,
                "batch seq model, n_head model h_dim -> batch seq n_head h_dim",
            )
            + self.b_V
        )

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q,
            k,
            "batch seq_q n_head h_dim, batch seq_k n_head h_dim -> batch n_head seq_q seq_k",
        )
        attn_scores_masked = self.apply_causal_mask(
            attn_scores / self.cfg.head_dim**0.5
        )
        attn_pattern = jax.nn.softmax(attn_scores_masked, axis=-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v,
            attn_pattern,
            "batch seq_k n_head h_dim, batch n_head seq_q seq_k -> batch seq_q n_head h_dim",
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = (
            einops.einsum(
                z,
                self.W_O,
                "batch seq_q n_head h_dim, n_head h_dim model -> batch seq_q model",
            )
            + self.b_O
        )

        return attn_out

    def apply_causal_mask(
        self, attn_scores: Float[Array, "batch n_head seq_q seq_k"]
    ) -> Float[Array, "batch n_head seq_q seq_k"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = jnp.ones((attn_scores.shape[-2], attn_scores.shape[-1]))
        mask = jnp.triu(all_ones, k=1)
        # Apply the mask to attention scores, then return the masked scores
        attn_scores = jnp.where(mask, self.IGNORE, attn_scores)
        return attn_scores


rand_float_test(Attention, [2, 4, 768])


def attn_config(cfg: Config):
    return {
        "num_heads": cfg.num_heads,
        "head_dim": cfg.head_dim,
        "features": cfg.model_dim,
        "init_range": cfg.init_range,
        "dtype": cfg.dtype,
        "param_dtype": cfg.param_dtype,
    }


load_gpt2_test(
    cls=Attention,
    ref_cls=tx.modules.MultiHeadAttention,
    ref_cfg=attn_config(ex_cfg),
    variables={"params": tfs_attention_params(ex_cfg, gpt2_params["block_0"]["attn"])},
    ref_vars={"params": gpt2_params["block_0"]["attn"]},
    x=jnp.expand_dims(state["intermediates"]["block_0"]["ln_1_output"][0], axis=0),
)


In [None]:
class MLP(nn.Module):
    cfg: Config

    def setup(self):
        init_fn = nn.initializers.normal(stddev=self.cfg.init_range)
        self.W_in = self.param(
            "W_in",
            init_fn,
            (self.cfg.model_dim, self.cfg.mlp_dim),
            self.cfg.param_dtype,
        )
        self.W_out = self.param(
            "W_out",
            init_fn,
            (self.cfg.mlp_dim, self.cfg.model_dim),
            self.cfg.param_dtype,
        )
        self.b_in = self.param(
            "b_in",
            nn.initializers.zeros,
            (self.cfg.mlp_dim,),
            self.cfg.param_dtype,
        )
        self.b_out = self.param(
            "b_out",
            nn.initializers.zeros,
            (self.cfg.model_dim,),
            self.cfg.param_dtype,
        )
        self.IGNORE = jnp.array(-1e5, dtype=jnp.float32)

    def __call__(
        self, normalized_resid_mid: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        normalized_resid_mid = normalized_resid_mid.astype(self.cfg.dtype)
        pre = (
            einops.einsum(
                normalized_resid_mid,
                self.W_in,
                "batch seq model, model mlp -> batch seq mlp",
            )
            + self.b_in
        )
        post = nn.gelu(pre)
        mlp_out = (
            einops.einsum(
                post,
                self.W_out,
                "batch seq mlp, mlp model -> batch seq model",
            )
            + self.b_out
        )
        return mlp_out


def mlp_config(cfg: Config):
    return {
        "features": [cfg.mlp_dim, cfg.model_dim],
        "init_range": cfg.init_range,
        "dtype": cfg.dtype,
        "param_dtype": cfg.param_dtype,
    }


rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(
    cls=MLP,
    ref_cls=tx.modules.MLP,
    ref_cfg=mlp_config(ex_cfg),
    variables={"params": tfs_mlp_params(ex_cfg, gpt2_params["block_0"]["mlp"])},
    ref_vars={"params": gpt2_params["block_0"]["mlp"]},
    x=jnp.expand_dims(state["intermediates"]["block_0"]["ln_2_output"][0], axis=0),
)


In [None]:
class TransformerBlock(nn.Module):
    cfg: Config

    def setup(self):
        self.ln1 = LayerNorm(self.cfg)
        self.attn = Attention(self.cfg)
        self.ln2 = LayerNorm(self.cfg)
        self.mlp = MLP(self.cfg)

    def __call__(
        self, resid_pre: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        resid_pre = resid_pre.astype(self.cfg.dtype)
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post


def block_config(cfg: Config):
    return {
        "num_heads": cfg.num_heads,
        "head_dim": cfg.head_dim,
        "model_dim": cfg.model_dim,
        "mlp_dim": cfg.mlp_dim,
        "epsilon": cfg.layer_norm_eps,
        "init_range": cfg.init_range,
        "dtype": cfg.dtype,
        "param_dtype": cfg.param_dtype,
    }


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(
    cls=TransformerBlock,
    ref_cls=tx.modules.TransformerBlock,
    ref_cfg=block_config(ex_cfg),
    variables={"params": tfs_block_params(ex_cfg, gpt2_params["block_0"])},
    ref_vars={"params": gpt2_params["block_0"]},
    x=jnp.expand_dims(state["intermediates"]["residual"][0], axis=0),
)


In [None]:
class Unembed(nn.Module):
    cfg: Config

    def setup(self):
        init_fn = nn.initializers.normal(stddev=self.cfg.init_range)
        self.W_U = self.param(
            "W_U",
            init_fn,
            (self.cfg.model_dim, self.cfg.vocab_dim),
            self.cfg.param_dtype,
        )
        self.b_U = self.param(
            "b_U",
            nn.initializers.zeros,
            (self.cfg.vocab_dim,),
            self.cfg.param_dtype,
        )

    def __call__(
        self, normalized_resid_final: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq vocab"]:
        normalized_resid_final = normalized_resid_final.astype(self.cfg.dtype)
        return (
            einops.einsum(
                normalized_resid_final,
                self.W_U,
                "batch seq model, model vocab -> batch seq vocab",
            )
            + self.b_U
        )
        # Or, could just do `normalized_resid_final @ self.W_U + self.b_U`


def unembed_config(cfg: Config):
    return {
        "num_embeddings": cfg.vocab_dim,
        "features": cfg.model_dim,
        "init_range": cfg.init_range,
        "dtype": cfg.dtype,
        "param_dtype": cfg.param_dtype,
    }


rand_float_test(Unembed, [2, 4, 768])

load_gpt2_test(
    cls=Unembed,
    ref_cls=tx.modules.Unembed,
    ref_cfg=unembed_config(ex_cfg),
    variables={"params": tfs_unembed_params(ex_cfg, gpt2_params["unembed"])},
    ref_vars={"params": gpt2_params["unembed"]},
    x=jnp.expand_dims(state["intermediates"]["final_output"][0], axis=0),
)


In [None]:
class DemoTransformer(nn.Module):
    cfg: Config

    def setup(self):
        self.embed = Embed(self.cfg)
        self.pos_embed = PosEmbed(self.cfg)
        self.blocks = [
            TransformerBlock(name=f"block_{i}", cfg=self.cfg)
            for i in range(self.cfg.num_layers)
        ]
        self.ln_final = LayerNorm(self.cfg)
        self.unembed = Unembed(self.cfg)

    def __call__(
        self, tokens: Int[Array, "batch seq"]
    ) -> Float[Array, "batch seq vocab"]:
        residual = self.embed(tokens) + self.pos_embed(tokens)
        residual = residual.astype(self.cfg.dtype)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits


def transformer_config(cfg: Config):
    return {
        "vocab_dim": cfg.vocab_dim,
        "model_dim": cfg.model_dim,
        "mlp_dim": cfg.mlp_dim,
        "num_heads": cfg.num_heads,
        "head_dim": cfg.head_dim,
        "context_length": cfg.context_length,
        "init_range": cfg.init_range,
        "num_layers": cfg.num_layers,
        "layer_norm_eps": cfg.layer_norm_eps,
        "dtype": cfg.dtype,
        "param_dtype": cfg.param_dtype,
    }


rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(
    cls=DemoTransformer,
    ref_cls=tx.modules.Transformer,
    ref_cfg=transformer_config(ex_cfg),
    variables={"params": tfs_transformer_params(ex_cfg, gpt2_params)},
    ref_vars={"params": gpt2_params},
    x=jnp.expand_dims(tokens, axis=0),
)


In [None]:
demo_cfg = Config(debug=False)
demo_gpt2 = DemoTransformer(demo_cfg)

demo_logits = demo_gpt2.apply(
    {"params": tfs_transformer_params(demo_cfg, gpt2_params)},
    jnp.expand_dims(tokens, axis=0),
)


In [None]:
def get_log_probs(
    logits: Float[Array, "batch seq vocab"], tokens: Int[Array, "batch seq"]
) -> Float[Array, "batch seq-1"]:
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    expanded_tokens = jnp.expand_dims(tokens[:, 1:], axis=-1)
    y = jnp.take_along_axis(log_probs[:, :-1], expanded_tokens, axis=-1)
    return jnp.squeeze(y, axis=-1)


pred_log_probs = get_log_probs(demo_logits, jnp.expand_dims(tokens, axis=0))
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(
    f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.vocab_dim):4f}"
)
print(
    f"Avg probability assigned to correct token: {jnp.mean(jnp.exp(pred_log_probs)):4f}"
)


In [None]:
test_string = """The Total Perspective Vortex derives its picture of the whole Universe on the principle of"""
print(test_string, end="", flush=True)
for i in range(100):
    test_tokens = jnp.expand_dims(reference_gpt2.to_tokens(test_string), axis=0)
    demo_logits = demo_gpt2.apply(
        {"params": tfs_transformer_params(demo_cfg, gpt2_params)}, test_tokens
    )
    next_string = reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())
    print(next_string, end="", flush=True)
    test_string += next_string

