# This is a GPT implementation shadowing HF's implementation of GPT2. It might be broken or noto work, I haven't checked much.


We can start by importing our favorite libraries 🥰

In [1]:
import jax
import equinox as eqx
import equinox.nn as nn
import jax.numpy as jnp
import typing as tp
import numpy

## activation function

XTTS uses HF's transformer underneath. This GPT uses the GLU_new activation function: 

    def forward(self, input: Tensor) -> Tensor:
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

In [2]:
import jax
from jaxtyping import ArrayLike
import math


@jax.jit
def glu_new(x: ArrayLike) -> jax.Array:
    return jax.numpy.array(
        0.5
        * x
        * (
            1
            + jax.numpy.tanh(
                math.sqrt(2.0 / math.pi) * (x + 0.044715 * jax.numpy.pow(x, 3.0))
            )
        )
    )


import torch
from torch import Tensor


def forward(input: Tensor) -> Tensor:
    return (
        0.5
        * input
        * (
            1.0
            + torch.tanh(
                math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))
            )
        )
    )

Testing if they're the same below...

In [None]:
# | code-fold : true

import numpy

key = jax.random.PRNGKey(69)
hidden_states = jax.random.normal(key, (100))
tor = torch.from_numpy(numpy.array(hidden_states))

attn_output = softplus(hidden_states)
ytor = forward(tor)

dif = ytor - torch.from_numpy(numpy.array(attn_output))
print(dif)
# assert torch.testing.assert_close(
#     ytor, torch.from_numpy(numpy.array(softplus(x)))
# )

# CONV1d

They have a strange conv.

In [3]:
class our_Conv1D(eqx.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    nf: int
    nx: int
    weight: jax.Array
    bias: jax.Array

    def __init__(self, nf, nx, key=None):
        super().__init__()
        self.nf = nf
        self.nx = nx
        self.weight = jax.nn.initializers.normal(stddev=0.02)(key, (nx, nf))
        self.bias = jax.numpy.zeros((nf))

    def __repr__(self) -> str:
        return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__)

    def __call__(self, x):
        size_out = x.shape[:-1] + (self.nf,)
        x = self.bias + jax.numpy.dot(
            jax.numpy.reshape(x, shape=(-1, x.shape[-1])), self.weight
        )
        return jax.numpy.reshape(x, size_out)

In [None]:
import transformers

input_dim = 1024
output_dim = 4096

their_gpt = transformers.pytorch_utils.Conv1D(output_dim, input_dim)
our_gpt = our_Conv1D(output_dim, input_dim, jax.random.PRNGKey(1))

print(our_gpt.bias.shape)

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(100, 1024))
their_x = torch.from_numpy(numpy.array(our_x))

print(list(their_gpt.named_parameters()))


torch_params = {
    name: param.detach().numpy() for name, param in their_gpt.named_parameters()
}

torch_to_jax_keys = {
    ("weight", "weight"),
    ("bias", "bias"),
}


# Function to update the JAX model parameters
def update_params(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    for jax_key, torch_key in torch_to_jax_keys:
        if jax_key == path:
            print(jax_key)
            if "bias" in jax_key:
                return jax.numpy.array(torch_params[torch_key])
            return jax.numpy.array(torch_params[torch_key])
    return x


our_gpt = jax.tree_util.tree_map_with_path(update_params, our_gpt)

print(type(our_x))
print(our_x.shape)
their_y = their_gpt(their_x)
print(their_y.size())
our_y = our_gpt(our_x)

torch.testing.assert_close(their_y, torch.from_numpy(numpy.array(our_y)))

## MLP

We can now move onto the multilayer perceptron.

In [4]:
from dataclasses import dataclass


@dataclass
class GPTConfig:
    block_size: int = 128
    vocab_size: int = (
        50304  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    )
    n_layer: int = 3
    n_head: int = 3
    n_embd: int = 200
    dropout: float = 0.0
    bias: bool = False  #

In [5]:
class MLP(eqx.Module):
    c_fc: our_Conv1D
    c_proj: our_Conv1D
    dropout: nn.Dropout

    def __init__(self, intermediate_size, config, key):
        key1, key2 = jax.random.split(key, 2)

        # The weights are transposed compraed to the feed forward.
        embed_dim = config.hidden_size
        self.c_fc = our_Conv1D(intermediate_size, embed_dim, key=key1)
        self.c_proj = our_Conv1D(embed_dim, intermediate_size, key=key2)
        self.dropout = nn.Dropout(config.resid_pdrop, deterministic=True)

    # TODO: Interesting take on the fact that vmap should be applied here ?
    def __call__(self, x):
        y = self.c_fc(x)
        y = glu_new(y)
        y = self.c_proj(y)
        y = self.dropout(y)
        return y

We can't be sure of the last part as it uses dropout but the rest seems to work 👍

In [None]:
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers.models.gpt2.modeling_gpt2 import GPT2Config

intermediate_size = 4096


their_config = GPT2Config(hidden_size=1024)
key = jax.random.PRNGKey(69)

their_gpt = GPT2MLP(intermediate_size, their_config)
our_gpt = MLP(intermediate_size, their_config, jax.random.PRNGKey(1))

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(100, 1024))
their_x = torch.from_numpy(numpy.array(our_x))

