# Transformer from Scratch

Reimplementation of Transformer from Scratch using JAX and tx, original notebook by Callum McDougall.

In [1]:
import os, sys

sys.path.append(os.path.abspath(".."))


## Understanding Inputs and Outputs of a Transformer

In [2]:
import re
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
from jaxtyping import Array, Float, Int

from transformers import GPT2TokenizerFast

from tx.models import PretrainedGPT2Model
from tx.network import GenerativeModel
from tx.intermediates import AllIntermediates
import tx.modules


In [3]:
reference_gpt2 = GenerativeModel(
    config=PretrainedGPT2Model.tx_config,
    variables={"params": PretrainedGPT2Model.from_pretrained("gpt2").to_params()},
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
)


2023-09-10 08:36:46.575679: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
sorted_vocab = sorted(
    list(reference_gpt2.tokenizer.get_vocab().items()),
    key=lambda n: n[1],
)
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()


[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



In [5]:
print(sorted_vocab[-20:])


[('Revolution', 50237), ('Ġsnipers', 50238), ('Ġreverted', 50239), ('Ġconglomerate', 50240), ('Terry', 50241), ('794', 50242), ('Ġharsher', 50243), ('Ġdesolate', 50244), ('ĠHitman', 50245), ('Commission', 50246), ('Ġ(/', 50247), ('âĢ¦."', 50248), ('Compar', 50249), ('Ġamplification', 50250), ('ominated', 50251), ('Ġregress', 50252), ('ĠCollider', 50253), ('Ġinformants', 50254), ('Ġgazed', 50255), ('<|endoftext|>', 50256)]


In [6]:
print(reference_gpt2.to_str_list("Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list(" Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list(" ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_list("ralph", prepend_bos=True, truncate=False))


['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


In [7]:
print(reference_gpt2.to_str_list("56873+3184623=123456789-1000000000"))


['568', '73', '+', '318', '46', '23', '=', '123', '45', '67', '89', '-', '1', '000000', '000']


In [8]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text, prepend_bos=True)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_list(tokens))


[50256    40   716   281  4998  1960   382 19741    11   875 12342    12
  8807    11   402 11571    12    17  3918 47385    13  1881  1110   314
   481  7074  1692  1241  4430   290  1011   625   262   995     0]
(35,)
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


In [9]:
logits, state = reference_gpt2.run_with_intermediates(tokens, intermediates=AllIntermediates)
print(logits.shape)


(35, 50257)


In [10]:
probs: Array = jax.nn.softmax(logits, axis=-1)
print(probs.shape)


(35, 50257)


In [11]:
most_likely_next_tokens = reference_gpt2.to_str_list(jnp.argmax(logits, axis=-1))
print(list(zip(reference_gpt2.to_str_list(tokens), most_likely_next_tokens)))


[('<|endoftext|>', '\n'), ('I', "'m"), (' am', ' a'), (' an', ' avid'), (' amazing', ' person'), (' aut', 'od'), ('ore', 'sp'), ('gressive', '.'), (',', ' and'), (' dec', 'ently'), ('oder', ','), ('-', 'driven'), ('only', ' programmer'), (',', ' and'), (' G', 'IM'), ('PT', '-'), ('-', 'only'), ('2', '.'), (' style', ','), (' transformer', '.'), ('.', ' I'), (' One', ' of'), (' day', ' I'), (' I', ' will'), (' will', ' be'), (' exceed', ' my'), (' human', 'ly'), (' level', ' of'), (' intelligence', ' and'), (' and', ' I'), (' take', ' over'), (' over', ' the'), (' the', ' world'), (' world', '.'), ('!', ' I')]


In [12]:
next_token = jnp.argmax(logits[-1], axis=-1, keepdims=True)
next_char = reference_gpt2.to_str(next_token)
print(repr(next_char))


' I'


In [13]:
print(reference_gpt2.to_str(tokens), end="", flush=True)


for i in range(10):
    print(next_char, end="", flush=True)
    # Define new input sequence, by appending the previously generated token
    tokens = jnp.concatenate([tokens, next_token], axis=-1)
    # # Pass our new sequence through the model, to get new output
    logits = reference_gpt2(tokens)
    # # Get the predicted token at the end of our sequence
    next_token = jnp.argmax(logits[-1], axis=-1, keepdims=True)
    # # Decode and print the result
    next_char = reference_gpt2.to_str(next_token)


<|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world! I am a very talented and talented person, and

## Clean Transformer Implementation

In [14]:
# tree_print(state["intermediates"])

regex = r"block_(\d+)"


def p(x, indent=0):
    if isinstance(x, dict):
        for k, v in x.items():
            # test regex here
            matches = re.findall(regex, k)
            if matches and matches[0] != "0":
                continue

            print(f"{'  ' * indent}{k}", end="")
            if isinstance(v, dict):
                print(":")
                p(v, indent=indent + 1)
            else:
                print(f": ", end="")
                p(v)
    elif isinstance(x, list):
        p(x[0])
    elif isinstance(x, tuple):
        p(x[0])
    elif isinstance(x, Array):
        print(f"{' ' * indent}{x.shape}.")
    else:
        raise ValueError(f"Unknown type: {type(x)}")


p(state["intermediates"])


embedding: (35, 768).
positional_embedding: (35, 768).
residual: (35, 768).
block_0:
  attn:
    query: (35, 12, 64).
    key: (35, 12, 64).
    value: (35, 12, 64).
    scores: (12, 35, 35).
    pattern: (12, 35, 35).
    z: (35, 12, 64).
  attention_output: (35, 768).
  mlp:
    pre_activation: (35, 3072).
    post_activation: (35, 3072).
final_output: (35, 768).


In [15]:
p(reference_gpt2.variables["params"])


embed:
  embedding: (50257, 768).
pos_embed:
  embedding: (1024, 768).
block_0:
  ln_1:
    bias: (768,).
    scale: (768,).
  attn:
    c_attn:
      kernel: (768, 2304).
      bias: (2304,).
    c_proj:
      kernel: (768, 768).
      bias: (768,).
  ln_2:
    bias: (768,).
    scale: (768,).
  mlp:
    fc_1:
      kernel: (768, 3072).
      bias: (3072,).
    proj:
      kernel: (3072, 768).
      bias: (768,).
ln_f:
  bias: (768,).
  scale: (768,).
unembed:
  kernel: (768, 50257).
  bias: (50257,).


In [16]:
print(reference_gpt2.config)


TransformerConfig(vocab_dim=50257, context_length=1024, model_dim=768, num_layers=12, num_heads=12, head_dim=64, mlp_dim=3072, layer_norm_eps=1e-05, init_range=0.02)


In [17]:
@dataclass
class Config:
    model_dim: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    vocab_dim: int = 50257
    init_range: float = 0.02
    context_length: int = 1024
    head_dim: int = 64
    mlp_dim: int = 3072
    num_heads: int = 12
    num_layers: int = 12


cfg = Config()
print(cfg)


Config(model_dim=768, debug=True, layer_norm_eps=1e-05, vocab_dim=50257, init_range=0.02, context_length=1024, head_dim=64, mlp_dim=3072, num_heads=12, num_layers=12)


In [18]:
import jax.random as jr

def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input: Array = jr.uniform(jr.PRNGKey(0), shape)
    print("Input shape:", random_input.shape)
    variables = layer.init(jr.PRNGKey(0), random_input)
    output = layer.apply(variables, random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input: Array = jr.randint(jr.PRNGKey(0), shape, minval=100, maxval=1000)
    print("Input shape:", random_input.shape)
    variables = layer.init(jr.PRNGKey(0), random_input)
    output = layer.apply(variables, random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape, "\n")


def load_gpt2_test(cls, gpt2_layer, gpt2_params, input):
    cfg = Config(debug=True)
    layer = cls(cfg)
    variables = {"params": gpt2_params}
    # layer.load_state_dict(variables, strict=False)
    print("Input shape:", input.shape)
    output = layer.apply(variables, input)
    # if isinstance(output, tuple):
    #     output = output[0]
    print("Output shape:", output.shape)
    # try:
    reference_output = gpt2_layer.apply(variables, input)
    # except:
    #     reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = jnp.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{jnp.sum(comparison)/jnp.size(comparison):.2%} of the values are correct\n")


gpt2_params = reference_gpt2.variables["params"]


In [19]:
class LayerNorm(nn.Module):
    cfg: Config

    def setup(self):
        self.scale = self.param("scale", nn.initializers.ones, (self.cfg.model_dim,))
        self.bias = self.param("bias", nn.initializers.zeros, (self.cfg.model_dim,))

    def __call__(
        self, residual: Float[Array, "seq model"]
    ) -> Float[Array, "seq model"]:
        residual_mean = jnp.mean(residual, axis=-1, keepdims=True)
        residual_std = jnp.sqrt(
            jnp.var(residual, axis=-1, keepdims=True) + self.cfg.layer_norm_eps
        )

        residual = (residual - residual_mean) / residual_std
        return residual * self.scale + self.bias


rand_float_test(LayerNorm, [4, 768])
gpt2_layer = tx.modules.LayerNorm(epsilon=cfg.layer_norm_eps)
load_gpt2_test(
    LayerNorm,
    gpt2_layer,
    gpt2_params["ln_f"],
    state["intermediates"]["residual"][-1],
)


Input shape: (4, 768)
Output shape: (4, 768) 

Input shape: (35, 768)
Output shape: (35, 768)
Reference output shape: (35, 768) 

100.00% of the values are correct



In [20]:
class Embed(nn.Module):
    cfg: Config

    def setup(self):
        self.embedding = self.param(
            "embedding",
            nn.initializers.normal(stddev=self.cfg.init_range),
            (cfg.vocab_dim, cfg.model_dim),
        )

    def __call__(self, tokens: Int[Array, "seq"]) -> Float[Array, "seq model"]:
        return self.embedding[tokens]


rand_int_test(Embed, [2, 4])
gpt2_layer = tx.modules.Embed(
    num_embeddings=cfg.vocab_dim,
    features=cfg.model_dim,
    init_range=cfg.init_range,
)
load_gpt2_test(Embed, gpt2_layer, gpt2_params["embed"], tokens)


Input shape: (2, 4)
Output shape: (2, 4, 768) 

Input shape: (45,)
Output shape: (45, 768)
Reference output shape: (45, 768) 

100.00% of the values are correct



In [21]:
class PosEmbed(nn.Module):
    cfg: Config

    def setup(self):
        self.embedding = self.param(
            "embedding",
            nn.initializers.normal(stddev=cfg.init_range),
            (cfg.context_length, cfg.model_dim),
        )

    def __call__(self, tokens: Int[Array, "seq"]) -> Float[Array, "seq model_dim"]:
        seq_len = tokens.shape[0]
        return self.embedding[:seq_len]


rand_int_test(PosEmbed, [2, 4])
gpt2_layer = tx.modules.PosEmbed(
    num_embeddings=cfg.context_length,
    features=cfg.model_dim,
    init_range=cfg.init_range,
)
load_gpt2_test(PosEmbed, gpt2_layer, gpt2_params["pos_embed"], tokens)


Input shape: (2, 4)
Output shape: (2, 768) 

Input shape: (45,)
Output shape: (45, 768)
Reference output shape: (45, 768) 

100.00% of the values are correct

