# Llama to Jax

In this notebook we test the conversion of the Llama architecture to use Jax as the backend, by isolating the components individually and ensuring that they yield the same results, as function.

In [1]:
import jax
import jax.numpy as jnp
import torch
import numpy as np

In [2]:
import sys
import os

os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [3]:
from safetensors.torch import load_file
safetensors_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/model.safetensors"  
weights = load_file(safetensors_path)


In [4]:
from llama.model import ModelArgs
import json

config_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/config.json"
with open(config_path, "r") as f:
    config = json.load(f)

# extract the necessary fields
model_args = ModelArgs(
    dim=config.get("hidden_size", 4096), 
    n_layers=config.get("num_hidden_layers", 32),  
    n_heads=config.get("num_attention_heads", 32), 
    n_kv_heads=config.get("num_key_value_heads", None), 
    vocab_size=config.get("vocab_size", -1), 
    multiple_of=256, # not in config so use the default
    norm_eps=config.get("rms_norm_eps", 1e-5),  # map "rms_norm_eps"
    max_batch_size=32,  # not in config so use the default
    use_scale_rope=True, # this is how it was in llama.ipynb
    rope_scale_factor=config.get("rope_scaling").get("factor"),
    original_rotary_embed_len=config.get("rope_scaling").get("original_max_position_embeddings"),
    cache_len = 2048,
)

print(model_args)

ModelArgs(dim=2048, n_layers=16, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=256, norm_eps=1e-05, rope_theta=500000, use_scaled_rope=True, rope_scale_factor=32.0, max_batch_size=32, original_rotary_embed_len=8192, cache_len=2048)


## `RMSNorm`

Here we'll demonstrate how to create jax arrays from the pytorch parameters, and see how our jax implementation of `RMSNorm` compares to the one we already know is a part of the working model

In [5]:
# load the functions to compare
from llama.model import RMSNorm as RMSNorm_pt
print("loaded source norm")
from llama_jax.model import RMSNorm as RMSNorm_jax
print("loaded jax norm")

loaded source norm
loaded jax norm


In [6]:
# check the weights to find a good demo tensor
[key for key in weights.keys() if 'norm' in key][0:2]

['model.layers.0.input_layernorm.weight',
 'model.layers.0.post_attention_layernorm.weight']

In [7]:
# pytorch tensor --> jax array
rms_weights_pt = weights['model.layers.0.input_layernorm.weight']
rms_weights_jax = jnp.array(rms_weights_pt.detach().float().numpy(), dtype=jnp.bfloat16)

print(f"pytorch: {rms_weights_jax}, jax: {rms_weights_jax}")
print(f"pytorch shape: {rms_weights_pt.shape}, jax shape: {rms_weights_jax.shape}")



pytorch: [0.158203 0.180664 0.269531 ... 0.22168 0.210938 0.152344], jax: [0.158203 0.180664 0.269531 ... 0.22168 0.210938 0.152344]
pytorch shape: torch.Size([2048]), jax shape: (2048,)


In [8]:
# put the weights in an isolated pytorch RMSNorm module
rms_norm_pt = RMSNorm_pt(2048)

with torch.no_grad():
    # overwrite the RMSNorm weight with the one from the loaded state_dict
    rms_norm_pt.weight.copy_(rms_weights_pt)

rms_norm_pt

RMSNorm()

In [9]:
# run the pytorch version on a sample tensor to get a "true" value

with torch.no_grad():
    x_torch = torch.randn(2, 2048) # add a batch dim

    y_torch = rms_norm_pt(x_torch)
    print("PyTorch output:", y_torch)

PyTorch output: tensor([[-0.2999,  0.0068, -0.6558,  ..., -0.4704,  0.0800, -0.0023],
        [-0.0136,  0.0844, -0.1716,  ..., -0.0436, -0.2050,  0.0508]])


In [10]:
# call our jax implementation and check the output is the same

rms_norm_jax = lambda x : RMSNorm_jax(x, rms_weights_jax)

x_jax = jnp.array(x_torch.detach().numpy())
y_jax = rms_norm_jax(x_jax)

print("Jax output:", y_jax)

Jax output: [[-0.29990137  0.00675914 -0.65584993 ... -0.470414    0.07996918
  -0.00227949]
 [-0.0135715   0.0843863  -0.17162883 ... -0.04360401 -0.20500495
   0.05083428]]


It all looks good :), the pytorch outputs `bfloat16` too, whereas jax has higher precision. We'll worry about that later, since I don't know whether that is the correct behaviour when the function is just one step in the overall architecture (for the pytorch implementation)