torch_params = {
    name: param.detach().numpy() for name, param in their_gpt.named_parameters()
}

print(torch_params.keys())

torch_to_jax_keys = {
    ("c_fc.weight", "c_fc.weight"),
    ("c_fc.bias", "c_fc.bias"),
    ("c_proj.weight", "c_proj.weight"),
    ("c_proj.bias", "c_proj.bias"),
}


# Function to update the JAX model parameters
def update_params(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    for jax_key, torch_key in torch_to_jax_keys:
        if jax_key == path:
            print(jax_key)
            if "bias" in jax_key:
                return jax.numpy.array(torch_params[torch_key])
            return jax.numpy.array(torch_params[torch_key])
    return x


our_gpt = jax.tree_util.tree_map_with_path(update_params, our_gpt)

print(type(our_x))
print(our_x.shape)
their_y = their_gpt.c_fc(their_x)
their_y = their_gpt.act(their_y)
their_y = their_gpt.c_proj(their_y)
# their_y = their_gpt.dropout(their_y)
print(their_y.size())
our_y = our_gpt.c_fc(our_x)
our_y = glu_new(our_y)
our_y = our_gpt.c_proj(our_y)
# our_y = our_gpt.dropout(our_y)

torch.testing.assert_close(their_y, torch.from_numpy(numpy.array(our_y)))

For some reason the Conv1D isn't the same for the GPT - testing below

In [None]:
from equinox.nn import Conv1d

hidden_states = jax.numpy.ones((10, 5))

torch_covn = our_Conv1D(7, 10)

jax_conv = Conv1d(10, 7, key=jax.random.PRNGKey(1))

Again, we can compare with their implementation to make sure we're close enough.

In [6]:
# | code-fold : true


class MLPTheirs(eqx.Module):
    c_fc: eqx.nn.Linear
    swiglu: SwiGLU
    c_proj: eqx.nn.Linear
    dropout: eqx.nn.Dropout

    def __init__(self, config, key):
        lkey1, lkey2, skey = jax.random.split(key, 3)

        self.c_fc = eqx.nn.Linear(
            config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=lkey1
        )
        self.swiglu = SwiGLU(4 * config.n_embd, 4 * config.n_embd, skey)
        self.c_proj = eqx.nn.Linear(
            4 * config.n_embd, config.n_embd, use_bias=config.bias, key=lkey2
        )
        self.dropout = eqx.nn.Dropout(config.dropout)

    def __call__(self, x):
        x = jax.vmap(self.c_fc)(x)
        x = jax.vmap(self.swiglu)(x)
        x = jax.vmap(self.c_proj)(x)
        x = self.dropout(x)
        return x

NameError: name 'SwiGLU' is not defined

In [None]:
# | code-fold : true

their_config = GPTConfig()
key = jax.random.PRNGKey(69)

mlp = MLP(their_config, key)
mlp_theirs = MLPTheirs(their_config, key)

hidden_states = jax.random.normal(key, (100, their_config.n_embd))

res = jax.vmap(mlp)(hidden_states)
res_theirs = mlp_theirs(hidden_states)

average_diff = jnp.mean(res_theirs)
print(average_diff)

## Masked attention

Moving onto one of the more complicated aspects of the model, but in the end it simply learns to output which tokens are more important with each other.

In [7]:
import math
import equinox as eqx
import equinox.nn as nn
import jax
import jax.experimental
import jax.numpy as jnp


class CausalSelfAttention(eqx.Module):
    c_attn: our_Conv1D
    c_proj: our_Conv1D

    resid_dropout: nn.Dropout
    attn_dropout: nn.Dropout

    bias: jax.Array = eqx.field(static=True)
    scale_attn_weights: bool
    split_size: int

    num_heads: int
    head_size: int

    def __init__(self, config, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)

        hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = hidden_size // config.num_attention_heads
        self.split_size = hidden_size

        self.c_attn = our_Conv1D(3 * hidden_size, hidden_size, key=key1)
        self.c_proj = our_Conv1D(hidden_size, hidden_size, key=key3)
        self.attn_dropout = nn.Dropout(config.attn_pdrop, deterministic=True)
        self.resid_dropout = nn.Dropout(config.resid_pdrop, deterministic=True)

        self.bias = jnp.tril(
            jnp.ones(
                (1, 1, config.max_position_embeddings, config.max_position_embeddings)
            )
        )

        self.scale_attn_weights = config.scale_attn_weights

    # Could play arround with the different attention score calculations (Baidhu ?)
    # X is an embedding, it should self attend.

    def _attn(self, q, k, v, attention_mask, head_mask):
        att = jnp.matmul(q, jnp.transpose(k, axes=(0, 1, 3, 2)))
        att = att / math.sqrt(jnp.shape(k)[-1])  # Scale weights is set to true in XTTS.

        query_length, key_length = q.shape[-2], k.shape[-2]
        mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
        att = jnp.where(
            jax.numpy.equal(jax.lax.stop_gradient(mask), 0),
            jnp.finfo(att.dtype).min,
            att,
        )

        if attention_mask is not None:
            att = att + attention_mask

        att = jax.nn.softmax(att, axis=-1)

        # att = self.attn_dropout(att)

        if head_mask is not None:
            att = att * head_mask

        return jnp.matmul(att, v), att

    # Stange that they do it this way and not by simply defining the dims without permutation
    def _split_heads(self, x):
        new_shape = x.shape[:-1] + (self.num_heads, self.head_size)
        x = jax.numpy.reshape(x, new_shape)
        return jax.numpy.permute_dims(x, (0, 2, 1, 3))

    def _merge_heads(self, x):
        x = jax.numpy.permute_dims(x, (0, 2, 1, 3))
        new_shape = x.shape[:-2] + (self.num_heads * self.head_size,)
        return jax.numpy.reshape(x, new_shape)

    def __call__(
        self,
        hidden_states: tp.Optional[tp.Tuple[jax.Array]],
        layer_past: tp.Optional[tp.Tuple[jax.Array]] = None,
        attention_mask: tp.Optional[jax.Array] = None,
        head_mask: tp.Optional[jax.Array] = None,
        encoder_hidden_states: tp.Optional[jax.Array] = None,
        encoder_attention_mask: tp.Optional[jax.Array] = None,
        use_cache: tp.Optional[bool] = False,
        output_attentions: tp.Optional[bool] = False,
    ):
        print(f"SHAPE OF {hidden_states.shape}")
        # x = jnp.swapaxes(x, -1, -2)
        qkv = jax.vmap(self.c_attn)(hidden_states)
        q, k, v = jax.numpy.split(qkv, 3, axis=2)

        query = self._split_heads(q)
        key = self._split_heads(k)
        value = self._split_heads(v)

        if layer_past is not None:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = jax.numpy.concat((past_key, key), axis=-2)
            value = jax.numpy.concat((past_value, value), axis=-2)

        present = None
        if use_cache is True:
            present = (key, value)

        # print(query.shape)
        attn_output, attn_weights = self._attn(
            query, key, value, attention_mask, head_mask
        )
        attn_output = self._merge_heads(attn_output)
        attn_output = jax.vmap(self.c_proj)(attn_output)
        # attn_output = self.resid_dropout(attn_output)
        print(f"Ours: {attn_output.shape}")

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

Small check...

In [None]:
%reload_ext autoreload
%autoreload 2

In [7]:
# | code-fold : true
from transformers.models.gpt2.modeling_gpt2 import GPT2SdpaAttention
from transformers import GPT2Config, GPT2Model
import jax, torch, numpy
from dataclasses import dataclass

from transformers.models.gpt2.modeling_gpt2 import GPT2Config


intermediate_size = 4096


their_config = GPT2Config().from_dict(
    {
        "_attn_implementation_autoset": True,
        "activation_function": "gelu_new",
        "attn_pdrop": 0.1,
        "bos_token_id": 50256,
        "embd_pdrop": 0.1,
        "eos_token_id": 50256,
        "gradient_checkpointing": False,
        "initializer_range": 0.02,
        "layer_norm_epsilon": 1e-05,
        "model_type": "gpt2",
        "n_ctx": 1082,
        "n_embd": 1024,
        "n_head": 16,
        "n_inner": None,
        "n_layer": 30,
        "n_positions": 1082,
        "reorder_and_upcast_attn": False,
        "resid_pdrop": 0.1,
        "scale_attn_by_inverse_layer_idx": False,
        "scale_attn_weights": True,
        "summary_activation": None,
        "summary_first_dropout": 0.1,
        "summary_proj_to_labels": True,
        "summary_type": "cls_index",
        "summary_use_proj": True,
        "transformers_version": "4.46.2",
        "use_cache": True,
        "vocab_size": 256,
    }
)

key = jax.random.PRNGKey(9)

their_gpt = GPT2SdpaAttention(their_config)
our_gpt = CausalSelfAttention(their_config, jax.random.PRNGKey(1))

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 1082, 1024))
their_x = torch.from_numpy(numpy.array(our_x))

