In [1]:
import jax
import jax.numpy as jnp
from safetensors import safe_open
from jax.sharding import PartitionSpec as PS,NamedSharding,Mesh
from typing import Optional
from pathlib import Path
from models.llama.load import load_llama_weights
from models.llama.model import LLaMA
from utils.kvcache import KVCache
from models.llama.config import ModelConfig

In [2]:
checkpoint_path = Path("/home/ammar3.shaikh/ReLax/artifacts/weights/llama-3.2-3B")

In [3]:
default_mesh = jax.make_mesh((2, 2), ('a', 'b'))

def mesh_sharding(
    pspec: PS, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)

In [4]:
model_config = ModelConfig.from_json_file(checkpoint_path)

In [5]:
params = load_llama_weights(checkpoint_path)

In [6]:
jax.debug.visualize_array_sharding(params["norm_weight"])

In [7]:
def print_pytree_keys(pytree):
  """Prints the full path to each leaf in a PyTree."""
  jax.tree_util.tree_map_with_path(
      # The lambda function is called for each leaf.
      # `path` is a tuple of objects describing the route to the leaf.
      # `_` is the leaf value itself, which we ignore here.
      lambda path, _: print(jax.tree_util.keystr(path)),
      pytree
  )


In [8]:
# Corrected sharding logic
params = jax.tree_util.tree_map(
    lambda leaf: jax.device_put(
        leaf,
        mesh_sharding(PS('a', 'b') if leaf.ndim > 1 else PS('a'))
    ),
    params
)

# Optional: print a leaf to check the sharding
print("Example sharded leaf:\n", jax.tree_util.tree_leaves(params)[0].sharding)

Example sharded leaf:
 NamedSharding(mesh=Mesh('a': 2, 'b': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('a',), memory_kind=device)


In [9]:
model = LLaMA(model_config)

In [10]:
print(model)

LLaMA(
    # attributes
    args = ModelConfig(vocab_size=128256, dim=3072, ffn_hidden_dim=8192, n_layers=28, n_heads=24, n_kv_heads=8, activation_fn='silu', max_seq_len=8192, rope_theta=500000.0, rms_norm_eps=1e-05, mode='inference')
)


In [12]:
input_ids = jax.device_put(jnp.arange(16)[None,:],mesh_sharding(PS(None,'a')))
kv_cache = jax.device_put(KVCache.new(n_layers=model_config.n_layers,bsz=1,max_seq_len=model_config.max_seq_len,kv_heads=model_config.n_kv_heads,head_dim=model_config.head_dim),mesh_sharding(PS('a',None,'b')))
start_pos = 0

In [13]:
output = model.apply({"params":params},input_ids,start_pos,kv_cache)

In [None]:
params["layers"].keys()

frozen_dict_keys(['tok_embeddings', 'norm_weight', 'output', 'layers_0', 'layers_1', 'layers_2', 'layers_3', 'layers_4', 'layers_5', 'layers_6', 'layers_7', 'layers_8', 'layers_9', 'layers_10', 'layers_11', 'layers_12', 'layers_13', 'layers_14', 'layers_15', 'layers_16', 'layers_17', 'layers_18', 'layers_19', 'layers_20', 'layers_21', 'layers_22', 'layers_23', 'layers_24', 'layers_25', 'layers_26', 'layers_27'])

In [18]:
out,kvcache=output

In [20]:
out.shape
kvcache.k.shape

(28, 1, 8192, 8, 128)