# Llama to Jax

In this notebook we test the conversion of the Llama architecture to use Jax as the backend, by isolating the components individually and ensuring that they yield the same results, as function.

In [1]:
import jax
import jax.numpy as jnp
import torch
import numpy as np

In [2]:
import sys
import os

os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [3]:
from safetensors.torch import load_file
safetensors_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/model.safetensors"  
weights = load_file(safetensors_path)


In [4]:
from llama.model import ModelArgs
import json

config_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/config.json"
with open(config_path, "r") as f:
    config = json.load(f)

# extract the necessary fields
model_args = ModelArgs(
    dim=config.get("hidden_size", 4096), 
    n_layers=config.get("num_hidden_layers", 32),  
    n_heads=config.get("num_attention_heads", 32), 
    n_kv_heads=config.get("num_key_value_heads", None), 
    vocab_size=config.get("vocab_size", -1), 
    multiple_of=256, # not in config so use the default
    norm_eps=config.get("rms_norm_eps", 1e-5),  # map "rms_norm_eps"
    max_batch_size=32,  # not in config so use the default
    use_scale_rope=True, # this is how it was in llama.ipynb
    rope_scale_factor=config.get("rope_scaling").get("factor"),
    original_rotary_embed_len=config.get("rope_scaling").get("original_max_position_embeddings"),
    cache_len = 2048,
)

print(model_args)

ModelArgs(dim=2048, n_layers=16, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=256, norm_eps=1e-05, rope_theta=500000, use_scaled_rope=True, rope_scale_factor=32.0, max_batch_size=32, original_rotary_embed_len=8192, cache_len=2048)


## `RMSNorm`

Here we'll demonstrate how to create jax arrays from the pytorch parameters, and see how our jax implementation of `RMSNorm` compares to the one we already know is a part of the working model

In [5]:
# load the functions to compare
from llama.model import RMSNorm as RMSNorm_pt
print("loaded source norm")
from llama_jax.model import RMSNorm as RMSNorm_jax
print("loaded jax norm")

loaded source norm


loaded jax norm


In [6]:
# check the weights to find a good demo tensor
[key for key in weights.keys() if 'norm' in key][0:2]

['model.layers.0.input_layernorm.weight',
 'model.layers.0.post_attention_layernorm.weight']

In [7]:
# pytorch tensor --> jax array
rms_weights_pt = weights['model.layers.0.input_layernorm.weight']
rms_weights_jax = jnp.array(rms_weights_pt.detach().float().numpy(), dtype=jnp.bfloat16)

print(f"pytorch: {rms_weights_jax}, jax: {rms_weights_jax}")
print(f"pytorch shape: {rms_weights_pt.shape}, jax shape: {rms_weights_jax.shape}")



pytorch: [0.158203 0.180664 0.269531 ... 0.22168 0.210938 0.152344], jax: [0.158203 0.180664 0.269531 ... 0.22168 0.210938 0.152344]
pytorch shape: torch.Size([2048]), jax shape: (2048,)


In [8]:
# put the weights in an isolated pytorch RMSNorm module
rms_norm_pt = RMSNorm_pt(2048)

with torch.no_grad():
    # overwrite the RMSNorm weight with the one from the loaded state_dict
    rms_norm_pt.weight.copy_(rms_weights_pt)

rms_norm_pt

RMSNorm()

In [9]:
# run the pytorch version on a sample tensor to get a "true" value

with torch.no_grad():
    x_torch = torch.randn(2, 2048) # add a batch dim

    y_torch = rms_norm_pt(x_torch)
    print("PyTorch output:", y_torch)

PyTorch output: tensor([[-0.0100,  0.1573, -0.4949,  ..., -0.0077,  0.1183,  0.1864],
        [-0.0805, -0.2562,  0.1460,  ..., -0.0104,  0.0233,  0.0328]])


In [10]:
# call our jax implementation and check the output is the same

rms_norm_jax = lambda x : RMSNorm_jax(x, rms_weights_jax)

x_jax = jnp.array(x_torch.detach().numpy(), dtype=jnp.bfloat16)
y_jax = rms_norm_jax(x_jax)

print("Jax output:", y_jax)

Jax output: [[-0.00994873 0.157227 -0.494141 ... -0.00769043 0.118652 0.186523]
 [-0.0805664 -0.257812 0.146484 ... -0.010376 0.0233154 0.0327148]]