# their_head_mask = torch.zeros((1, 1, 1024))
# our_head_mask = jax.numpy.zeros((10, 1082, 1024))

torch_params = {
    name: param.detach().numpy() for name, param in their_gpt.named_parameters()
}

print(torch_params.keys())

torch_to_jax_keys = {
    ("c_attn.weight", "c_attn.weight"),
    ("c_attn.bias", "c_attn.bias"),
    ("c_proj.weight", "c_proj.weight"),
    ("c_proj.bias", "c_proj.bias"),
}


# Function to update the JAX model parameters
def update_params(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    for jax_key, torch_key in torch_to_jax_keys:
        if jax_key == path:
            print(jax_key)
            if "bias" in jax_key:
                return jax.numpy.array(torch_params[torch_key])
            return jax.numpy.array(torch_params[torch_key])
    return x


our_gpt = jax.tree_util.tree_map_with_path(update_params, our_gpt)

print(type(our_x))
print(our_x.shape)
their_y = their_gpt(their_x, output_attentions=True)
our_y = our_gpt(our_x)
print(our_y[0][0, 0])
print(their_y[0][0, 0])
torch.testing.assert_close(their_y[0], torch.from_numpy(numpy.array(our_y[0])))

  from .autonotebook import tqdm as notebook_tqdm
  our_gpt = CausalSelfAttention(their_config, jax.random.PRNGKey(1))


dict_keys(['c_attn.weight', 'c_attn.bias', 'c_proj.weight', 'c_proj.bias'])
c_attn.weight
c_attn.bias
c_proj.weight
c_proj.bias
<class 'jaxlib.xla_extension.ArrayImpl'>
(10, 1082, 1024)
Theirs: torch.Size([10, 1082, 1024])
SHAPE OF (10, 1082, 1024)
Ours: (10, 1082, 1024)
[ 0.96926224  0.0685735  -0.2757392  ...  0.19498248 -0.12266847
  0.26153973]
tensor([ 0.9693,  0.0686, -0.2757,  ...,  0.1950, -0.1227,  0.2615],
       grad_fn=<SelectBackward0>)


In [None]:
import torch.nn.functional as F

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(3, 100, 10))
their_x = torch.from_numpy(numpy.array(our_x))