## `precompute_freq_cis`



In [11]:
from llama.model import precompute_freqs_cis as precompute_freqs_cis_pt
from llama_jax.model import precompute_freqs_cis as precompute_freqs_cis_jax

In [12]:
dim = 2048 // 32 

freq_cis_pt = precompute_freqs_cis_pt(dim, 2048, 500_000, True, 32, 8192)
print(f"{freq_cis_pt[1000,:]=}")
print(f"{freq_cis_pt.dtype=}")

freq_cis_pt[1000,:]=tensor([ 0.5624+8.2688e-01j, -0.7484-6.6329e-01j,  0.8558+5.1728e-01j,
        -0.9982-5.9692e-02j,  0.6555-7.5522e-01j, -0.9931+1.1765e-01j,
        -0.8397-5.4308e-01j,  0.9927+1.2066e-01j,  0.9957-9.2948e-02j,
         0.9843-1.7641e-01j, -0.6581-7.5291e-01j, -0.0060-9.9998e-01j,
         0.5323+8.4655e-01j,  0.1267-9.9194e-01j, -0.9976-6.9797e-02j,
        -0.5315+8.4708e-01j,  0.1559+9.8777e-01j,  0.5910+8.0666e-01j,
         0.9998+1.9460e-02j,  0.9999+1.2914e-02j,  1.0000+8.5702e-03j,
         1.0000+5.6872e-03j,  1.0000+3.7740e-03j,  1.0000+2.5045e-03j,
         1.0000+1.6620e-03j,  1.0000+1.1029e-03j,  1.0000+7.3187e-04j,
         1.0000+4.8567e-04j,  1.0000+3.2229e-04j,  1.0000+2.1387e-04j,
         1.0000+1.4193e-04j,  1.0000+9.4183e-05j])
freq_cis_pt.dtype=torch.complex64


In [13]:
freq_cis_jax = precompute_freqs_cis_jax(dim, 2048, 500_000, True, 32, 8192)
print(f"{freq_cis_jax[1000,:]=}")
print(f"{freq_cis_jax.dtype=}")


freq_cis_jax[1000,:]=Array([ 0.56237906+8.2687956e-01j, -0.7483619 -6.6329068e-01j,
        0.85581774+5.1727748e-01j, -0.99821687-5.9691951e-02j,
        0.6554753 -7.5521660e-01j, -0.9930554 +1.1764777e-01j,
       -0.839681  -5.4307991e-01j,  0.99269414+1.2065805e-01j,
        0.995671  -9.2947975e-02j,  0.98431766-1.7640516e-01j,
       -0.65812033-7.5291276e-01j, -0.00604559-9.9998170e-01j,
        0.53230125+8.4655499e-01j,  0.12669091-9.9194223e-01j,
       -0.9975612 -6.9796599e-02j, -0.53146   +8.4708339e-01j,
        0.15594384+9.8776591e-01j,  0.59101975+8.0665708e-01j,
        0.99981064+1.9460410e-02j,  0.9999166 +1.2914409e-02j,
        0.9999633 +8.5701505e-03j,  0.99998385+5.6872014e-03j,
        0.9999929 +3.7740455e-03j,  0.99999684+2.5044645e-03j,
        0.9999986 +1.6619666e-03j,  0.9999994 +1.1028834e-03j,
        0.99999976+7.3187490e-04j,  0.9999999 +4.8567311e-04j,
        0.99999994+3.2229329e-04j,  1.        +2.1387423e-04j,
        1.        +1.4192720e-04j,

In [14]:
np.linalg.norm(freq_cis_pt.detach().numpy() - np.array(freq_cis_jax))

0.0

these match too, since the arrays are bigger, we check by looking at the norm too - which shows they are identical.

## `apply_rotary_emb`

In [15]:
from llama.model import apply_rotary_emb as apply_rotary_emb_pt
from llama_jax.model import apply_rotary_emb as apply_rotary_emb_jax

In [16]:
# setting up dummy data

bsz = 1  # batch size
seqlen = 30 # dummy value
n_local_heads = 32 # no parallelism so local = total
head_dim = 2048 // 32

dummy_shape = (bsz, seqlen, n_local_heads, head_dim)

with torch.no_grad():
    freq_cis_pt = precompute_freqs_cis_pt(dim, seqlen, 500_000, True, 32, 8192)
    xq_torch = torch.randn(dummy_shape) 
    xk_torch = torch.randn(dummy_shape)

freq_cis_jax = precompute_freqs_cis_jax(dim, seqlen, 500_000, True, 32, 8192)
xq_jax = jnp.array(xq_torch.detach().numpy())
xk_jax = jnp.array(xk_torch.detach().numpy())

In [17]:
with torch.no_grad():
    yq_torch, yk_torch = apply_rotary_emb_pt(xq_torch, xk_torch, freq_cis_pt)

yq_jax, yk_jax = apply_rotary_emb_jax(xq_jax, xk_jax, freq_cis_jax)

In [18]:
print(f"{yq_torch.shape=}")
print(f"{yq_jax.shape=}")

print(f"{yk_torch.shape=}")
print(f"{yk_jax.shape=}")

print(f"{yq_torch.dtype=}")
print(f"{yq_jax.dtype=}")

yq_torch.shape=torch.Size([1, 30, 32, 64])
yq_jax.shape=(1, 30, 32, 64)
yk_torch.shape=torch.Size([1, 30, 32, 64])
yk_jax.shape=(1, 30, 32, 64)
yq_torch.dtype=torch.float32
yq_jax.dtype=dtype('float32')


In [19]:
print(f"q error: {np.linalg.norm(yq_torch.detach().numpy() - np.array(yq_jax))}")
print(f"k error: {np.linalg.norm(yk_torch.detach().numpy() - np.array(yk_jax))}")

q error: 7.464219379471615e-06
k error: 7.546961569460109e-06


there is a slight difference in the values here, but over the vast amount of dummy data it is well within numerical tolerances

## `attention_block`

In [20]:
from llama.model import Attention as Attention_pt
from llama_jax.model import attention_block as attention_block_jax

In [21]:
[key for key in weights.keys() if '10' in key]

['model.layers.10.input_layernorm.weight',
 'model.layers.10.mlp.down_proj.weight',
 'model.layers.10.mlp.gate_proj.weight',
 'model.layers.10.mlp.up_proj.weight',
 'model.layers.10.post_attention_layernorm.weight',
 'model.layers.10.self_attn.k_proj.weight',
 'model.layers.10.self_attn.o_proj.weight',
 'model.layers.10.self_attn.q_proj.weight',
 'model.layers.10.self_attn.v_proj.weight']

In [42]:
# get all the required parameters for one attention block
# check that the conversion leaves shape and dtype invariant

layer10_keys = [key for key in weights.keys() if ('10' in key) and ('norm' not in key) and ('mlp' not in key)]
layer10_weights_pt = { key : weights[key] for key in layer10_keys}
layer10_weights_jax = { key : jnp.array(weights[key].detach().float().numpy(), dtype = jnp.bfloat16) for key in layer10_keys}

print("pytorch params:")
for key, value in layer10_weights_pt.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

print("\njax params")
for key, value in layer10_weights_jax.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

pytorch params:
    model.layers.10.self_attn.k_proj.weight dtype: torch.bfloat16 shape: torch.Size([512, 2048])
    model.layers.10.self_attn.o_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 2048])
    model.layers.10.self_attn.q_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 2048])
    model.layers.10.self_attn.v_proj.weight dtype: torch.bfloat16 shape: torch.Size([512, 2048])

jax params
    model.layers.10.self_attn.k_proj.weight dtype: bfloat16 shape: (512, 2048)
    model.layers.10.self_attn.o_proj.weight dtype: bfloat16 shape: (2048, 2048)
    model.layers.10.self_attn.q_proj.weight dtype: bfloat16 shape: (2048, 2048)
    model.layers.10.self_attn.v_proj.weight dtype: bfloat16 shape: (512, 2048)


In [43]:
with torch.no_grad():
    attention_pt = Attention_pt(model_args)

with torch.no_grad():
    attention_pt.wq.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.q_proj.weight"])
    attention_pt.wk.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.k_proj.weight"])
    attention_pt.wv.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.v_proj.weight"])
    attention_pt.wo.weight.copy_(layer10_weights_pt["model.layers.10.self_attn.o_proj.weight"])

attention block initialised


In [44]:
dummy_x_shape = (bsz, seqlen, 2048)

In [45]:
x_torch = torch.randn(dummy_x_shape) 
y_torch = attention_pt.forward(x_torch, 0, freq_cis_pt, None)

In [46]:
x_jax = jnp.array(x_torch.detach().numpy())

y_jax = attention_block_jax(x_jax, None, freq_cis_jax,
    wq = layer10_weights_jax["model.layers.10.self_attn.q_proj.weight"],
    wk = layer10_weights_jax["model.layers.10.self_attn.k_proj.weight"],
    wv = layer10_weights_jax["model.layers.10.self_attn.v_proj.weight"],
    wo = layer10_weights_jax["model.layers.10.self_attn.o_proj.weight"],
    n_heads = 32,
    n_kv_heads = 8)

In [47]:
np.linalg.norm(y_torch.detach().numpy() - np.array(y_jax))

3.3401127e-05

###  `AttentionBlock` with a mask

In [49]:
zeros = torch.zeros(20) 
neg_inf = torch.full((10,), float('-inf')) 
mask_pt = torch.cat([zeros, neg_inf]) 
mask_pt

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf,
        -inf, -inf, -inf, -inf, -inf, -inf])

In [50]:
y_torch = attention_pt.forward(x_torch, 0, freq_cis_pt, mask_pt)

In [51]:
mask_jax = jnp.array(mask_pt.detach().numpy())

y_jax = attention_block_jax(x_jax, mask_jax, freq_cis_jax,
    wq = layer10_weights_jax["model.layers.10.self_attn.q_proj.weight"],
    wk = layer10_weights_jax["model.layers.10.self_attn.k_proj.weight"],
    wv = layer10_weights_jax["model.layers.10.self_attn.v_proj.weight"],
    wo = layer10_weights_jax["model.layers.10.self_attn.o_proj.weight"],
    n_heads = 32,
    n_kv_heads = 8)

In [52]:
np.linalg.norm(y_torch.detach().numpy() - np.array(y_jax))

3.4069857e-05

## `FeedForward`

This is a special MLP implementation that passes a through a linear product of the values, which tempers the non-linearity of the `silu` activation function

In [28]:
from llama.model import FeedForward as FeedForward_pt
from llama_jax.model import feed_forward as feed_forward_jax

In [29]:
[key for key in weights.keys() if ('10' in key) and ('mlp' in key)]

# down -> w2
# up   -> w3
# gate -> w1

['model.layers.10.mlp.down_proj.weight',
 'model.layers.10.mlp.gate_proj.weight',
 'model.layers.10.mlp.up_proj.weight']

In [30]:
layer10_keys = [key for key in weights.keys() if ('10' in key) and ('mlp' in key)]
layer10_weights_pt = { key : weights[key] for key in layer10_keys}
layer10_weights_jax = { key : jnp.array(weights[key].detach().float().numpy(), dtype = jnp.bfloat16) for key in layer10_keys}

print("pytorch params:")
for key, value in layer10_weights_pt.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

print("\njax params")
for key, value in layer10_weights_jax.items():
    print(f"    {key} dtype: {value.dtype} shape: {value.shape}")

pytorch params:
    model.layers.10.mlp.down_proj.weight dtype: torch.bfloat16 shape: torch.Size([2048, 8192])
    model.layers.10.mlp.gate_proj.weight dtype: torch.bfloat16 shape: torch.Size([8192, 2048])
    model.layers.10.mlp.up_proj.weight dtype: torch.bfloat16 shape: torch.Size([8192, 2048])

jax params
    model.layers.10.mlp.down_proj.weight dtype: bfloat16 shape: (2048, 8192)
    model.layers.10.mlp.gate_proj.weight dtype: bfloat16 shape: (8192, 2048)
    model.layers.10.mlp.up_proj.weight dtype: bfloat16 shape: (8192, 2048)


In [31]:
with torch.no_grad():
    feed_forward_pt = FeedForward_pt(2048, 4 * 2048, 256)

with torch.no_grad():
    feed_forward_pt.w1.weight.copy_(layer10_weights_pt["model.layers.10.mlp.gate_proj.weight"])
    feed_forward_pt.w2.weight.copy_(layer10_weights_pt["model.layers.10.mlp.down_proj.weight"])
    feed_forward_pt.w3.weight.copy_(layer10_weights_pt["model.layers.10.mlp.up_proj.weight"])

In [32]:
y_torch = feed_forward_pt.forward(x_torch)

In [33]:
y_jax = feed_forward_jax(x_jax,
    gate = layer10_weights_jax["model.layers.10.mlp.gate_proj.weight"], 
    down = layer10_weights_jax["model.layers.10.mlp.down_proj.weight"], 
    up   = layer10_weights_jax["model.layers.10.mlp.up_proj.weight"])

In [34]:
np.linalg.norm(y_torch.detach().numpy() - np.array(y_jax))

8.845371e-05