It all looks good :), the pytorch outputs `bfloat16` too, whereas jax has higher precision. We'll worry about that later, since I don't know whether that is the correct behaviour when the function is just one step in the overall architecture (for the pytorch implementation)

## `precompute_freq_cis`



In [11]:
from llama.model import precompute_freqs_cis as precompute_freqs_cis_pt
from llama_jax.model import precompute_freqs_cis as precompute_freqs_cis_jax

In [12]:
dim = 2048 // 32 

freq_cis_pt = precompute_freqs_cis_pt(dim, 2048, 500_000, True, 32, 8192)
print(f"{freq_cis_pt[1000,:]=}")
print(f"{freq_cis_pt.dtype=}")

freq_cis_pt[1000,:]=tensor([ 0.5624+8.2688e-01j, -0.7484-6.6329e-01j,  0.8558+5.1728e-01j,
        -0.9982-5.9692e-02j,  0.6555-7.5522e-01j, -0.9931+1.1765e-01j,
        -0.8397-5.4308e-01j,  0.9927+1.2066e-01j,  0.9957-9.2948e-02j,
         0.9843-1.7641e-01j, -0.6581-7.5291e-01j, -0.0060-9.9998e-01j,
         0.5323+8.4655e-01j,  0.1267-9.9194e-01j, -0.9976-6.9797e-02j,
        -0.5315+8.4708e-01j,  0.1559+9.8777e-01j,  0.5910+8.0666e-01j,
         0.9998+1.9460e-02j,  0.9999+1.2914e-02j,  1.0000+8.5702e-03j,
         1.0000+5.6872e-03j,  1.0000+3.7740e-03j,  1.0000+2.5045e-03j,
         1.0000+1.6620e-03j,  1.0000+1.1029e-03j,  1.0000+7.3187e-04j,
         1.0000+4.8567e-04j,  1.0000+3.2229e-04j,  1.0000+2.1387e-04j,
         1.0000+1.4193e-04j,  1.0000+9.4183e-05j])
freq_cis_pt.dtype=torch.complex64


In [13]:
freq_cis_jax = precompute_freqs_cis_jax(dim, 2048, 500_000, True, 32, 8192)
print(f"{freq_cis_jax[1000,:]=}")
print(f"{freq_cis_jax.dtype=}")


