# Transformer from Scratch

Reimplementation of Transformer from Scratch using JAX and tx, original notebook by Callum McDougall.

In [1]:
import os, sys

sys.path.append(os.path.abspath(".."))


## Understanding Inputs and Outputs of a Transformer

In [2]:
from dataclasses import dataclass
import math
import re

import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
from jaxtyping import Array, Float, Int

from transformers import GPT2TokenizerFast

from tqdm import tqdm

from tx.models import PretrainedGPT2Model
from tx.network import GenerativeModel
from tx.modules import AllIntermediates
import tx.modules
from params import (
    tfs_layer_norm_params,
    tfs_embed_params,
    tfs_pos_embed_params,
    tfs_attention_params,
    tfs_mlp_params,
    tfs_block_params,
    tfs_unembed_params,
    tfs_transformer_params,
)


In [3]:
reference_gpt2 = GenerativeModel(
    config=PretrainedGPT2Model.tx_config,
    variables={"params": PretrainedGPT2Model.from_pretrained("gpt2").to_params()},
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
)


In [4]:
sorted_vocab = sorted(
    list(reference_gpt2.tokenizer.get_vocab().items()),
    key=lambda n: n[1],
)
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()


[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



In [5]:
print(sorted_vocab[-20:])


[('Revolution', 50237), ('Ġsnipers', 50238), ('Ġreverted', 50239), ('Ġconglomerate', 50240), ('Terry', 50241), ('794', 50242), ('Ġharsher', 50243), ('Ġdesolate', 50244), ('ĠHitman', 50245), ('Commission', 50246), ('Ġ(/', 50247), ('âĢ¦."', 50248), ('Compar', 50249), ('Ġamplification', 50250), ('ominated', 50251), ('Ġregress', 50252), ('ĠCollider', 50253), ('Ġinformants', 50254), ('Ġgazed', 50255), ('<|endoftext|>', 50256)]


In [6]:
print(reference_gpt2.to_str_list("Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list(" Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list(" ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list("ralph", prepend_bos=True, truncate=False))


['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


In [7]:
print(reference_gpt2.to_str_list("56873+3184623=123456789-1000000000"))


['568', '73', '+', '318', '46', '23', '=', '123', '45', '67', '89', '-', '1', '000000', '000']


In [8]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text, prepend_bos=True)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_list(tokens))


[50256    40   716   281  4998  1960   382 19741    11   875 12342    12
  8807    11   402 11571    12    17  3918 47385    13  1881  1110   314
   481  7074  1692  1241  4430   290  1011   625   262   995     0]
(35,)
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


In [9]:
logits, state = reference_gpt2.run_with_intermediates(
    tokens, intermediates=AllIntermediates
)
print(logits.shape)


(35, 50257)


In [10]:
probs: Array = jax.nn.softmax(logits, axis=-1)
print(probs.shape)


(35, 50257)


In [11]:
most_likely_next_tokens = reference_gpt2.to_str_list(jnp.argmax(logits, axis=-1))
print(list(zip(reference_gpt2.to_str_list(tokens), most_likely_next_tokens)))


[('<|endoftext|>', '\n'), ('I', "'m"), (' am', ' a'), (' an', ' avid'), (' amazing', ' person'), (' aut', 'od'), ('ore', 'sp'), ('gressive', '.'), (',', ' and'), (' dec', 'ently'), ('oder', ','), ('-', 'driven'), ('only', ' programmer'), (',', ' and'), (' G', 'IM'), ('PT', '-'), ('-', 'only'), ('2', '.'), (' style', ','), (' transformer', '.'), ('.', ' I'), (' One', ' of'), (' day', ' I'), (' I', ' will'), (' will', ' be'), (' exceed', ' my'), (' human', 'ly'), (' level', ' of'), (' intelligence', ' and'), (' and', ' I'), (' take', ' over'), (' over', ' the'), (' the', ' world'), (' world', '.'), ('!', '\n')]


In [12]:
next_token = jnp.argmax(logits[-1], axis=-1, keepdims=True)
next_char = reference_gpt2.to_str(next_token)
print(repr(next_char))


'\n'


In [13]:
print(reference_gpt2.to_str(tokens), end="", flush=True)

for i in range(10):
    print(next_char, end="", flush=True)
    # Define new input sequence, by appending the previously generated token
    tokens = jnp.concatenate([tokens, next_token], axis=-1)
    # # Pass our new sequence through the model, to get new output
    logits = reference_gpt2(tokens)
    # # Get the predicted token at the end of our sequence
    next_token = jnp.argmax(logits[-1], axis=-1, keepdims=True)
    # # Decode and print the result
    next_char = reference_gpt2.to_str(next_token)


<|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!

I am an amazing autoregressive,

## Clean Transformer Implementation

In [14]:
def p(x, indent=0):
    if isinstance(x, dict):
        for k, v in x.items():
            matches = re.findall(r"block_(\d+)", k)
            if matches and matches[0] != "0":
                continue

            print(f"{'  ' * indent}{k}", end="")
            if isinstance(v, dict):
                print(":")
                p(v, indent=indent + 1)
            else:
                print(f": ", end="")
                p(v)
    elif isinstance(x, list):
        p(x[0])
    elif isinstance(x, tuple):
        p(x[0])
    elif isinstance(x, Array):
        print(f"{' ' * indent}{x.shape}.")
    else:
        raise ValueError(f"Unknown type: {type(x)}")


p(state["intermediates"])


embedding: (35, 768).
positional_embedding: (35, 768).
residual: (35, 768).
block_0:
  ln_1_output: (35, 768).
  attn:
    query: (35, 12, 64).
    key: (35, 12, 64).
    value: (35, 12, 64).
    scores: (12, 35, 35).
    z: (35, 12, 64).
  attention_output: (35, 768).
  ln_2_output: (35, 768).
  mlp:
    pre_activation: (35, 3072).
    post_activation: (35, 3072).
final_output: (35, 768).


In [15]:
p(reference_gpt2.variables["params"])


embed:
  embedding: (50257, 768).
pos_embed:
  embedding: (1024, 768).
block_0:
  ln_1:
    bias: (768,).
    scale: (768,).
  attn:
    c_attn:
      kernel: (768, 2304).
      bias: (2304,).
    c_proj:
      kernel: (768, 768).
      bias: (768,).
  ln_2:
    bias: (768,).
    scale: (768,).
  mlp:
    fc_1:
      kernel: (768, 3072).
      bias: (3072,).
    proj:
      kernel: (3072, 768).
      bias: (768,).
ln_f:
  bias: (768,).
  scale: (768,).
unembed:
  kernel: (768, 50257).
  bias: (50257,).


In [16]:
print(reference_gpt2.config)


TransformerConfig(vocab_dim=50257, context_length=1024, model_dim=768, num_layers=12, num_heads=12, head_dim=64, mlp_dim=3072, layer_norm_eps=1e-05, init_range=0.02)


In [17]:
@dataclass
class Config:
    model_dim: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    vocab_dim: int = 50257
    init_range: float = 0.02
    context_length: int = 1024
    head_dim: int = 64
    mlp_dim: int = 3072
    num_heads: int = 12
    num_layers: int = 12


ex_cfg = Config()
print(ex_cfg)


Config(model_dim=768, debug=True, layer_norm_eps=1e-05, vocab_dim=50257, init_range=0.02, context_length=1024, head_dim=64, mlp_dim=3072, num_heads=12, num_layers=12)


In [18]:
import jax.random as jr


def rand_float_test(cls, shape):
    layer = cls(Config(debug=True))
    random_input: Array = jr.uniform(jr.PRNGKey(0), shape)
    print("Input shape:", random_input.shape)

    variables = layer.init(jr.PRNGKey(0), random_input)
    output: Array = layer.apply(variables, random_input)
    print("Output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    layer = cls(Config(debug=True))
    random_input: Array = jr.randint(jr.PRNGKey(0), shape, minval=100, maxval=1000)
    print("Input shape:", random_input.shape)

    variables = layer.init(jr.PRNGKey(0), random_input)
    output: Array = layer.apply(variables, random_input)
    print("Output shape:", output.shape, "\n")


def load_gpt2_test(cls, ref_cls, ref_cfg, variables, x: Array, ref_vars=None):
    # Initialise the layer to test
    layer = cls(cfg=Config(debug=True))
    print("Input shape:", x.shape)

    # Apply the layer to the input
    output = layer.apply(variables, x)
    print("Output shape:", output.shape)

    # Initialise the reference layer to test against
    # nn.vmap is used to apply the layer to each element of the batch
    ref_layer = nn.vmap(
        ref_cls,
        in_axes=0,
        out_axes=0,
        variable_axes={"params": None},
        split_rngs={"params": False},
    )(**ref_cfg)

    # Apply the reference layer to the input
    if ref_vars is None:
        reference_output = ref_layer.apply(variables, x)
    else:
        reference_output = ref_layer.apply(ref_vars, x)
    print("Reference output shape:", reference_output.shape, "\n")

    # Compare the output of the layer to the reference output
    comparison = jnp.isclose(output, reference_output, atol=1e-5, rtol=1e-5)
    print(
        f"{jnp.sum(comparison) / jnp.size(comparison):.2%} of the values are correct\n"
    )


gpt2_params = reference_gpt2.variables["params"]


In [19]:
class LayerNorm(nn.Module):
    cfg: Config

    def setup(self):
        self.w = self.param("w", nn.initializers.ones, (self.cfg.model_dim,))
        self.b = self.param("b", nn.initializers.zeros, (self.cfg.model_dim,))

    def __call__(
        self, residual: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        residual_mean = jnp.mean(residual, axis=-1, keepdims=True)
        residual_std = jnp.sqrt(
            jnp.var(residual, axis=-1, keepdims=True) + self.cfg.layer_norm_eps
        )

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b


def layer_norm_config(cfg: Config):
    return {"epsilon": cfg.layer_norm_eps}


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(
    cls=LayerNorm,
    ref_cls=tx.modules.LayerNorm,
    ref_cfg={"epsilon": ex_cfg.layer_norm_eps},
    variables={"params": tfs_layer_norm_params(ex_cfg, gpt2_params["ln_f"])},
    ref_vars={"params": gpt2_params["ln_f"]},
    x=jnp.expand_dims(state["intermediates"]["residual"][-1], axis=0),
)


Input shape: (2, 4, 768)
Output shape: (2, 4, 768) 

Input shape: (1, 35, 768)
Output shape: (1, 35, 768)
Reference output shape: (1, 35, 768) 

100.00% of the values are correct



In [20]:
class Embed(nn.Module):
    cfg: Config

    def setup(self):
        self.W_E = self.param(
            "W_E",
            nn.initializers.normal(stddev=self.cfg.init_range),
            (self.cfg.vocab_dim, self.cfg.model_dim),
        )

    def __call__(
        self, tokens: Int[Array, "batch seq"]
    ) -> Float[Array, "batch seq model"]:
        return self.W_E[tokens]


def embed_config(cfg: Config):
    return {
        "num_embeddings": cfg.vocab_dim,
        "features": cfg.model_dim,
        "init_range": cfg.init_range,
    }


rand_int_test(Embed, [2, 4])
load_gpt2_test(
    cls=Embed,
    ref_cls=tx.modules.Embed,
    ref_cfg=embed_config(ex_cfg),
    variables={"params": tfs_embed_params(ex_cfg, gpt2_params["embed"])},
    ref_vars={"params": gpt2_params["embed"]},
    x=jnp.expand_dims(tokens, axis=0),
)


Input shape: (2, 4)
Output shape: (2, 4, 768) 

Input shape: (1, 45)
Output shape: (1, 45, 768)
Reference output shape: (1, 45, 768) 

100.00% of the values are correct



In [21]:
class PosEmbed(nn.Module):
    cfg: Config

    def setup(self):
        self.W_pos = self.param(
            "W_pos",
            nn.initializers.normal(stddev=self.cfg.init_range),
            (self.cfg.context_length, self.cfg.model_dim),
        )

    def __call__(
        self, tokens: Int[Array, "batch seq"]
    ) -> Float[Array, "batch seq model"]:
        batch, seq_len = tokens.shape
        return einops.repeat(
            self.W_pos[:seq_len], "seq model -> batch seq model", batch=batch
        )


def pos_embed_config(cfg: Config):
    return {
        "num_embeddings": cfg.context_length,
        "features": cfg.model_dim,
        "init_range": cfg.init_range,
    }


rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(
    cls=PosEmbed,
    ref_cls=tx.modules.PosEmbed,
    ref_cfg=pos_embed_config(ex_cfg),
    variables={"params": tfs_pos_embed_params(ex_cfg, gpt2_params["pos_embed"])},
    ref_vars={"params": gpt2_params["pos_embed"]},
    x=jnp.expand_dims(tokens, axis=0),
)


Input shape: (2, 4)
Output shape: (2, 4, 768) 

Input shape: (1, 45)
Output shape: (1, 45, 768)
Reference output shape: (1, 45, 768) 

100.00% of the values are correct



In [23]:
class Attention(nn.Module):
    cfg: Config

    def setup(self):
        init_fn = nn.initializers.normal(stddev=self.cfg.init_range)
        qkv_kernel_shape = (self.cfg.num_heads, self.cfg.model_dim, self.cfg.head_dim)
        self.W_Q = self.param("W_Q", init_fn, qkv_kernel_shape)
        self.W_K = self.param("W_K", init_fn, qkv_kernel_shape)
        self.W_V = self.param("W_V", init_fn, qkv_kernel_shape)
        self.W_O = self.param(
            "W_O",
            init_fn,
            (qkv_kernel_shape[0], qkv_kernel_shape[2], qkv_kernel_shape[1]),
        )

        qkv_bias_shape = (self.cfg.num_heads, self.cfg.head_dim)
        self.b_Q = self.param("b_Q", nn.initializers.zeros, qkv_bias_shape)
        self.b_K = self.param("b_K", nn.initializers.zeros, qkv_bias_shape)
        self.b_V = self.param("b_V", nn.initializers.zeros, qkv_bias_shape)
        self.b_O = self.param("b_O", nn.initializers.zeros, (self.cfg.model_dim,))

        self.IGNORE = jnp.array(-1e5, dtype=jnp.float32)

    def __call__(
        self, normalized_resid_pre: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        # Calculate query, key and value vectors
        q = (
            einops.einsum(
                normalized_resid_pre,
                self.W_Q,
                "batch seq model, n_head model h_dim -> batch seq n_head h_dim",
            )
            + self.b_Q
        )
        k = (
            einops.einsum(
                normalized_resid_pre,
                self.W_K,
                "batch seq model, n_head model h_dim -> batch seq n_head h_dim",
            )
            + self.b_K
        )
        v = (
            einops.einsum(
                normalized_resid_pre,
                self.W_V,
                "batch seq model, n_head model h_dim -> batch seq n_head h_dim",
            )
            + self.b_V
        )

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q,
            k,
            "batch seq_q n_head h_dim, batch seq_k n_head h_dim -> batch n_head seq_q seq_k",
        )
        attn_scores_masked = self.apply_causal_mask(
            attn_scores / self.cfg.head_dim**0.5
        )
        attn_pattern = jax.nn.softmax(attn_scores_masked, axis=-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v,
            attn_pattern,
            "batch seq_k n_head h_dim, batch n_head seq_q seq_k -> batch seq_q n_head h_dim",
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = (
            einops.einsum(
                z,
                self.W_O,
                "batch seq_q n_head h_dim, n_head h_dim model -> batch seq_q model",
            )
            + self.b_O
        )

        return attn_out

    def apply_causal_mask(
        self, attn_scores: Float[Array, "batch n_head seq_q seq_k"]
    ) -> Float[Array, "batch n_head seq_q seq_k"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = jnp.ones((attn_scores.shape[-2], attn_scores.shape[-1]))
        mask = jnp.triu(all_ones, k=1)
        # Apply the mask to attention scores, then return the masked scores
        # attn_scores.masked_fill_(mask, self.IGNORE)
        attn_scores = jnp.where(mask, self.IGNORE, attn_scores)
        return attn_scores


rand_float_test(Attention, [2, 4, 768])


def attn_config(cfg: Config):
    return {
        "num_heads": cfg.num_heads,
        "head_dim": cfg.head_dim,
        "model_dim": cfg.model_dim,
        "init_range": cfg.init_range,
    }


load_gpt2_test(
    cls=Attention,
    ref_cls=tx.modules.Attention,
    ref_cfg=attn_config(ex_cfg),
    variables={"params": tfs_attention_params(ex_cfg, gpt2_params["block_0"]["attn"])},
    ref_vars={"params": gpt2_params["block_0"]["attn"]},
    x=jnp.expand_dims(state["intermediates"]["block_0"]["ln_1_output"][0], axis=0),
)



Input shape: (2, 4, 768)
Output shape: (2, 4, 768) 

Input shape: (1, 35, 768)
Output shape: (1, 35, 768)
Reference output shape: (1, 35, 768) 

91.85% of the values are correct



In [24]:
class MLP(nn.Module):
    cfg: Config

    def setup(self):
        init_fn = nn.initializers.normal(stddev=self.cfg.init_range)
        self.W_in = self.param("W_in", init_fn, (self.cfg.model_dim, self.cfg.mlp_dim))
        self.W_out = self.param(
            "W_out", init_fn, (self.cfg.mlp_dim, self.cfg.model_dim)
        )
        self.b_in = self.param("b_in", nn.initializers.zeros, (self.cfg.mlp_dim,))
        self.b_out = self.param("b_out", nn.initializers.zeros, (self.cfg.model_dim,))
        self.IGNORE = jnp.array(-1e5, dtype=jnp.float32)

    def __call__(
        self, normalized_resid_mid: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        pre = (
            einops.einsum(
                normalized_resid_mid,
                self.W_in,
                "batch seq model, model mlp -> batch seq mlp",
            )
            + self.b_in
        )
        post = nn.gelu(pre)
        mlp_out = (
            einops.einsum(
                post,
                self.W_out,
                "batch seq mlp, mlp model -> batch seq model",
            )
            + self.b_out
        )
        return mlp_out


def mlp_config(cfg: Config):
    return {
        "features": [cfg.mlp_dim, cfg.model_dim],
        "init_range": cfg.init_range,
    }


rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(
    cls=MLP,
    ref_cls=tx.modules.MLP,
    ref_cfg=mlp_config(ex_cfg),
    variables={"params": tfs_mlp_params(ex_cfg, gpt2_params["block_0"]["mlp"])},
    ref_vars={"params": gpt2_params["block_0"]["mlp"]},
    x=jnp.expand_dims(state["intermediates"]["block_0"]["ln_2_output"][0], axis=0),
)


Input shape: (2, 4, 768)
Output shape: (2, 4, 768) 

Input shape: (1, 35, 768)
Output shape: (1, 35, 768)
Reference output shape: (1, 35, 768) 

100.00% of the values are correct



In [26]:
class TransformerBlock(nn.Module):
    cfg: Config

    def setup(self):
        self.ln1 = LayerNorm(self.cfg)
        self.attn = Attention(self.cfg)
        self.ln2 = LayerNorm(self.cfg)
        self.mlp = MLP(self.cfg)

    def __call__(
        self, resid_pre: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq model"]:
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post


def block_config(cfg: Config):
    return {
        "num_heads": cfg.num_heads,
        "head_dim": cfg.head_dim,
        "model_dim": cfg.model_dim,
        "mlp_dim": cfg.mlp_dim,
        "epsilon": cfg.layer_norm_eps,
        "init_range": cfg.init_range,
    }


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(
    cls=TransformerBlock,
    ref_cls=tx.modules.TransformerBlock,
    ref_cfg=block_config(ex_cfg),
    variables={"params": tfs_block_params(ex_cfg, gpt2_params["block_0"])},
    ref_vars={"params": gpt2_params["block_0"]},
    x=jnp.expand_dims(state["intermediates"]["residual"][0], axis=0),
)


Input shape: (2, 4, 768)
Output shape: (2, 4, 768) 

Input shape: (1, 35, 768)
Output shape: (1, 35, 768)
Reference output shape: (1, 35, 768) 

25.15% of the values are correct



In [27]:
class Unembed(nn.Module):
    cfg: Config

    def setup(self):
        init_fn = nn.initializers.normal(stddev=self.cfg.init_range)
        self.W_U = self.param("W_U", init_fn, (self.cfg.model_dim, self.cfg.vocab_dim))
        self.b_U = self.param("b_U", nn.initializers.zeros, (self.cfg.vocab_dim,))

    def __call__(
        self, normalized_resid_final: Float[Array, "batch seq model"]
    ) -> Float[Array, "batch seq vocab"]:
        return (
            einops.einsum(
                normalized_resid_final,
                self.W_U,
                "batch seq model, model vocab -> batch seq vocab",
            )
            + self.b_U
        )
        # Or, could just do `normalized_resid_final @ self.W_U + self.b_U`


def unembed_config(cfg: Config):
    return {
        "num_embeddings": cfg.vocab_dim,
        "features": cfg.model_dim,
        "init_range": cfg.init_range,
    }


rand_float_test(Unembed, [2, 4, 768])

load_gpt2_test(
    cls=Unembed,
    ref_cls=tx.modules.Unembed,
    ref_cfg=unembed_config(ex_cfg),
    variables={"params": tfs_unembed_params(ex_cfg, gpt2_params["unembed"])},
    ref_vars={"params": gpt2_params["unembed"]},
    x=jnp.expand_dims(state["intermediates"]["final_output"][0], axis=0),
)


Input shape: (2, 4, 768)
Output shape: (2, 4, 50257) 

Input shape: (1, 35, 768)
Output shape: (1, 35, 50257)
Reference output shape: (1, 35, 50257) 

100.00% of the values are correct



In [28]:
class DemoTransformer(nn.Module):
    cfg: Config

    def setup(self):
        self.embed = Embed(self.cfg)
        self.pos_embed = PosEmbed(self.cfg)
        self.blocks = [
            TransformerBlock(name=f"block_{i}", cfg=self.cfg)
            for i in range(self.cfg.num_layers)
        ]
        self.ln_final = LayerNorm(self.cfg)
        self.unembed = Unembed(self.cfg)

    def __call__(
        self, tokens: Int[Array, "batch seq"]
    ) -> Float[Array, "batch seq vocab"]:
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits


def transformer_config(cfg: Config):
    return {
        "vocab_dim": cfg.vocab_dim,
        "model_dim": cfg.model_dim,
        "mlp_dim": cfg.mlp_dim,
        "num_heads": cfg.num_heads,
        "head_dim": cfg.head_dim,
        "context_length": cfg.context_length,
        "init_range": cfg.init_range,
        "num_layers": cfg.num_layers,
    }


rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(
    cls=DemoTransformer,
    ref_cls=tx.modules.Transformer,
    ref_cfg=transformer_config(ex_cfg),
    variables={"params": tfs_transformer_params(ex_cfg, gpt2_params)},
    ref_vars={"params": gpt2_params},
    x=jnp.expand_dims(tokens, axis=0),
)


Input shape: (2, 4)
Output shape: (2, 4, 50257) 

Input shape: (1, 45)
Output shape: (1, 45, 50257)
Reference output shape: (1, 45, 50257) 

10.20% of the values are correct



In [29]:
demo_cfg = Config(debug=False)
demo_gpt2 = DemoTransformer(demo_cfg)

demo_logits = demo_gpt2.apply(
    {"params": tfs_transformer_params(demo_cfg, gpt2_params)},
    jnp.expand_dims(tokens, axis=0),
)


In [30]:
def get_log_probs(
    logits: Float[Array, "batch seq vocab"], tokens: Int[Array, "batch seq"]
) -> Float[Array, "batch seq-1"]:
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    expanded_tokens = jnp.expand_dims(tokens[:, 1:], axis=-1)
    y = jnp.take_along_axis(log_probs[:, :-1], expanded_tokens, axis=-1)
    return jnp.squeeze(y, axis=-1)


pred_log_probs = get_log_probs(demo_logits, jnp.expand_dims(tokens, axis=0))
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(
    f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.vocab_dim):4f}"
)
print(
    f"Avg probability assigned to correct token: {jnp.mean(jnp.exp(pred_log_probs)):4f}"
)





Avg cross entropy loss: 3.6749
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.211349


In [31]:
test_string = """The Total Perspective Vortex derives its picture of the whole Universe on the principle of"""
for i in tqdm(range(100)):
    test_tokens = jnp.expand_dims(reference_gpt2.to_tokens(test_string), axis=0)
    demo_logits = demo_gpt2.apply(
        {"params": tfs_transformer_params(demo_cfg, gpt2_params)}, test_tokens
    )
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)


  4%|▍         | 4/100 [00:05<02:11,  1.37s/it]


KeyboardInterrupt: 