---
title: "Implementing GPT2 in JAX for fun 🦀🦀🦀"
author:
  - name: "Tugdual Kerjan"
    url: https://tugdual.fr
    email: tkerjan@outlook.com
date: "November 9, 2024"
number-sections: true
reference-location: margin
toc: true
format: 
  html:
    standalone: true
    embed-resources: true
    self-contained-math: true
    code-fold: false
    code-tools: true
execute:
  output:
    false
bibliography: assets/bib.bibtex
theme: united
github: "https://github.com/TugdualKerjan/GPT2-for-JAX"
lightbox: true
---

# GPT2 for JAX 🚀  

Explore the full project on the [GitHub repository](https://github.com/TugdualKerjan/GPT2-for-JAX).

## Context ✍️  

This project involves rewriting XTTS in JAX to better understand its architecture and functionality. Originally developed by the now-defunct Coqai company, XTTS is a Text-to-Speech model. We'll recreate its generative component using a GPT2 architecture—a decoder-only transformer—based on [@radford2019language]. The implementation closely follows this [tutorial](https://huggingface.co/blog/sachithgunasekara/nanojaxgpt).  

![The crux of the GPT2 architecture. Layers composed of masked attention and forwards.](assets/architecture.png)  


## GPT2 in Text-to-Speech  

### What are we building?  

Our goal is to generate sequences of tokens for audio synthesis. Specifically, we aim to produce "audio tokens," small units of audio, discovered using a [VQVAE](https://tugdual.fr/Audio-VQVAE-for-JAX/). By learning to map text tokens to audio tokens, the model becomes multi-modal.  

The final output sequences represent speech, which we convert into audio using [HiFiGAN](https://tugdual.fr/HiFiGAN-for-JAX/). Additionally, we enhance speech expressiveness (e.g., tone, speed) by feeding 1024-dimensional vectors representing the target speaker's paralinguistic features.  

### Under the Hood 

__Masked Attention__ 

Masked attention is the core mechanism for learning relationships between tokens. It determines which tokens influence others by projecting them into smaller dimensions and computing relationships. Masking ensures the model focuses only on prior tokens, preventing it from "seeing" future ones.  

Studies classify attention patterns into:  
1. **Semantic**: Tokens linked by meaning.  
2. **Linguistic**: Tokens connected by grammar (e.g., verbs and nouns).  
3. **Rare Tokens**: Infrequent but critical tokens.  

__Feedforward Layers__  

Feedforward layers mix outputs, add non-linearity via activation functions, and stack layers for hierarchical abstractions. The final output approximates a one-hot encoding in the token vocabulary, enabling token selection for sequential generation.  


## Goal 🎯  

Implement a GPT2 architecture using Equinox and train it on TinyStories.  

# Model

We have a few things to implement from the ground up. The custom activation function, the forward layer, the masked attention. We then package this up in a nice layer that we can stack, and finally wrap all these stacks into a GPT2 !

We can start by importing our favorite libraries 🥰

In [None]:
import jax
import equinox as eqx
import equinox.nn as nn
import jax.numpy as jnp
import typing as tp

## Configuration file

Because of the size of our model, we're going to be passing down lots of arguments. To avoid having a long unreadable list of parameters we can define a "dataclass" that will allow us to simply pass a `config` down to the model.

Feel free to experiment with various settings !

In [None]:
from dataclasses import dataclass


@dataclass
class GPTConfig:
    block_size: int = 200
    vocab_size: int = (
        50304  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    )
    n_layer: int = 12
    n_head: int = 6
    n_embd: int = 512
    dropout: float = 0.0
    bias: bool = False  #

## SwiGLU Activation Function 

We start by implementing the SwiGLU activation function, introduced in [@shazeer2020gluvariantsimprovetransformer], a powerful variant of GLU.  

### Why SwiGLU?  
SwiGLU dynamically adjusts its activation based on the input. Think of it like a railway switch—redirecting the "activation path" when the input carries different information. This gives the network greater flexibility and control, leading to better performance.  

For more details, see this [explanation by Boudefel](https://medium.com/@s_boudefel/exploring-swiglu-the-activation-function-powering-modern-llms-9697f88221e7).  

![The function we implement, based on [@shazeer2020gluvariantsimprovetransformer]](assets/swiglu.png)  

Below is a visualization of the Swish function, $x \times \text{sigmoid}(x)$, which plays a role in SwiGLU:  

![](assets/graphswi.png)  

In [None]:
class SwiGLU(eqx.Module):
    W: nn.Linear
    V: nn.Linear
    b: jax.Array
    c: jax.Array

    def __init__(self, input_dim, output_dim, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.W = nn.Linear(input_dim, output_dim, key=key1)
        self.V = nn.Linear(input_dim, output_dim, key=key2)
        self.b = jax.random.normal(key3, (output_dim))
        self.c = jax.random.normal(key4, (output_dim))

    def __call__(self, x):
        return jax.nn.swish((self.W(x) + self.b) * (self.V(x) + self.c))

In [None]:
# | code-fold : true

key = jax.random.PRNGKey(69)
mod = SwiGLU(10, 4, key)

x = jnp.ones(10)
print(mod(x).shape)

## MLP

We can now move onto the multilayer perceptron, which we mentionned earlier as the feedforward part of our network. Because the model is big and we want to make sure that it doesn't just "memorize" things, we include dropout which pushes the model to avoid relying on singular neurons / data flowing through for information.

✨ You'll also notice that since our SwiGLU has two linear layers in it, in reality each MLP that we'll use uses __4__ layers !!

In [None]:
class MLP(eqx.Module):
    ff1: nn.Linear
    ff2: nn.Linear
    act: SwiGLU
    drop: nn.Dropout

    def __init__(self, config, key):

        key1, key2, key3 = jax.random.split(key, 3)

        self.ff1 = nn.Linear(
            config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=key1
        )
        self.act = SwiGLU(4 * config.n_embd, 4 * config.n_embd, key=key2)
        self.ff2 = nn.Linear(
            4 * config.n_embd, config.n_embd, use_bias=config.bias, key=key3
        )
        self.drop = nn.Dropout(config.dropout)

    @eqx.filter_jit
    def __call__(self, x):
        y = self.ff1(x)
        y = self.act(y)
        y = self.ff2(y)
        return self.drop(y)

Again, we can compare with their implementation to make sure we're close enough.

In [None]:
# | code-fold : true


class MLPTheirs(eqx.Module):
    c_fc: eqx.nn.Linear
    swiglu: SwiGLU
    c_proj: eqx.nn.Linear
    dropout: eqx.nn.Dropout

    def __init__(self, config, key):
        lkey1, lkey2, skey = jax.random.split(key, 3)

        self.c_fc = eqx.nn.Linear(
            config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=lkey1
        )
        self.swiglu = SwiGLU(4 * config.n_embd, 4 * config.n_embd, skey)
        self.c_proj = eqx.nn.Linear(
            4 * config.n_embd, config.n_embd, use_bias=config.bias, key=lkey2
        )
        self.dropout = eqx.nn.Dropout(config.dropout)

    def __call__(self, x):
        x = jax.vmap(self.c_fc)(x)
        x = jax.vmap(self.swiglu)(x)
        x = jax.vmap(self.c_proj)(x)
        x = self.dropout(x)
        return x

In [None]:
# | code-fold : true

config = GPTConfig()
key = jax.random.PRNGKey(69)

mlp = MLP(config, key)
mlp_theirs = MLPTheirs(config, key)

x = jax.random.normal(key, (100, config.n_embd))

res = jax.vmap(mlp)(x)
res_theirs = mlp_theirs(x)

average_diff = jnp.mean(res_theirs)
print(average_diff)

## Masked attention

Moving onto one of the more complicated aspects of the model, but in the end it simply learns to output which tokens are more important with each other. There are plenty of fantastic tutorials out there for better understanding the underlying concept, notably : [Transformers explained visually](https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853)

In [None]:
import math


class CausalSelfAttention(eqx.Module):
    attnk: nn.Linear
    attnq: nn.Linear
    attnv: nn.Linear
    proj: nn.Linear

    resid_dropout: nn.Dropout
    attn_dropout: nn.Dropout

    mask: jax.Array

    def __init__(self, config, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)

        self.attnk = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key1
        )
        self.attnv = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key2
        )
        self.attnq = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key3
        )
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.proj = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key4
        )

        self.mask = jnp.tril(jnp.ones((config.block_size, config.block_size)))

    # Could play arround with the different attention score calculations (Baidhu ?)
    # X is an embedding, it should self attend.

    @eqx.filter_jit
    def __call__(self, x):
        # x = jnp.swapaxes(x, -1, -2)
        T, C = x.shape  # Seq length and embedding dim.

        q = jax.vmap(self.attnq)(x)
        k = jax.vmap(self.attnk)(x)
        v = jax.vmap(self.attnv)(x)

        att = jnp.matmul(q, jnp.transpose(k)) / math.sqrt(jnp.shape(k)[-1])
        att = jnp.where(
            jax.numpy.equal(jax.lax.stop_gradient(self.mask[:T, :T]), 0),
            float("-inf"),
            att,
        )
        att = jax.nn.softmax(att, axis=-1)
        att = self.attn_dropout(att)

        y = jnp.matmul(att, v)

        y = jax.vmap(self.proj)(y)
        y = self.resid_dropout(y)
        return y

Small check...

In [None]:
# | code-fold : true

import optax


config = GPTConfig()
key = jax.random.PRNGKey(69)

mlp = CausalSelfAttention(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(mlp)

x = jax.random.normal(jax.random.key(2), (30, config.n_embd))


@eqx.filter_jit
def loss(model, x, y):
    output = model(x)
    return jax.numpy.mean(jax.numpy.abs(y - output))


def make_step(model, opt_state, x, y):
    loss_step, grads = eqx.filter_value_and_grad(loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss_step


print(make_step(mlp, opt_state, x, x))

print(mlp(jax.random.normal(key, (100, config.n_embd))).shape)

## Block

Ok ! Now that we have the component parts of what we call a "block" we can assemble them. This will then be stacked to get as many layers of abstraction as we wish. In our case we will stack it 12 times as per the GPTConfig we defined.

In [None]:
class Block(eqx.Module):
    norm: nn.LayerNorm
    attn: CausalSelfAttention
    mlp: MLP

    def __init__(self, config, key):
        key1, key2 = jax.random.split(key, 2)

        self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
        self.attn = CausalSelfAttention(config, key=key1)
        self.mlp = MLP(config, key=key2)

    @eqx.filter_jit
    def __call__(self, x):
        y = jax.vmap(self.norm)(x)
        y = self.attn(
            y
        )  # Can't vmap as the whole point is exchange info between tokens.
        x = y + x

        y = jax.vmap(self.norm)(x)
        y = jax.vmap(self.mlp)(y)
        x = y + x

        return x

Can compare with their work.

In [None]:
# | code-fold : true

import optax


config = GPTConfig()
key = jax.random.PRNGKey(69)

block = Block(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(block)

x = jax.random.normal(jax.random.key(2), (30, config.n_embd))


@eqx.filter_jit
def loss(model, x, y):
    output = model(x)
    return jax.numpy.mean(jax.numpy.abs(y - output))


def make_step(model, opt_state, x, y):
    loss_step, grads = eqx.filter_value_and_grad(loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss_step


print(make_step(block, opt_state, x, x))

We can finally add the embeddings to our model, which are the maps that send tokens to the dimension that the model works with, i.e. 1024 dims.

In [None]:
class GPT(eqx.Module):
    wte: nn.Embedding  # Token embeddings
    wpe: nn.Embedding  # Positional embeddings

    drop: nn.Dropout

    layers: list
    norm: nn.LayerNorm
    lm_head: nn.Linear

    def __init__(self, config, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)

        self.wte = nn.Embedding(config.vocab_size, config.n_embd, key=key1)
        self.wpe = nn.Embedding(config.block_size, config.n_embd, key=key2)
        self.drop = nn.Dropout(config.dropout)

        self.layers = [Block(config, key3) for _ in range(config.n_layer)]
        self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, key=key4)

    # @eqx.filter_jit
    def __call__(self, token_ids):
        (t,) = token_ids.shape

        # Should use better positional embeddings with cos and sin.
        pos = jnp.arange(0, t, dtype=jnp.int64)

        tok_emb = jax.vmap(self.wte)(token_ids)
        pos_emb = jax.vmap(self.wpe)(pos)

        # Dropout at the first layer ? Seems a bit aggressive...
        x = self.drop(tok_emb + pos_emb)

        for block in self.layers:
            x = block(x)
        x = jax.vmap(self.norm)(x)
        logits = jax.vmap(self.lm_head)(x)
        logits = jax.nn.softmax(logits)

        return logits

In [None]:
# | code-fold : true

import optax


config = GPTConfig()
key = jax.random.PRNGKey(69)

block = GPT(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(block)

x = jax.numpy.ones((30, 128), dtype=jax.numpy.int32)


@eqx.filter_jit
def loss(model, x, y):
    output = jax.vmap(model)(x)
    return jax.numpy.mean(
        jax.vmap(optax.softmax_cross_entropy_with_integer_labels)(output, y)
    )


def make_step(model, opt_state, x, y):
    loss_step, grads = eqx.filter_value_and_grad(loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss_step


print(make_step(block, opt_state, x, x))

(GPT(
  wte=Embedding(num_embeddings=50304, embedding_size=200, weight=f32[50304,200]),
  wpe=Embedding(num_embeddings=128, embedding_size=200, weight=f32[128,200]),
  drop=Dropout(p=0.0, inference=False),
  layers=[
    Block(
      norm=LayerNorm(
        shape=(200,),
        eps=1e-05,
        use_weight=True,
        use_bias=False,
        weight=f32[200],
        bias=None
      ),
      attn=CausalSelfAttention(
        attnk=Linear(
          weight=f32[200,200],
          bias=None,
          in_features=200,
          out_features=200,
          use_bias=False
        ),
        attnq=Linear(
          weight=f32[200,200],
          bias=None,
          in_features=200,
          out_features=200,
          use_bias=False
        ),
        attnv=Linear(
          weight=f32[200,200],
          bias=None,
          in_features=200,
          out_features=200,
          use_bias=False
        ),
        proj=Linear(
          weight=f32[200,200],
          bias=None,
        

# Training

We can now move onto training the model ! We're going to be using the TinyStories dataset. [Tiktoken](https://github.com/openai/tiktoken) is used to map the sentences to sequences of tokens that the model would understand. Below is the code to download and transform the data into a binary file, and then provide it with a dataloader to our training regime.

In [None]:
# | code-fold : true

# saves the openwebtext dataset to a binary file for training. following was helpful:
# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py

import os
from tqdm import tqdm
import numpy as np
import tiktoken
from datasets import load_dataset  # huggingface datasets

# number of workers in .map() call
# good number to use is ~order number of cpu cores // 2
num_proc = 16

dataset = load_dataset("roneneldan/TinyStories")

# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
enc = tiktoken.get_encoding("gpt2")


def process(example):
    ids = enc.encode_ordinary(
        example["text"]
    )  # encode_ordinary ignores any special tokens
    ids.append(enc.eot_token)  # add the end of text token, e.g. 50256 for gpt2 bpe
    # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
    out = {"ids": ids, "len": len(ids)}
    return out


# tokenize the dataset
tokenized = dataset.map(
    process,
    remove_columns=["text"],
    desc="tokenizing the splits",
    num_proc=num_proc,
)

# concatenate all the ids in each dataset into one large file we can use for training
for split, dset in tokenized.items():
    arr_len = np.sum(dset["len"])
    filename = os.path.join(os.path.dirname("dataset"), f"{split}.bin")
    dtype = np.uint16  # (can do since enc.max_token_value == 50256 is < 2**16)
    arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
    total_batches = 1024

    idx = 0
    for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
        # Batch together samples for faster write
        batch = dset.shard(
            num_shards=total_batches, index=batch_idx, contiguous=True
        ).with_format("numpy")
        arr_batch = np.concatenate(batch["ids"])
        # Write into mmap
        arr[idx : idx + len(arr_batch)] = arr_batch
        idx += len(arr_batch)
    arr.flush()

We can now load the code from the compressed binary representation to the inputs and outputs. Since we want the GPT to learn to predict the next token, we simply shift the input by 1 !

In [None]:
# | code-fold : true

import os
import jax.numpy as np
from GPT2 import GPTConfig
import numpy

data_dir = "dataset"
config = GPTConfig()


def get_batch(split: str):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == "train":
        data = numpy.memmap(
            os.path.join(data_dir, "train.bin"), dtype=numpy.uint16, mode="r"
        )
    else:
        data = numpy.memmap(
            os.path.join(data_dir, "validation.bin"), dtype=numpy.uint16, mode="r"
        )

    ix = numpy.random.randint(len(data) - config.block_size, size=(8,))
    x = np.stack(
        [np.array(data[i : i + config.block_size], dtype=np.int64) for i in ix]
    )
    y = np.stack(
        [np.array(data[i + 1 : i + 1 + config.block_size], dtype=np.int64) for i in ix]
    )

    return x, y

We can now define our loss function. Our goal here is to motivate the model to output something close to [0, 0, 0, ..., 1,..., 0, 0] where the 1 is placed at the $n$th index. This index would ideally correspond to the word we're attempting to match. `optax`, the ML optimisation library of JAX conveniently has a function for this.

In [94]:
import optax

learning_rate = 1e-5
warmup_iters = 10
init_from = "scratch"
lr_decay_iters = 20
iter_num = 0
min_lr = 1e-6

lr_scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=learning_rate,
    warmup_steps=warmup_iters if init_from == "scratch" else 0,
    decay_steps=lr_decay_iters - iter_num,
    end_value=min_lr,
)

optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=learning_rate)


@eqx.filter_jit
def loss(model, x, y):
    output = jax.vmap(model)(x)
    return jax.numpy.mean(
        jax.vmap(optax.softmax_cross_entropy_with_integer_labels)(output, y)
    )


def make_step(model, optimizer_state, x, y):
    losses, grads = eqx.filter_value_and_grad(loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, optimizer_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, losses

We can now move onto initializing our model and training it ! We can log the progress on wandb to see the loss curve.

In [None]:
import wandb

key = jax.random.PRNGKey(69)

gptconf = GPTConfig()
model = GPT(gptconf, key)

wandb.init(project="gpt-training", config=gptconf.__dict__)

optimizer_state = optimizer.init(model)
num_iterations = 100

for local_iter_num in range(num_iterations):
    x, y = get_batch("train")

    # Perform a single training step
    model, optimizer_state, losses = make_step(model, optimizer_state, x, y)

    wandb.log({"loss": losses, "iteration": local_iter_num})

Initializing a new model from scratch
[[ 7967    88    13 ...   645   780   673]
 [ 1375 15342   683 ...   477   465  1204]
 [  290   517    13 ...   760    11   475]
 ...
 [  339  1625   284 ... 19751    30   520]
 [   11   339  1965 ...   640  1978    13]
 [ 6403   373   523 ...  5059  3371   683]]


  [np.array(data[i : i + config.block_size], dtype=np.int64) for i in ix]
  [np.array(data[i + 1 : i + 1 + config.block_size], dtype=np.int64) for i in ix]


(8, 128)
Iteration 1/100 | Loss: 10.825838088989258
[[  467    13   632 ...    13  1649   262]
 [  287   465  1263 ...  3574   326  1110]
 [  287   465 13008 ...   607    13  1375]
 ...
 [ 2300   703  1327 ...  9461    11  1757]
 [  290   531    11 ...   957   475  2147]
 [  714  1100    13 ...    11   484  2982]]


  [np.array(data[i : i + config.block_size], dtype=np.int64) for i in ix]
  [np.array(data[i + 1 : i + 1 + config.block_size], dtype=np.int64) for i in ix]


(8, 128)
Iteration 2/100 | Loss: 10.825838088989258
[[   11  6184    95 ...   523   881  7838]
 [  339   714   407 ...   355   484   714]
 [   13   314  1842 ...   284   262  3952]
 ...
 [  257  1263  3704 ...  3521   470   651]
 [  373   257  1310 ...  5822 22075   465]
 [  465  5318   326 ...   523   326   314]]
(8, 128)
Iteration 3/100 | Loss: 10.825836181640625
[[ 1022   262   734 ...   339  4251   262]
 [  257  1657    13 ...  9955    11   257]
 [  351   465  1021 ...   460  1833  4186]
 ...
 [  475   339  1422 ...   366 10449   345]
 [ 6253    13   383 ...   550   284   466]
 [  736  1363    11 ...   262 16365   644]]
(8, 128)
Iteration 4/100 | Loss: 10.82583999633789
[[  262  2119   351 ...   603   338 10955]
 [ 1243   326   547 ...   198   198 22940]
 [  546   284   923 ...  1339   262 11376]
 ...
 [  607   290 26834 ...  1110    11   257]
 [ 1123   584    13 ...   340    13   366]
 [  711   517 19780 ... 45230   284   262]]
(8, 128)
Iteration 5/100 | Loss: 10.825838088989258
[

KeyboardInterrupt: 