freq_cis_jax[1000,:]=Array([ 0.56237906+8.2687956e-01j, -0.7483619 -6.6329068e-01j,
        0.85581774+5.1727748e-01j, -0.99821687-5.9691951e-02j,
        0.6554753 -7.5521660e-01j, -0.9930554 +1.1764777e-01j,
       -0.839681  -5.4307991e-01j,  0.99269414+1.2065805e-01j,
        0.995671  -9.2947975e-02j,  0.98431766-1.7640516e-01j,
       -0.65812033-7.5291276e-01j, -0.00604559-9.9998170e-01j,
        0.53230125+8.4655499e-01j,  0.12669091-9.9194223e-01j,
       -0.9975612 -6.9796599e-02j, -0.53146   +8.4708339e-01j,
        0.15594384+9.8776591e-01j,  0.59101975+8.0665708e-01j,
        0.99981064+1.9460410e-02j,  0.9999166 +1.2914409e-02j,
        0.9999633 +8.5701505e-03j,  0.99998385+5.6872014e-03j,
        0.9999929 +3.7740455e-03j,  0.99999684+2.5044645e-03j,
        0.9999986 +1.6619666e-03j,  0.9999994 +1.1028834e-03j,
        0.99999976+7.3187490e-04j,  0.9999999 +4.8567311e-04j,
        0.99999994+3.2229329e-04j,  1.        +2.1387423e-04j,
        1.        +1.4192720e-04j,

In [14]:
np.linalg.norm(freq_cis_pt.detach().numpy() - np.array(freq_cis_jax))

0.0

these match too, since the arrays are bigger, we check by looking at the norm too - which shows they are identical.

## `apply_rotary_emb`

In [15]:
from llama.model import apply_rotary_emb as apply_rotary_emb_pt
from llama_jax.model import apply_rotary_emb as apply_rotary_emb_jax

In [16]:
# setting up dummy data

bsz = 1  # batch size
seqlen = 30 # dummy value
n_local_heads = 32 # no parallelism so local = total
head_dim = 2048 // 32

dummy_shape = (bsz, seqlen, n_local_heads, head_dim)

with torch.no_grad():
    freq_cis_pt = precompute_freqs_cis_pt(dim, seqlen, 500_000, True, 32, 8192)
    xq_torch = torch.randn(dummy_shape) 
    xk_torch = torch.randn(dummy_shape)

freq_cis_jax = precompute_freqs_cis_jax(dim, seqlen, 500_000, True, 32, 8192)
xq_jax = jnp.array(xq_torch.detach().numpy())
xk_jax = jnp.array(xk_torch.detach().numpy())

In [17]:
with torch.no_grad():
    yq_torch, yk_torch = apply_rotary_emb_pt(xq_torch, xk_torch, freq_cis_pt)

yq_jax, yk_jax = apply_rotary_emb_jax(xq_jax, xk_jax, freq_cis_jax)

In [18]:
print(f"{yq_torch.shape=}")
print(f"{yq_jax.shape=}")

print(f"{yk_torch.shape=}")
print(f"{yk_jax.shape=}")

print(f"{yq_torch.dtype=}")
print(f"{yq_jax.dtype=}")

yq_torch.shape=torch.Size([1, 30, 32, 64])
yq_jax.shape=(1, 30, 32, 64)
yk_torch.shape=torch.Size([1, 30, 32, 64])
yk_jax.shape=(1, 30, 32, 64)
yq_torch.dtype=torch.float32
yq_jax.dtype=dtype('float32')


In [19]:
print(f"q error: {np.linalg.norm(yq_torch.detach().numpy() - np.array(yq_jax))}")
print(f"k error: {np.linalg.norm(yk_torch.detach().numpy() - np.array(yk_jax))}")

q error: 7.465223916369723e-06
k error: 7.636589543835726e-06


there is a slight difference in the values here, but over the vast amount of dummy data it is well within numerical tolerances

## `attention_block`

In [20]:
from llama.model import Attention as Attention_pt
from llama_jax.model import attention_block as attention_block_jax

In [21]:
[key for key in weights.keys() if '10' in key]

['model.layers.10.input_layernorm.weight',
 'model.layers.10.mlp.down_proj.weight',
 'model.layers.10.mlp.gate_proj.weight',
 'model.layers.10.mlp.up_proj.weight',
 'model.layers.10.post_attention_layernorm.weight',
 'model.layers.10.self_attn.k_proj.weight',
 'model.layers.10.self_attn.o_proj.weight',
 'model.layers.10.self_attn.q_proj.weight',
 'model.layers.10.self_attn.v_proj.weight']

In [22]:
# get all the required parameters for one attention block
# check that the conversion leaves shape and dtype invariant

layer10_keys = [key for key in weights.keys() if ('10' in key) and ('norm' not in key) and ('mlp' not in key)]
layer10_weights_pt = { key : weights[key] for key in layer10_keys}
layer10_weights_jax = { key : jnp.array(weights[key].detach().float().numpy(), dtype = jnp.bfloat16) for key in layer10_keys}

print("pytorch params:")
for key, value in layer10_weights_pt.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

print("\njax params")
for key, value in layer10_weights_jax.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

pytorch params:
    model.layers.10.self_attn.k_proj.weight dtype: torch.bfloat16 shape: torch.Size([512, 2048])
    model.layers.10.self_attn.o_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 2048])
    model.layers.10.self_attn.q_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 2048])
    model.layers.10.self_attn.v_proj.weight dtype: torch.bfloat16 shape: torch.Size([512, 2048])

jax params
    model.layers.10.self_attn.k_proj.weight dtype: bfloat16 shape: (512, 2048)
    model.layers.10.self_attn.o_proj.weight dtype: bfloat16 shape: (2048, 2048)
    model.layers.10.self_attn.q_proj.weight dtype: bfloat16 shape: (2048, 2048)
    model.layers.10.self_attn.v_proj.weight dtype: bfloat16 shape: (512, 2048)


In [23]:
with torch.no_grad():
    attention_pt = Attention_pt(model_args)

with torch.no_grad():
    attention_pt.wq.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.q_proj.weight"])
    attention_pt.wk.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.k_proj.weight"])
    attention_pt.wv.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.v_proj.weight"])
    attention_pt.wo.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.o_proj.weight"])

attention block initialised


In [24]:
dummy_x_shape = (bsz, seqlen, 2048)

