In [1]:
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as PS, Mesh
import jax.random as random
from typing import Optional

from models.llama.model import LLaMa
from models.llama.config import ModelConfig
from models.llama.load import load_llama_weights
from models.llama.tokenizer import Tokenizer

from utils.kvcache import KVCache
from utils.sharding import mesh_sharding

from pathlib import Path



In [2]:
checkpoint_path = Path("/home/ammar3.shaikh/ReLax/artifacts/weights/llama-3.2-3B")

In [3]:
model_config = ModelConfig.from_json_file(checkpoint_path)
model = LLaMa(model_config)
params = load_llama_weights(checkpoint_path)

In [4]:
tokenizer = Tokenizer(checkpoint_path/"original/tokenizer.model")

In [5]:
text = tokenizer.decode(tokenizer.encode("Hello, how are you?",bos=True,eos=True))
print(text)

<|begin_of_text|>Hello, how are you?<|end_of_text|>


In [6]:
def generate_text(kvcache,start_pos, text, max_new_tokens, temperature=0.8, rng_key=None):
    tokens = jnp.array([tokenizer.encode(text,bos=False,eos=False)])

    response = tokens[0]
    if rng_key is None:
        rng_key = random.PRNGKey(42)

    for _ in range(max_new_tokens):
        output,new_kvcache = model.apply({"params":params},tokens,start_pos,kvcache)
        start_pos += len(tokens[0])
        
        # Sample from the model's output instead of always picking the best token
        logits = output[:, -1, :]
        scaled_logits = logits / temperature
        rng_key, subkey = random.split(rng_key)
        next_token = random.categorical(subkey, scaled_logits, axis=-1)
        
        tokens = next_token[None, :]
        kvcache = new_kvcache
        response = jnp.concatenate([response,tokens[0]])

    response_text = tokenizer.decode(response)
    return response_text

In [7]:
kvcache = KVCache.new(n_layers=model_config.n_layers, bsz=1, max_seq_len=model_config.max_seq_len, kv_heads=model_config.n_kv_heads, head_dim=model_config.head_dim)
start_pos = 0
text = "There once was a ship that put to sea"
max_new_tokens = 50

response = generate_text(kvcache,start_pos,text,max_new_tokens)
print(response)

calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
callin

In [8]:
key = random.PRNGKey(0)
tokens = jnp.array([tokenizer.encode("Hello, how are you?",bos=False,eos=False)])

kvcache = KVCache.new(n_layers=model_config.n_layers, bsz=1, max_seq_len=model_config.max_seq_len, kv_heads=model_config.n_kv_heads, head_dim=model_config.head_dim)
start_pos = 0

In [9]:
dummy_params = model.init(key,tokens,start_pos,kvcache)

calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup
calling setup


In [10]:
def print_pytree_keys(pytree):
  """Prints the full path to each leaf in a PyTree."""
  jax.tree_util.tree_map_with_path(
      # The lambda function is called for each leaf.
      # `path` is a tuple of objects describing the route to the leaf.
      # `_` is the leaf value itself, which we ignore here.
      lambda path, _: print(jax.tree_util.keystr(path)),
      pytree
  )

In [11]:
print_pytree_keys(params)

['layer_0']['attention_norm_weight']
['layer_0']['ffn_norm_weight']
['layer_0']['w1_gate']
['layer_0']['w2_up']
['layer_0']['w3_down']
['layer_0']['wk']
['layer_0']['wo']
['layer_0']['wq']
['layer_0']['wv']
['layer_1']['attention_norm_weight']
['layer_1']['ffn_norm_weight']
['layer_1']['w1_gate']
['layer_1']['w2_up']
['layer_1']['w3_down']
['layer_1']['wk']
['layer_1']['wo']
['layer_1']['wq']
['layer_1']['wv']
['layer_10']['attention_norm_weight']
['layer_10']['ffn_norm_weight']
['layer_10']['w1_gate']
['layer_10']['w2_up']
['layer_10']['w3_down']
['layer_10']['wk']
['layer_10']['wo']
['layer_10']['wq']
['layer_10']['wv']
['layer_11']['attention_norm_weight']
['layer_11']['ffn_norm_weight']
['layer_11']['w1_gate']
['layer_11']['w2_up']
['layer_11']['w3_down']
['layer_11']['wk']
['layer_11']['wo']
['layer_11']['wq']
['layer_11']['wv']
['layer_12']['attention_norm_weight']
['layer_12']['ffn_norm_weight']
['layer_12']['w1_gate']
['layer_12']['w2_up']
['layer_12']['w3_down']
['layer_12']['

In [12]:
print_pytree_keys(dummy_params["params"])

['layer_0']['attention_norm_weight']
['layer_0']['ffn_norm_weight']
['layer_0']['w1_gate']
['layer_0']['w2_up']
['layer_0']['w3_down']
['layer_0']['wk']
['layer_0']['wo']
['layer_0']['wq']
['layer_0']['wv']
['layer_1']['attention_norm_weight']
['layer_1']['ffn_norm_weight']
['layer_1']['w1_gate']
['layer_1']['w2_up']
['layer_1']['w3_down']
['layer_1']['wk']
['layer_1']['wo']
['layer_1']['wq']
['layer_1']['wv']
['layer_10']['attention_norm_weight']
['layer_10']['ffn_norm_weight']
['layer_10']['w1_gate']
['layer_10']['w2_up']
['layer_10']['w3_down']
['layer_10']['wk']
['layer_10']['wo']
['layer_10']['wq']
['layer_10']['wv']
['layer_11']['attention_norm_weight']
['layer_11']['ffn_norm_weight']
['layer_11']['w1_gate']
['layer_11']['w2_up']
['layer_11']['w3_down']
['layer_11']['wk']
['layer_11']['wo']
['layer_11']['wq']
['layer_11']['wv']
['layer_12']['attention_norm_weight']
['layer_12']['ffn_norm_weight']
['layer_12']['w1_gate']
['layer_12']['w2_up']
['layer_12']['w3_down']
['layer_12']['

In [None]:
from safetensors import safe_open
import jax
import jax.numpy as jnp

with safe_open(checkpoint_path/"model.safetensors", framework="flax") as f:
    for key in f.keys():
        print(key)

In [1]:
import jax
import jax.numpy as jnp
import chex

In [3]:
dims = chex.Dimensions(B=3, T=5, N=7)  # You can specify any letters.
x = jnp.array([[2, 0, 5, 6, 3],
               [5, 4, 4, 3, 3],
               [0, 0, 5, 2, 0]])
chex.assert_shape(x, dims['BT'])

In [None]:
seqlen=5
start_pos=2
bsz = 3
n_heads = 7
head_dim = 10