their_y = F.scaled_dot_product_attention(
    their_x[0], their_x[1], their_x[2], is_causal=True
)

our_y = jax.nn.dot_product_attention(
    our_x[0], our_x[1], our_x[2], is_causal=True, implementation="xla"
)

print(their_y[0:10])
print(our_y[0:10])
torch.testing.assert_close(their_y, torch.from_numpy(numpy.array(our_y)))

## Block

Ok ! Now that we have the component parts of what we call a "block" we can assemble them. This will then be stacked to get as many layers of abstraction as we wish. In our case we will stack it 12 times as per the GPTConfig we defined.

In [None]:
class Block(eqx.Module):
    ln_1: nn.LayerNorm
    ln_2: nn.LayerNorm
    attn: CausalSelfAttention
    mlp: MLP

    def __init__(self, config, key):
        key1, key2 = jax.random.split(key, 2)
        hidden_size = config.hidden_size
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

        self.ln_1 = nn.LayerNorm(
            (hidden_size),
            eps=config.layer_norm_epsilon,
            elementwise_affine=True,
        )
        self.attn = CausalSelfAttention(config, key=key1)
        self.ln_2 = nn.LayerNorm(
            (hidden_size),
            eps=config.layer_norm_epsilon,
            elementwise_affine=True,
        )

        self.mlp = MLP(inner_dim, config, key=key2)

    def __call__(
        self,
        hidden_states: tp.Optional[tp.Tuple[jax.Array]],
        layer_past: tp.Optional[tp.Tuple[jax.Array]] = None,
        attention_mask: tp.Optional[jax.Array] = None,
        head_mask: tp.Optional[jax.Array] = None,
        encoder_hidden_states: tp.Optional[jax.Array] = None,
        encoder_attention_mask: tp.Optional[jax.Array] = None,
        use_cache: tp.Optional[bool] = False,
        output_attentions: tp.Optional[bool] = False,
    ):
        residual = hidden_states
        hidden_states = jax.vmap(jax.vmap(self.ln_1))(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )  # Can't vmap as the whole point is exchange info between tokens.
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        hidden_states = attn_output + residual
        residual = hidden_states

        hidden_states = jax.vmap(jax.vmap(self.ln_2))(hidden_states)
        feed_forward_hidden_states = jax.vmap(self.mlp)(hidden_states)
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            return (hidden_states,) + outputs
        else:
            return (hidden_states,) + outputs[1:]