In [25]:
x_torch = torch.randn(dummy_x_shape) 
y_torch = attention_pt.forward(x_torch, 0, freq_cis_pt, None)

In [26]:
x_jax = jnp.array(x_torch.detach().numpy(), dtype=jnp.bfloat16)

y_jax = attention_block_jax(x_jax, None, freq_cis_jax,
    wq = layer10_weights_jax["model.layers.10.self_attn.q_proj.weight"],
    wk = layer10_weights_jax["model.layers.10.self_attn.k_proj.weight"],
    wv = layer10_weights_jax["model.layers.10.self_attn.v_proj.weight"],
    wo = layer10_weights_jax["model.layers.10.self_attn.o_proj.weight"],
    n_heads = 32,
    n_kv_heads = 8)

In [27]:
np.linalg.norm(y_torch.detach().numpy() - np.array(y_jax))

0.33409098

###  `AttentionBlock` with a mask

In [28]:
zeros = torch.zeros(20) 
neg_inf = torch.full((10,), float('-inf')) 
mask_pt = torch.cat([zeros, neg_inf]) 
mask_pt

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf,
        -inf, -inf, -inf, -inf, -inf, -inf])

In [29]:
y_torch = attention_pt.forward(x_torch, 0, freq_cis_pt, mask_pt)

In [30]:
mask_jax = jnp.array(mask_pt.detach().numpy())

y_jax = attention_block_jax(x_jax, mask_jax, freq_cis_jax,
    wq = layer10_weights_jax["model.layers.10.self_attn.q_proj.weight"],
    wk = layer10_weights_jax["model.layers.10.self_attn.k_proj.weight"],
    wv = layer10_weights_jax["model.layers.10.self_attn.v_proj.weight"],
    wo = layer10_weights_jax["model.layers.10.self_attn.o_proj.weight"],
    n_heads = 32,
    n_kv_heads = 8)

In [31]:
np.linalg.norm(y_torch.detach().numpy() - np.array(y_jax))

0.3009763

## `FeedForward`

This is a special MLP implementation that passes a through a linear product of the values, which tempers the non-linearity of the `silu` activation function

In [32]:
from llama.model import FeedForward as FeedForward_pt
from llama_jax.model import feed_forward as feed_forward_jax

In [33]:
[key for key in weights.keys() if ('10' in key) and ('mlp' in key)]

# down -> w2
# up   -> w3
# gate -> w1

['model.layers.10.mlp.down_proj.weight',
 'model.layers.10.mlp.gate_proj.weight',
 'model.layers.10.mlp.up_proj.weight']

In [34]:
layer10_keys = [key for key in weights.keys() if ('10' in key) and ('mlp' in key)]
layer10_weights_pt = { key : weights[key] for key in layer10_keys}
layer10_weights_jax = { key : jnp.array(weights[key].detach().float().numpy(), dtype = jnp.bfloat16) for key in layer10_keys}

print("pytorch params:")
for key, value in layer10_weights_pt.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

print("\njax params")
for key, value in layer10_weights_jax.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

pytorch params:
    model.layers.10.mlp.down_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 8192])
    model.layers.10.mlp.gate_proj.weight dtype: torch.bfloat16 shape: torch.Size([8192, 2048])
    model.layers.10.mlp.up_proj.weight dtype: torch.bfloat16 shape: torch.Size([8192, 2048])

jax params
    model.layers.10.mlp.down_proj.weight dtype: bfloat16 shape: (2048, 8192)
    model.layers.10.mlp.gate_proj.weight dtype: bfloat16 shape: (8192, 2048)
    model.layers.10.mlp.up_proj.weight dtype: bfloat16 shape: (8192, 2048)


In [35]:
with torch.no_grad():
    feed_forward_pt = FeedForward_pt(2048, 4 * 2048, 256)

with torch.no_grad():
    feed_forward_pt.w1.weight.copy_(layer10_weights_pt["model.layers.10.mlp.gate_proj.weight"])
    feed_forward_pt.w2.weight.copy_(layer10_weights_pt["model.layers.10.mlp.down_proj.weight"])
    feed_forward_pt.w3.weight.copy_(layer10_weights_pt["model.layers.10.mlp.up_proj.weight"])

In [36]:
y_torch = feed_forward_pt.forward(x_torch)

