In [1]:
import jax
import jax.numpy as jnp


from models.llama.model import LLaMa
from models.llama.config import ModelConfig
from models.llama.tokenizer import Tokenizer
from models.llama.load import load_llama_weights
from utils.memory import estimate_pytree_memory_footprint, format_bytes
from utils.kvcache import KVCache

jax.config.update("jax_default_matmul_precision", "float32")


In [2]:
ckpt_dir = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX"
tokenizer_path = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX/tokenizer.model"

In [3]:

model_config = ModelConfig.from_json_file(ckpt_dir)

tokenizer = Tokenizer(tokenizer_path)
print("Model config: ", model_config)

Model config:  ModelConfig(vocab_size=128256, dim=3072, ffn_hidden_dim=8192, n_layers=28, n_heads=24, n_kv_heads=8, activation_fn='silu', max_seqlen=8192, rope_theta=500000.0, rms_norm_eps=1e-05, mode='inference', dtype='float32', use_scaled_rope=True)


In [None]:
# Load model weights
params = load_llama_weights(ckpt_dir+"/model.safetensors", model_config)

# Initialize the model
model = LLaMa(model_config)

# Create a JAX random key
seed = 1
rng_key = jax.random.PRNGKey(seed)


# Estimate and print memory usage
params_size_bytes = estimate_pytree_memory_footprint(params)
print(f"Estimated model params size: {format_bytes(params_size_bytes)}")


In [5]:
model_config

ModelConfig(vocab_size=128256, dim=3072, ffn_hidden_dim=8192, n_layers=28, n_heads=24, n_kv_heads=8, activation_fn='silu', max_seqlen=8192, rope_theta=500000.0, rms_norm_eps=1e-05, mode='inference', dtype='float32', use_scaled_rope=True)

In [None]:
# JAX init
batch_size = 2
seq_len = 512
dummy_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
dummy_kvcache = KVCache.new(n_layers=model_config.n_layers, bsz=batch_size, max_seqlen=model_config.max_seqlen, kv_heads=model_config.n_kv_heads, head_dim=model_config.head_dim, dtype=model_config.dtype)

kvcache_size_bytes = estimate_pytree_memory_footprint(dummy_kvcache)
print(f"Estimated kvcache size: {format_bytes(kvcache_size_bytes)}")

Estimated kvcache size: 3.50GB


In [7]:
# To inspect the PyTree, you can print the shape of each leaf node (array).
# This preserves the tree structure.
param_shapes = jax.tree_util.tree_map(lambda x: x.shape, params)
param_shapes

{'layer_0': {'attention_norm_weight': (3072,),
  'ffn_norm_weight': (3072,),
  'w_down': (8192, 3072),
  'w_gate': (3072, 8192),
  'w_up': (3072, 8192),
  'wk': (3072, 8, 128),
  'wo': (3072, 3072),
  'wq': (3072, 24, 128),
  'wv': (3072, 8, 128)},
 'layer_1': {'attention_norm_weight': (3072,),
  'ffn_norm_weight': (3072,),
  'w_down': (8192, 3072),
  'w_gate': (3072, 8192),
  'w_up': (3072, 8192),
  'wk': (3072, 8, 128),
  'wo': (3072, 3072),
  'wq': (3072, 24, 128),
  'wv': (3072, 8, 128)},
 'layer_10': {'attention_norm_weight': (3072,),
  'ffn_norm_weight': (3072,),
  'w_down': (8192, 3072),
  'w_gate': (3072, 8192),
  'w_up': (3072, 8192),
  'wk': (3072, 8, 128),
  'wo': (3072, 3072),
  'wq': (3072, 24, 128),
  'wv': (3072, 8, 128)},
 'layer_11': {'attention_norm_weight': (3072,),
  'ffn_norm_weight': (3072,),
  'w_down': (8192, 3072),
  'w_gate': (3072, 8192),
  'w_up': (3072, 8192),
  'wk': (3072, 8, 128),
  'wo': (3072, 3072),
  'wq': (3072, 24, 128),
  'wv': (3072, 8, 128)},
 '

In [10]:
logits,updated_kvcache = model.apply({"params": params}, dummy_tokens, 0, dummy_kvcache)

Input embeddings: [[[ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  ...
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]]

 [[ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  ...
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.00029373]
  [ 0.01342773  0.001297    0.02099609 ... -0.0378418  -0.015625
   -0.000293

Attention output: [[[-0.06846347  0.53539747  2.2893229  ...  0.09503698 -3.3458726
    0.13851827]
  [-0.06846341  0.5353974   2.2893229  ...  0.09503698 -3.3458712
    0.13851807]
  [-0.06846352  0.53539747  2.2893229  ...  0.09503698 -3.3458724
    0.1385183 ]
  ...
  [-0.06846346  0.53539747  2.2893238  ...  0.09503704 -3.3458729
    0.13851823]
  [-0.06846346  0.5353975   2.2893236  ...  0.095037   -3.3458726
    0.13851821]
  [-0.06846347  0.53539735  2.2893233  ...  0.09503701 -3.345872
    0.13851812]]

 [[-0.06846347  0.53539747  2.2893229  ...  0.09503698 -3.3458726
    0.13851827]
  [-0.06846341  0.5353974   2.2893229  ...  0.09503698 -3.3458712
    0.13851807]
  [-0.06846352  0.53539747  2.2893229  ...  0.09503698 -3.3458724
    0.1385183 ]
  ...
  [-0.06846346  0.53539747  2.2893238  ...  0.09503704 -3.3458729
    0.13851823]
  [-0.06846346  0.5353975   2.2893236  ...  0.095037   -3.3458726
    0.13851821]
  [-0.06846347  0.53539735  2.2893233  ...  0.09503701 -3.345872
  

(3072, 8192)

In [None]:
print(logits[:10, :10])