Can compare with their work.

In [None]:
# | code-fold : true
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers import GPT2Config
import jax, torch, numpy
from dataclasses import dataclass

from transformers.models.gpt2.modeling_gpt2 import GPT2Config


their_config = GPT2Config().from_dict(
    {
        "_attn_implementation_autoset": True,
        "activation_function": "gelu_new",
        "attn_pdrop": 0.1,
        "bos_token_id": 50256,
        "embd_pdrop": 0.1,
        "eos_token_id": 50256,
        "gradient_checkpointing": False,
        "initializer_range": 0.02,
        "layer_norm_epsilon": 1e-05,
        "model_type": "gpt2",
        "n_ctx": 1082,
        "n_embd": 1024,
        "n_head": 16,
        "n_inner": None,
        "n_layer": 30,
        "n_positions": 1082,
        "reorder_and_upcast_attn": False,
        "resid_pdrop": 0.1,
        "scale_attn_by_inverse_layer_idx": False,
        "scale_attn_weights": True,
        "summary_activation": None,
        "summary_first_dropout": 0.1,
        "summary_proj_to_labels": True,
        "summary_type": "cls_index",
        "summary_use_proj": True,
        "transformers_version": "4.46.2",
        "use_cache": True,
        "vocab_size": 256,
    }
)

key = jax.random.PRNGKey(2)

their_gpt = GPT2Block(their_config)
our_gpt = Block(their_config, jax.random.PRNGKey(1))

torch_params = {
    name: param.detach().numpy() for name, param in their_gpt.named_parameters()
}