In [37]:
y_jax = feed_forward_jax(x_jax,
    gate = layer10_weights_jax["model.layers.10.mlp.gate_proj.weight"], 
    down = layer10_weights_jax["model.layers.10.mlp.down_proj.weight"], 
    up   = layer10_weights_jax["model.layers.10.mlp.up_proj.weight"])

In [38]:
np.linalg.norm(y_torch.detach().numpy() - np.array(y_jax))

0.8660234

## `TransfomerBlock`

In [39]:
from llama.model import TransformerBlock as TransformerBlock_pt
from llama_jax.model import transformer_block as transformer_block_jax

In [40]:
layer10_keys = [key for key in weights.keys() if ('10' in key)]
layer10_weights_pt = { key : weights[key] for key in layer10_keys}
layer10_weights_jax = { key : jnp.array(weights[key].detach().float().numpy(), dtype = jnp.bfloat16) for key in layer10_keys}

print("pytorch params:")
for key, value in layer10_weights_pt.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

print("\njax params")
for key, value in layer10_weights_jax.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

pytorch params:
    model.layers.10.input_layernorm.weight dtype: torch.bfloat16 shape: torch.Size([2048])
    model.layers.10.mlp.down_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 8192])
    model.layers.10.mlp.gate_proj.weight dtype: torch.bfloat16 shape: torch.Size([8192, 2048])
    model.layers.10.mlp.up_proj.weight dtype: torch.bfloat16 shape: torch.Size([8192, 2048])
    model.layers.10.post_attention_layernorm.weight dtype: torch.bfloat16 shape: torch.Size([2048])
    model.layers.10.self_attn.k_proj.weight dtype: torch.bfloat16 shape: torch.Size([512, 2048])
    model.layers.10.self_attn.o_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 2048])
    model.layers.10.self_attn.q_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 2048])
    model.layers.10.self_attn.v_proj.weight dtype: torch.bfloat16 shape: torch.Size([512, 2048])

jax params
    model.layers.10.input_layernorm.weight dtype: bfloat16 shape: (2048,)
    model.layers.10.mlp.down_proj.

In [41]:
with torch.no_grad():
    transformer_block_pt = TransformerBlock_pt(10, model_args)

with torch.no_grad():
    transformer_block_pt.attention_norm.weight.copy_(layer10_weights_pt["model.layers.10.input_layernorm.weight"])

    transformer_block_pt.attention.wq.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.q_proj.weight"])
    transformer_block_pt.attention.wk.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.k_proj.weight"])
    transformer_block_pt.attention.wv.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.v_proj.weight"])
    transformer_block_pt.attention.wo.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.o_proj.weight"])

    transformer_block_pt.feed_forward.w1.weight.copy_(layer10_weights_pt["model.layers.10.mlp.gate_proj.weight"])
    transformer_block_pt.feed_forward.w2.weight.copy_(layer10_weights_pt["model.layers.10.mlp.down_proj.weight"])
    transformer_block_pt.feed_forward.w3.weight.copy_(layer10_weights_pt["model.layers.10.mlp.up_proj.weight"])

    transformer_block_pt.ffn_norm.weight.copy_(layer10_weights_pt["model.layers.10.post_attention_layernorm.weight"])

attention block initialised


In [42]:
with torch.no_grad():
    y_torch = transformer_block_pt.forward(x_torch, 0, freq_cis_pt, mask_pt)

In [43]:
# build the jax param pytree - will need to write a script to do this automatically for the saved weights, once the full architecture is translated

attention_params = {
    "wq" : layer10_weights_jax["model.layers.10.self_attn.q_proj.weight"],
    "wk" : layer10_weights_jax["model.layers.10.self_attn.k_proj.weight"],
    "wv" : layer10_weights_jax["model.layers.10.self_attn.v_proj.weight"],
    "wo" : layer10_weights_jax["model.layers.10.self_attn.o_proj.weight"]
}

ff_params = {
    "up"   : layer10_weights_jax["model.layers.10.mlp.up_proj.weight"],
    "gate" : layer10_weights_jax["model.layers.10.mlp.gate_proj.weight"], 
    "down" : layer10_weights_jax["model.layers.10.mlp.down_proj.weight"]
}

norm_params = {
    "pre_attention_rms"  : layer10_weights_jax["model.layers.10.input_layernorm.weight"],
    "post_attention_rms" : layer10_weights_jax["model.layers.10.post_attention_layernorm.weight"]
}

