In [1]:
import sys
import os
sys.path.append(os.path.abspath('..'))
import jax
import jax.numpy as jnp
from jax import random, jit
from functools import partial


from models.llama.model import LLaMa
from models.llama.config import ModelConfig
from utils.kvcache import KVCache
from models.llama.load import load_llama_weights, pth_to_safetensors
from utils.memory import estimate_pytree_memory_footprint, format_bytes
from models.llama.tokenizer import Tokenizer
from sampling import Sampler
from sampling import TopPSampler

devices = jax.devices()
jax.config.update("jax_default_matmul_precision", "highest")
jax.config.update("jax_enable_x64", True)










In [2]:
model_path = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX/model.safetensors"
max_seqlen = 2048

model_config = ModelConfig.from_json_file("/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX")










In [3]:
model_config

ModelConfig(vocab_size=128256, dim=3072, ffn_hidden_dim=8192, n_layers=28, n_heads=24, n_kv_heads=8, activation_fn='silu', max_seqlen=8192, rope_theta=500000.0, rms_norm_eps=1e-05, mode='inference', dtype='float32', use_scaled_rope=True)

In [4]:
params = load_llama_weights(model_path, model_config)

In [5]:
# To inspect the PyTree, you can print the shape of each leaf node (array).
# This preserves the tree structure.
param_shapes = jax.tree_util.tree_map(lambda x: x.shape, params)
format_bytes(estimate_pytree_memory_footprint(params))

'13.44GB'

In [6]:
model = LLaMa(model_config)

In [7]:
# JAX init
key = jax.random.PRNGKey(0)
batch_size = 1
seq_len = 512
dummy_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
dummy_kvcache = KVCache.new(n_layers=model_config.n_layers, bsz=batch_size, max_seqlen=model_config.max_seqlen, kv_heads=model_config.n_kv_heads, head_dim=model_config.head_dim, dtype=model_config.dtype)

format_bytes(estimate_pytree_memory_footprint(dummy_kvcache))

'1.75GB'

In [8]:
logits = model.apply({"params":params}, dummy_tokens, 0, dummy_kvcache)

In [3]:
output_path = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX"
pth_model_path = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B/original/consolidated.00.pth"
config_dir = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B/"
pth_to_safetensors(pth_model_path, config_dir, output_path)

Model weights saved to /home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX/model.safetensors


In [9]:
import torch

pth_model_path = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B/original/consolidated.00.pth"
tensors = torch.load(pth_model_path, map_location="cpu")

print(tensors.keys())

dict_keys(['tok_embeddings.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wv.weight', 'layers.0.attention.wo.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.attention_norm.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wv.weight', 'layers.1.attention.wo.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.attention_norm.weight', 'layers.1.ffn_norm.weight', 'layers.2.attention.wq.weight', 'layers.2.attention.wk.weight', 'layers.2.attention.wv.weight', 'layers.2.attention.wo.weight', 'layers.2.feed_forward.w1.weight', 'layers.2.feed_forward.w3.weight', 'layers.2.feed_forward.w2.weight', 'layers.2.attention_norm.weight', 'layers.2.ffn_norm.weight', 'layers.3.attention.wq.weight', 'layers.3.attention.wk.weight', 'layers.3.atten

In [10]:
logits,updated_kvcache = logits

In [11]:
print(logits.shape)

(1, 512, 128256)
