In [1]:
import sys
import os
sys.path.append(os.path.abspath('..'))
import jax
import jax.numpy as jnp
from jax import random, jit
from functools import partial


from models.llama.model import LLaMa
from models.llama.config import ModelConfig
from utils.kvcache import KVCache
from models.llama.load import load_llama_weights
from utils.memory import estimate_pytree_memory_footprint, format_bytes
from models.llama.tokenizer import Tokenizer
from sampling import Sampler
from sampling import TopPSampler

devices = jax.devices()
jax.config.update("jax_default_matmul_precision", "highest")
jax.config.update("jax_enable_x64", True)










In [2]:
dim=3072
n_layers=28
n_heads=24
n_kv_heads=8
ffn_hidden_dim=8192
vocab_size=128256  
max_seqlen = 8192
rope_theta=500000.0
rms_norm_eps=1e-5
activation_fn="silu"
dtype=jnp.float32

head_dim = dim // n_heads

model_config = ModelConfig(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_kv_heads,
    ffn_hidden_dim=ffn_hidden_dim,
    vocab_size=vocab_size,
    max_seqlen=max_seqlen,
    rope_theta=rope_theta,
    rms_norm_eps=rms_norm_eps,
    activation_fn=activation_fn,
    dtype=dtype
)

In [3]:
from models.llama.load import load_llama_weights
model_path = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B"
params = load_llama_weights(model_path)

TypeError: Got unsupported ScalarType BFloat16

In [4]:
tokenizer = Tokenizer("/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B/original/tokenizer.model")

In [5]:
model = LLaMa(model_config)

In [29]:
prompt = "What is the capital of France?"

In [26]:

def generate(
    model: LLaMa,
    params,
    tokenizer: Tokenizer,
    prompt: str,
    max_gen_len: int,
    temperature: float,
    top_p: float,
    rng_key: jax.random.PRNGKey,
):
    """
    JAX-based text generation function.
    """
    # 1. Initialize sampler and KVCache
    sampler = TopPSampler(p=top_p, temperature=temperature)
    kv_cache = KVCache.new(
        n_layers=model.args.n_layers,
        bsz=1,
        max_seqlen=model.args.max_seqlen,
        kv_heads=model.args.n_kv_heads,
        head_dim=model.args.head_dim,
        dtype=model.args.dtype,
    )
    print(f"KVCache size: {format_bytes(estimate_pytree_memory_footprint(kv_cache))}")

    # 2. Define and JIT-compile the model step function for performance
    @partial(jit, static_argnames=['model'])
    def _model_step(model, params, tokens, kv_cache, start_pos):
        logits, updated_kv_cache = model.apply(
            {'params': params},
            tokens,
            start_pos=start_pos,
            kv_cache=kv_cache
        )
        return logits[:, -1, :], updated_kv_cache

    # 3. Encode prompt and pre-fill KV cache
    prompt_tokens = tokenizer.encode(prompt, bos=False, eos=False)
    tokens = jnp.array([prompt_tokens], dtype=jnp.int32)
    current_pos = 0
    generated_tokens = list(prompt_tokens)

    for _ in range(max_gen_len):

        logits, kv_cache = _model_step(model, params, tokens, kv_cache, current_pos)
        current_pos += tokens.shape[1]

        rng_key, sample_key = random.split(rng_key)
        next_token = sampler.sample(logits, sample_key)
        generated_tokens.append(next_token.item())
        tokens = next_token[:,None]

    return tokenizer.decode(generated_tokens)


In [27]:
response = generate(model,params,tokenizer,prompt,32,0.6,0.9,jax.random.PRNGKey(1))

KVCache size: 1.75GB


In [28]:
print(response)

What is the capital of France?  2 3 2 2 3 3 3 3 3 3 3 3 3 3 3 