# Function to update the JAX model parameters
def update_params(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    # for jax_key, torch_key in torch_to_jax_keys:
    if path in torch_params.keys():
        if "bias" in path:
            return jax.numpy.array(torch_params[path])
        return jax.numpy.array(torch_params[path])
    print(path)
    return x


our_gpt = jax.tree_util.tree_map_with_path(update_params, our_gpt)

our_x = jax.random.normal(jax.random.PRNGKey(2), shape=(10, 1082, 1024))
their_x = torch.from_numpy(numpy.array(our_x))
print(our_x[0, 0])
print(their_x[0, 0])

their_y = their_gpt(their_x)
# (10, 16, 1082, 64) is the shape of the attentions
our_y = our_gpt(our_x)
torch.testing.assert_close(their_y[0], torch.from_numpy(numpy.array(our_y[0])))

All the three main parts that consitute the last part, the block have been coded, we can proceed to the actual model.

In [None]:
import typing as tp
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions


class TransformerLayer(eqx.Module):
    wte: nn.Embedding  # Token embeddings
    wpe: nn.Embedding  # Positional embeddings

    drop: nn.Dropout

    embed_dim: int

    h: list
    norm: nn.LayerNorm

    def __init__(self, config, key):
        key1, key2 = jax.random.split(key, 2)

        self.embed_dim = config.hidden_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim, key=key1)
        self.wpe = nn.Embedding(
            config.max_position_embeddings, self.embed_dim, key=key2
        )
        self.drop = nn.Dropout(config.embd_pdrop)

        self.h = [
            Block(config, y) for y in jax.random.split(key, config.num_hidden_layers)
        ]
        self.norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def __call__(
        self,
        input_ids: tp.Optional[jax.Array] = None,  # One ID inputted ?
        past_key_values: tp.Optional[jax.Array] = None,  # Used !
        attention_mask: tp.Optional[jax.Array] = None,  # Used !
        token_type_ids: tp.Optional[jax.Array] = None,  # Not used
        position_ids: tp.Optional[jax.Array] = None,  # Used !
        head_mask: tp.Optional[jax.Array] = None,  # Isn't used
        inputs_embeds: tp.Optional[jax.Array] = None,  # Isn't used
        output_attentions: tp.Optional[bool] = None,  # Isn't used
        output_hidden_states: tp.Optional[bool] = None,  # Isn't used
        use_cache: tp.Optional[bool] = False,  # Set to true.
        return_dict: tp.Optional[bool] = False,  # Set to true.
    ):

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            # self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.shape
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.shape[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        # Should use better positional embeddings with cos and sin.
        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        else:
            past_length = past_key_values[0].shape[-2]
        if position_ids is None:
            position_ids = jax.numpy.arange(past_length, input_shape[-1] + past_length)
            position_ids = jax.numpy.expand_dims(position_ids, 0)

        if inputs_embeds is None:
            inputs_embeds = jax.vmap(jax.vmap(self.wte))(input_ids)

        # pos = jnp.arange(0, t, dtype=jnp.int64)

        position_embeds = jax.vmap(jax.vmap(self.wpe))(position_ids)

        # Dropout at the first layer ? Seems a bit aggressive...
        hidden_states = inputs_embeds + position_embeds

        # No need for fancy stuff for the attention mask, simply since it's applied before the softmax change the values of 0 to -inf
        if attention_mask is not None:

            attention_mask = jax.numpy.where(
                jax.numpy.equal(attention_mask, 1), 1, -jax.numpy.inf
            )
        # No cross attention so we're all good.
        # No head mask.
        # Token type ids is none
        hidden_states = self.drop(hidden_states)
        # Not training.
        presents = () if use_cache else None
        # Output attentions not used
        # no cross attention_mask
        # No output hidden states
        # print(f"Ours : {hidden_states[0,0,:10]}")

        for block, layer_past in zip(self.h, past_key_values):
            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=None,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )
            hidden_states = outputs[0]
            # print(f"Ours : {hidden_states[0,0,:10]}")
            if use_cache:
                presents = presents + (outputs[1],)

        hidden_states = jax.vmap(jax.vmap(self.norm))(hidden_states)

        if return_dict:
            return BaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=hidden_states,
                past_key_values=presents,
                attentions=None,
                cross_attentions=None,
                hidden_states=None,
            )

        return hidden_states

Comparing with their work...

In [38]:
%reload_ext autoreload
%autoreload 2

In [39]:
# | code-fold : true
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import jax, torch, numpy
from dataclasses import dataclass

from transformers.models.gpt2.modeling_gpt2 import GPT2Config


their_config = GPT2Config(
    vocab_size=256,  # Unused.
    n_positions=1082,
    n_ctx=1082,
    n_embd=1024,
    n_layer=30,
    n_head=16,
    gradient_checkpointing=False,
    use_cache=True,
)

key = jax.random.PRNGKey(2)

their_gpt = GPT2Model(their_config)

torch_params = {
    name: param.detach().numpy() for name, param in their_gpt.named_parameters()
}

In [None]:
# # Function to update the JAX model parameters
# def update_params(path, x):
#     path = ".".join([str(p).strip("[].") for p in path])
#     # for jax_key, torch_key in torch_to_jax_keys:
#     if path in torch_params.keys():
#         if "bias" in path:
#             return jax.numpy.array(torch_params[path])
#         return jax.numpy.array(torch_params[path])
#     print(path)
#     return x


jax.tree_util.tree_map_with_path(lambda path, _: print(path), our_gpt)

In [40]:
our_gpt = TransformerLayer(their_config, jax.random.PRNGKey(1))


