In [66]:
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as PS, Mesh
import jax.random as random
from typing import Optional

from models.llama.model import LLaMa
from models.llama.config import ModelConfig
from models.llama.load import load_llama_weights
from models.llama.tokenizer import Tokenizer

from utils.kvcache import KVCache

from pathlib import Path



In [14]:
default_mesh = jax.make_mesh((2, 2), ('a', 'b'))

def mesh_sharding(
    pspec: PS, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)

In [15]:
checkpoint_path = Path("/home/ammar3.shaikh/ReLax/artifacts/weights/llama-3.2-3B")

In [16]:
model_config = ModelConfig.from_json_file(checkpoint_path)
model = LLaMa(model_config)
params = load_llama_weights(checkpoint_path)

In [None]:
tokenizer = Tokenizer(checkpoint_path/"original/tokenizer.model")

'<|begin_of_text|>Hello, world!<|end_of_text|>'

(1, 9)


In [35]:
output,new_kvcache = model.apply({"params":params},tokens,start_pos,kvcache)


In [45]:
pred = jnp.argmax(output,axis=-1).tolist()

In [49]:
response = tokenizer.decode(pred[0])

In [50]:
response

',1 again2  a   '

In [71]:
def generate_text(kvcache,start_pos, text, max_new_tokens, temperature=0.8, rng_key=None):
    tokens = jnp.array([tokenizer.encode(text,bos=False,eos=False)])

    response = tokens[0]
    if rng_key is None:
        rng_key = random.PRNGKey(42)

    for _ in range(max_new_tokens):
        output,new_kvcache = model.apply({"params":params},tokens,start_pos,kvcache)
        start_pos += len(tokens[0])
        
        # Sample from the model's output instead of always picking the best token
        logits = output[:, -1, :]
        scaled_logits = logits / temperature
        rng_key, subkey = random.split(rng_key)
        next_token = random.categorical(subkey, scaled_logits, axis=-1)
        
        tokens = next_token[None, :]
        kvcache = new_kvcache
        response = jnp.concatenate([response,tokens[0]])

    response_text = tokenizer.decode(response)
    return response_text

In [72]:
kvcache = KVCache.new(n_layers=model_config.n_layers, bsz=1, max_seq_len=model_config.max_seq_len, kv_heads=model_config.n_kv_heads, head_dim=model_config.head_dim)
start_pos = 0
text = "There once was a ship that put to sea"
max_new_tokens = 50

response = generate_text(kvcache,start_pos,text,max_new_tokens)
print(response)

There once was a ship that put to sea 7 7
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0


In [57]:
output,new_kvcache = model.apply({"params":params},tokens,start_pos,kvcache)

In [59]:
output.shape

(1, 17, 128256)