## Converting and saving the weights

In [5]:
import sys
import os

os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [6]:
import torch
from safetensors.torch import load_file
safetensors_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/model.safetensors"  
weights = load_file(safetensors_path)

In [None]:
from llama_jax.model import precompute_freqs_cis

dim = 2048 // 32 
theta = 500_000
freq_cis_jax = precompute_freqs_cis(dim, 2048, theta, True, 32, 8192)



In [None]:
from tqdm import tqdm
import jax.numpy as jnp

to_jax = lambda pt : jnp.array(pt.detach().float().numpy(), dtype=jnp.bfloat16)

num_layers = 16
layer_weights = []

keys_left = weights.keys()

# reverse the range because we'll remove keys we have seen already
# this will help because e.g. 1 is in 15, but 15 is not in 1 

layer_params = []

for i in tqdm(reversed(range(num_layers)), total = num_layers):

    layer_keys = [key for key in keys_left if (str(i) in key)]
    keys_left  = list(set(keys_left) - set(layer_keys))
    
    attention_params = {
        "wq" : to_jax(weights[f"model.layers.{i}.self_attn.q_proj.weight"]),
        "wk" : to_jax(weights[f"model.layers.{i}.self_attn.k_proj.weight"]),
        "wv" : to_jax(weights[f"model.layers.{i}.self_attn.v_proj.weight"]),
        "wo" : to_jax(weights[f"model.layers.{i}.self_attn.o_proj.weight"])
    }

    ff_params = {
        "up"   : to_jax(weights[f"model.layers.{i}.mlp.up_proj.weight"]),
        "gate" : to_jax(weights[f"model.layers.{i}.mlp.gate_proj.weight"]), 
        "down" : to_jax(weights[f"model.layers.{i}.mlp.down_proj.weight"])
    }

    norm_params = {
        "pre_attention_rms"  : to_jax(weights[f"model.layers.{i}.input_layernorm.weight"]),
        "post_attention_rms" : to_jax(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
    }

    param_pytree = {
        "attention"    : attention_params,
        "feed_forward" : ff_params,
        "norms"        : norm_params
    }

    layer_params.append(param_pytree)

layer_params.reverse()

params_jax = {
    "tok_embeddings" : to_jax(weights["model.embed_tokens.weight"]),
    "freqs_cis"      : freq_cis_jax,
    "layers"         : layer_params,
    "norm_scale"     : to_jax(weights["model.norm.weight"]),
    "output_weight"  : to_jax(weights["model.embed_tokens.weight"])
}

100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:12<00:00,  1.29it/s]


In [14]:
import pickle

with open("Data/ModelWeights/llama_jax_weights.pkl", "wb") as f:
    pickle.dump(params_jax, f)

In [1]:
import pickle

with open("Data/ModelWeights/llama_jax_weights.pkl", "rb") as f:
    params_jax_loaded = pickle.load(f)