param_pytree = {
    "attention"    : attention_params,
    "feed_forward" : ff_params,
    "norms"        : norm_params
}

In [44]:
x_jax

Array([[[0.640625, -0.613281, 0.527344, ..., 1.46094, -0.714844,
         0.785156],
        [0.371094, 0.0722656, -0.263672, ..., -0.238281, 0.00637817,
         0.511719],
        [-2.01562, -0.447266, -1.26562, ..., 1.00781, -0.578125,
         1.03906],
        ...,
        [-0.166016, -0.988281, 0.15918, ..., -1.33594, -0.746094,
         -0.582031],
        [-0.546875, -0.863281, 0.550781, ..., 0.839844, -1.03906,
         -1.41406],
        [1.4375, 1.17188, 0.964844, ..., 0.486328, 0.652344, 1.98438]]],      dtype=bfloat16)

In [45]:
y_jax = transformer_block_jax(x_jax,
    param_pytree,
    mask_jax, freq_cis_jax,
    n_heads = 32, n_kv_heads = 8)

In [46]:
np.linalg.norm(y_torch.detach().numpy() - np.array(y_jax))

0.757297

we're starting to see some drift here, but given that there is still a difference in the precision used, I think this is not a dealbreaker.

## `Transformer`

**NOTE** - tokens usually have a **preceding** space, adding a space at the end of the prompt really messes things up, and we get weird predictions

In [47]:
from llama.model import Transformer as Transformer_pt
from llama_jax.model import transformer as transformer_jax

In [48]:
transformer_pt = Transformer_pt(model_args)

attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised


In [49]:
fixed_state_dict = {}

for key in weights.keys():
    
    new_key = key
    new_key = new_key.replace("model.", "")
    
    new_key = new_key.replace("embed_tokens.weight", "tok_embeddings.weight")

    new_key = new_key.replace("self_attn.q_proj", "attention.wq")
    new_key = new_key.replace("self_attn.k_proj", "attention.wk")
    new_key = new_key.replace("self_attn.v_proj", "attention.wv")
    new_key = new_key.replace("self_attn.o_proj", "attention.wo")

    new_key = new_key.replace("mlp.gate_proj", "feed_forward.w1")
    new_key = new_key.replace("mlp.up_proj", "feed_forward.w3")
    new_key = new_key.replace("mlp.down_proj", "feed_forward.w2")

    new_key = new_key.replace("input_layernorm", "attention_norm")
    new_key = new_key.replace("post_attention_layernorm", "ffn_norm")

    new_key = new_key.replace("model.norm", "norm")

    fixed_state_dict[new_key] = weights[key]

fixed_state_dict["output.weight"] = fixed_state_dict["tok_embeddings.weight"]

In [50]:
transformer_pt.load_state_dict(fixed_state_dict, strict = True)


<All keys matched successfully>

In [51]:
from llama.tokenizer import Tokenizer

tok_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf-tok/tokenizer.model"
tok = Tokenizer(tok_path)
tok

<llama.tokenizer.Tokenizer at 0x7febf3f284a0>

In [52]:
toks = tok.encode("The capital of France is", bos=True, eos=False)
print(f"{toks=}")
print(f"{tok.decode(toks)=}")

toks = torch.tensor(toks, dtype=torch.int).unsqueeze(0) # dummy batch dim

toks=[128000, 791, 6864, 315, 9822, 374]
tok.decode(toks)='<|begin_of_text|>The capital of France is'


In [53]:
out_torch = transformer_pt(toks, 0)
out_torch

tensor([[[ 7.0544,  9.0268, 13.3233,  ..., -3.7595, -3.7596, -3.7596],
         [ 5.3197,  6.3458,  5.5625,  ..., -0.9394, -0.9385, -0.9387],
         [ 9.4542,  7.4223,  5.0388,  ..., -0.4380, -0.4379, -0.4378],
         [ 8.4220,  8.3505,  5.5084,  ...,  0.1847,  0.1850,  0.1848],
         [14.8629, 12.0516,  9.5194,  ...,  0.8880,  0.8878,  0.8880],
         [10.2351,  8.6261,  4.6638,  ...,  0.8458,  0.8463,  0.8458]]])

In [54]:
print(tok.encode(" Paris", bos=False, eos = False))

[12366]


In [55]:
probs = torch.softmax(out_torch[0, -1, :], dim=0)

top_k = 10
top_probs, top_tokens = torch.topk(probs, top_k)