# Function to update the JAX model parameters
def update_params(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    # for jax_key, to\rch_key in torch_to_jax_keys:
    if path in torch_params.keys():
        # if "ln_" in path:
        #     print(x.shape)
        #     print(torch_params[path].shape)
        #     print(path)
        #     return x
        # if "ln_" in path:
        #     print(path)
        return jax.numpy.array(torch_params[path])
    # print(path)
    return x


# our_gpt = jax.tree_util.tree_map_with_path(update_params, our_gpt)
our_gpt = jax.tree_util.tree_map_with_path(update_params, our_gpt)
# # our_gpt.h.0.ln_1.
eqx.tree_serialise_leaves("xttsgpt.eqx", our_gpt)
our_gpt = eqx.tree_deserialise_leaves("xttsgpt.eqx", our_gpt)

our_x = jax.random.randint(jax.random.PRNGKey(2), shape=(1, 1082), minval=0, maxval=255)
their_x = torch.from_numpy(numpy.array(our_x))
print(our_x[0, 0])
print(their_x[0, 0])

their_y = their_gpt(their_x)
# (10, 16, 1082, 64) is the shape of the attentions
our_y = our_gpt(our_x)

# print(their_y[0][0, 0, :10])
# print(our_y[0, 0, :10])
torch.testing.assert_close(their_y[0], torch.from_numpy(numpy.array(our_y)))
del our_gpt

  self.attn = CausalSelfAttention(config, key=key1)


5
tensor(5, dtype=torch.int32)
Theirs : tensor([ 0.0346, -0.0105,  0.0182,  0.0283, -0.0087,  0.0409, -0.0246,  0.0024,
        -0.0308, -0.0179], grad_fn=<SliceBackward0>)
ccra
Theirs : tensor([ 0.0374, -0.0112,  0.0191,  0.0348, -0.0100,  0.0409, -0.0253, -0.0005,
        -0.0327, -0.0191], grad_fn=<SliceBackward0>)
ccra
Theirs : tensor([ 0.0371, -0.0128,  0.0212,  0.0345, -0.0107,  0.0429, -0.0272,  0.0021,
        -0.0311, -0.0163], grad_fn=<SliceBackward0>)
ccra
Theirs : tensor([ 0.0361, -0.0096,  0.0201,  0.0358, -0.0133,  0.0418, -0.0299,  0.0035,
        -0.0295, -0.0158], grad_fn=<SliceBackward0>)
ccra
Theirs : tensor([ 0.0351, -0.0097,  0.0204,  0.0386, -0.0165,  0.0390, -0.0302,  0.0072,
        -0.0325, -0.0134], grad_fn=<SliceBackward0>)
ccra
Theirs : tensor([ 0.0348, -0.0120,  0.0210,  0.0416, -0.0160,  0.0389, -0.0306,  0.0066,
        -0.0322, -0.0124], grad_fn=<SliceBackward0>)
ccra
Theirs : tensor([ 0.0368, -0.0152,  0.0201,  0.0385, -0.0154,  0.0401, -0.0299,  0.0042

AssertionError: Tensor-likes are not close!

Mismatched elements: 1107846 / 1107968 (100.0%)
Greatest absolute difference: 0.6334989666938782 at index (0, 95, 23) (up to 1e-05 allowed)
Greatest relative difference: 3572257.75 at index (0, 508, 407) (up to 1.3e-06 allowed)

In [26]:
our_x = jax.random.normal(
    jax.random.PRNGKey(2),
    shape=(1, 64, 1024),
    dtype=numpy.float32,
)
their_x = torch.from_numpy(numpy.array(our_x))

their_y = their_gpt.h[0].ln_1(their_x)
# (10, 16, 1082, 64) is the shape of the attentions
our_y = jax.vmap(jax.vmap(our_gpt.h[0].ln_1))(our_x)

# print(their_y[0][0, 0, :10])
# print(our_y[0, 0, :10])
torch.testing.assert_close(their_y, torch.from_numpy(numpy.array(our_y)))

In [None]:
class GPT(eqx.Module):
    transformer: TransformerLayer
    lm_head: nn.Linear

    def __init__(self, config, key):
        key1, key2 = jax.random.split(key, 2)

        self.transformer = TransformerLayer(config, key1)
        self.lm_head = nn.Linear(
            config.n_embd, config.vocab_size, use_bias=False, key=key2
        )

    def __call__(self, token_ids):
        y = self.transformer(token_ids)
        logits = jax.vmap(self.lm_head)(y)
        return logits

We can compare our method with the one implemented in nanoJAXGPT: