# Llama to Jax

In this notebook we test the conversion of the Llama architecture to use Jax as the backend, by isolating the components individually and ensuring that they yield the same results, as function.

In [28]:
import jax
import jax.numpy as jnp
import torch

In [1]:
import sys
import os

os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [4]:
from safetensors.torch import load_file
safetensors_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/model.safetensors"  
weights = load_file(safetensors_path)


## `RMSNorm`

Here we'll demonstrate how to create jax arrays from the pytorch parameters, and see how our jax implementation of `RMSNorm` compares to the one we already know is a part of the working model

In [None]:
# load the functions to compare
from llama.model import RMSNorm as RMSNorm_pt
print("loaded source norm")
from llama_jax.model import RMSNorm as RMSNorm_jax
print("loaded jax norm")

loaded source norm
loaded jax norm


In [None]:
# check the weights to find a good demo tensor
[key for key in weights.keys() if 'norm' in key][0:2]

['model.layers.0.input_layernorm.weight',
 'model.layers.0.post_attention_layernorm.weight']

In [None]:
# pytorch tensor --> jax array
rms_weights_pt = weights['model.layers.0.input_layernorm.weight']
rms_weights_jax = jnp.array(rms_weights_pt.detach().float().numpy(), dtype=jnp.bfloat16)

print(f"pytorch: {rms_weights_jax}, jax: {rms_weights_jax}")
print(f"pytorch shape: {rms_weights_pt.shape}, jax shape: {rms_weights_jax.shape}")



In [None]:
# put the weights in an isolated pytorch RMSNorm module
rms_norm_pt = RMSNorm_pt(2048)

with torch.no_grad():
    # overwrite the RMSNorm weight with the one from the loaded state_dict
    rms_norm_pt.weight.copy_(rms_weights_pt)

rms_norm_pt

In [None]:
# run the pytorch version on a sample tensor to get a "true" value

with torch.no_grad():
    x_torch = torch.randn(2, 2048) # add a batch dim

    y_torch = rms_norm_pt(x_torch)
    print("PyTorch output:", y_torch)

PyTorch output: tensor([[-0.1956, -0.0468, -0.1102,  ..., -0.2386,  0.2464, -0.1231],
        [ 0.0766, -0.4311, -0.1458,  ..., -0.3946, -0.2244,  0.1960]])


In [None]:
# call our jax implementation and check the output is the same

rms_norm_jax = lambda x : RMSNorm_jax(x, rms_weights_jax)

x_jax = jnp.array(x_torch.detach().numpy())
y_jax = rms_norm_jax(x_jax)

print("Jax output:", y_jax)

Jax output: [[-0.19556196 -0.04684718 -0.11015126 ... -0.23862486  0.24641658
  -0.12314538]
 [ 0.0766468  -0.43110186 -0.14580318 ... -0.39460558 -0.22443385
   0.19601265]]


It all looks good :), the pytorch outputs `bfloat16` too, whereas jax has higher precision. We'll worry about that later, since I don't know whether that is the correct behaviour when the function is just one step in the overall architecture (for the pytorch implementation)