log_probs = torch.log(top_probs)

for i in range(top_k):
    token_str = tok.decode([top_tokens[i].item()])
    print(f"Rank {i+1}: Token '{token_str}' with log probability {log_probs[i].item()}")

next_token = top_tokens[0]
print(f"\nNext token is: '{tok.decode([next_token.item()])}'")

token_str = " Paris"
print(f"Token '{token_str}' with log probability {torch.log(probs[12366]).item()}")


Rank 1: Token ' the' with log probability -2.0233542919158936
Rank 2: Token ' Paris' with log probability -2.0792510509490967
Rank 3: Token ' a' with log probability -2.3678767681121826
Rank 4: Token ' located' with log probability -2.5045363903045654
Rank 5: Token ' situated' with log probability -2.869053602218628
Rank 6: Token ' in' with log probability -3.737622022628784
Rank 7: Token ' known' with log probability -3.8099029064178467
Rank 8: Token ' one' with log probability -3.9084270000457764
Rank 9: Token ' not' with log probability -4.317896366119385
Rank 10: Token ' called' with log probability -4.401501178741455

Next token is: ' the'
Token ' Paris' with log probability -2.0792510509490967


In [56]:
num_layers = 16

In [57]:
to_jax = lambda pt : jnp.array(pt.detach().float().numpy(), dtype=jnp.bfloat16)

In [58]:
tok_embed_jax = to_jax(weights["model.embed_tokens.weight"])

In [59]:
from tqdm import tqdm

layer_weights = []

keys_left = weights.keys()

# reverse the range because we'll remove keys we have seen already
# this will help because e.g. 1 is in 15, but 15 is not in 1 

layer_params = []

for i in tqdm(reversed(range(num_layers)), total = num_layers):

    layer_keys = [key for key in keys_left if (str(i) in key)]
    keys_left  = list(set(keys_left) - set(layer_keys))
    
    attention_params = {
        "wq" : to_jax(weights[f"model.layers.{i}.self_attn.q_proj.weight"]),
        "wk" : to_jax(weights[f"model.layers.{i}.self_attn.k_proj.weight"]),
        "wv" : to_jax(weights[f"model.layers.{i}.self_attn.v_proj.weight"]),
        "wo" : to_jax(weights[f"model.layers.{i}.self_attn.o_proj.weight"])
    }

    ff_params = {
        "up"   : to_jax(weights[f"model.layers.{i}.mlp.up_proj.weight"]),
        "gate" : to_jax(weights[f"model.layers.{i}.mlp.gate_proj.weight"]), 
        "down" : to_jax(weights[f"model.layers.{i}.mlp.down_proj.weight"])
    }

    norm_params = {
        "pre_attention_rms"  : to_jax(weights[f"model.layers.{i}.input_layernorm.weight"]),
        "post_attention_rms" : to_jax(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
    }

    param_pytree = {
        "attention"    : attention_params,
        "feed_forward" : ff_params,
        "norms"        : norm_params
    }

    layer_params.append(param_pytree)

layer_params.reverse()

100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:41<00:00,  2.60s/it]


In [60]:
params_jax = {
    "tok_embeddings" : to_jax(weights["model.embed_tokens.weight"]),
    "freqs_cis"      : freq_cis_jax,
    "layers"         : layer_params,
    "norm_scale"     : to_jax(weights["model.norm.weight"]),
    "output_weight"  : to_jax(weights["model.embed_tokens.weight"])
}

In [61]:
freq_cis_jax.shape

(30, 32)

In [62]:
toks = tok.encode("The capital of France is", bos=True, eos=False)
toks = toks + [0] * (freq_cis_jax.shape[0] - len(toks))

print(f"{len(toks)=}") # pad out since we need the precomputed freq_cis to match the number of tokens

mask = [0 if tok != 0 else -jnp.inf for tok in toks] # mask the pad tokens

toks = jnp.array(toks)[None, :] # add dummy "batch" axis
mask = jnp.array(mask)

len(toks)=30


In [63]:
out_jax = transformer_jax(toks, params_jax, mask, n_heads = 32, n_kv_heads = 8)

In [64]:
probs = jax.nn.softmax(out_jax[0, -1, :], axis=0)

top_k = 25
top_probs, top_tokens = jax.lax.top_k(probs, top_k)

log_probs = jnp.log(top_probs)

