In [29]:
import jax
import jax.numpy as jnp
from safetensors import safe_open
from jax.sharding import PartitionSpec as PS,NamedSharding,Mesh
from typing import Optional
from pathlib import Path
from models.llama.load import load_llama_weights
from models.llama.model import LLaMA
from utils.kvcache import KVCache

In [30]:
checkpoint_path = Path("/home/ammar3.shaikh/ReLax/artifacts/weights/llama-3.2-3B")

In [31]:
default_mesh = jax.make_mesh((2, 2), ('a', 'b'))

def mesh_sharding(
    pspec: PS, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)

In [32]:
params = load_llama_weights(checkpoint_path)

In [33]:
print(params.keys())

frozen_dict_keys(['tok_embeddings', 'norm_weight', 'output', 'layers_0', 'layers_1', 'layers_2', 'layers_3', 'layers_4', 'layers_5', 'layers_6', 'layers_7', 'layers_8', 'layers_9', 'layers_10', 'layers_11', 'layers_12', 'layers_13', 'layers_14', 'layers_15', 'layers_16', 'layers_17', 'layers_18', 'layers_19', 'layers_20', 'layers_21', 'layers_22', 'layers_23', 'layers_24', 'layers_25', 'layers_26', 'layers_27'])


In [6]:
import json
from pathlib import Path
from models.llama.config import ModelConfig

tmp_path = Path("/home/ammar3.shaikh/ReLax/experiments")

config_data = {
    "hidden_size": 64,
    "num_hidden_layers": 2,
    "num_attention_heads": 4,
    "num_key_value_heads": 2,
    "intermediate_size": 128,
    "vocab_size": 1000,
    "rms_norm_eps": 1e-6,
    "rope_theta": 1000.0,
    "max_position_embeddings": 512,
    "hidden_act": "silu"
}

config_path = tmp_path / "config.json"
with open(config_path, 'w') as f:
    json.dump(config_data, f)
    



In [25]:
model_config = ModelConfig.from_json_file("/home/ammar3.shaikh/ReLax/experiments")

print(model_config)

ModelConfig(vocab_size=1000, dim=64, ffn_hidden_dim=128, n_layers=2, n_heads=4, n_kv_heads=2, activation_fn='silu', max_seq_len=512, rope_theta=1000.0, rms_norm_eps=1e-06, mode='inference')


In [26]:
model = LLaMA(model_config)


In [27]:
tokens = jnp.array([[1,2,3,4,5,6,7,8,9,10]])
kv_cache = KVCache.new(n_layers=model_config.n_layers,bsz=1,max_seq_len=model_config.max_seq_len,kv_heads=model_config.n_kv_heads,head_dim=model_config.head_dim)
start_pos = 0
params = model.init(jax.random.PRNGKey(0),tokens,start_pos,kv_cache)["params"]










In [28]:
def print_pytree_keys(pytree):
  """Prints the full path to each leaf in a PyTree."""
  jax.tree_util.tree_map_with_path(
      # The lambda function is called for each leaf.
      # `path` is a tuple of objects describing the route to the leaf.
      # `_` is the leaf value itself, which we ignore here.
      lambda path, _: print(jax.tree_util.keystr(path)),
      pytree
  )

print_pytree_keys(params)

['layer_0']['attention_norm_weight']
['layer_0']['ffn_norm_weight']
['layer_0']['w1_gate']
['layer_0']['w2_up']
['layer_0']['w3_down']
['layer_0']['wk']
['layer_0']['wo']
['layer_0']['wq']
['layer_0']['wv']
['layer_1']['attention_norm_weight']
['layer_1']['ffn_norm_weight']
['layer_1']['w1_gate']
['layer_1']['w2_up']
['layer_1']['w3_down']
['layer_1']['wk']
['layer_1']['wo']
['layer_1']['wq']
['layer_1']['wv']
['norm_weight']
['output']['kernel']
['tok_embeddings']['embedding']


(64, 4, 16)