for i in range(top_k):
    token_str = tok.decode([top_tokens[i].item()])
    print(f"Rank {i+1}: Token '{token_str}' with log probability {log_probs[i].item()}")

next_token = top_tokens[0]
print(f"\nNext token is: '{tok.decode([next_token.item()])}'")

token_str = " Paris"
print(f"Token '{token_str}' with log probability {jnp.log(probs[12366]).item()}")


Rank 1: Token 'The' with log probability -3.1875
Rank 2: Token ' The' with log probability -3.25
Rank 3: Token 'A' with log probability -3.875
Rank 4: Token ' I' with log probability -3.875
Rank 5: Token 'France' with log probability -3.875
Rank 6: Token ' This' with log probability -4.125
Rank 7: Token 'This' with log probability -4.25
Rank 8: Token ' ' with log probability -4.4375
Rank 9: Token ' It' with log probability -4.4375
Rank 10: Token ' France' with log probability -4.4375
Rank 11: Token ' A' with log probability -4.5
Rank 12: Token ' ' with log probability -4.5625
Rank 13: Token 'I' with log probability -4.6875
Rank 14: Token 'In' with log probability -4.6875
Rank 15: Token 'It' with log probability -4.75
Rank 16: Token 'French' with log probability -4.875
Rank 17: Token ' In' with log probability -4.9375
Rank 18: Token ' and' with log probability -5.0
Rank 19: Token ' We' with log probability -5.0
Rank 20: Token 'Paris' with log probability -5.0625
Rank 21: Token 'We' with

In [65]:
np.linalg.norm(np.array(out_jax[0, -1, :]) - out_torch[0, -1, :].detach().numpy())

953.33966

In [66]:
np.array(out_jax[0, -1, :])

array([8.8125, 4.53125, 12.0625, ..., -0.486328, -0.486328, -0.486328],
      dtype=bfloat16)

In [67]:
out_torch[0, -1, :].detach().numpy()

array([10.23513   ,  8.626071  ,  4.6638365 , ...,  0.8458147 ,
        0.846267  ,  0.84582824], dtype=float32)

Here we see a significant degredation as the errors compound through the model - however the desired token still does appear in the prediction list. This is likely to do with the datatypes used througout - as the bfloat 16 can be hurting the performances. And there may be underlying implementation differences in the Jax vs pytorch routines.

### Just in time compilation

In [73]:
jitted_transformer = jax.jit(transformer_jax, static_argnames=["n_heads", "n_kv_heads"])

In [74]:
jitted_transformer(toks, params_jax, mask, n_heads = 32, n_kv_heads = 8)

Array([[[9.625, 12.0625, 14.1875, ..., -5.15625, -5.15625, -5.15625],
        [6.0625, 9.125, 6.9375, ..., -0.570312, -0.570312, -0.570312],
        [12.6875, 10.4375, 7.25, ..., 0.351562, 0.351562, 0.351562],
        ...,
        [7.8125, 4.34375, 12.3125, ..., -1.23438, -1.23438, -1.23438],
        [8.5, 4.59375, 12.125, ..., -0.847656, -0.847656, -0.847656],
        [8.875, 4.59375, 12.0625, ..., -0.507812, -0.507812, -0.507812]]],      dtype=bfloat16)

In [75]:
jitted_transformer(toks, params_jax, mask, n_heads = 32, n_kv_heads = 8)

Array([[[9.625, 12.0625, 14.1875, ..., -5.15625, -5.15625, -5.15625],
        [6.0625, 9.125, 6.9375, ..., -0.570312, -0.570312, -0.570312],
        [12.6875, 10.4375, 7.25, ..., 0.351562, 0.351562, 0.351562],
        ...,
        [7.8125, 4.34375, 12.3125, ..., -1.23438, -1.23438, -1.23438],
        [8.5, 4.59375, 12.125, ..., -0.847656, -0.847656, -0.847656],
        [8.875, 4.59375, 12.0625, ..., -0.507812, -0.507812, -0.507812]]],      dtype=bfloat16)

We see that doing just in time compilation can roughly halve the time it takes to run the model in jax

## Conclusions

Whilst the Jax implementation does leave a lot to be desired, we can see that it does roughly match the pytorch. We aren't intending to use it for generative purposes, just as a way to access the hidden layers' values.

The functional implementation exposes the logic more cleanly and hopefully will make further analysis easier.