In [1]:
import jax
import jax.numpy as jnp


from models.llama.model import LLaMa
from models.llama.config import ModelConfig
from models.llama.tokenizer import Tokenizer
from models.llama.load import load_llama_weights
from utils.memory import estimate_pytree_memory_footprint, format_bytes
from utils.kvcache import KVCache

jax.config.update("jax_default_matmul_precision", "float32")


In [5]:
ckpt_dir = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX"
tokenizer_path = "/home/ammar3.shaikh/ReLax/artifacts/weights/Llama-3.2-3B-JAX/tokenizer.model"

In [6]:

model_config = ModelConfig.from_json_file(ckpt_dir)

tokenizer = Tokenizer(tokenizer_path)
print("Model config: ", model_config)

Model config:  ModelConfig(vocab_size=128256, dim=3072, ffn_hidden_dim=8192, n_layers=28, n_heads=24, n_kv_heads=8, activation_fn='silu', max_seqlen=8192, rope_theta=500000.0, rms_norm_eps=1e-05, mode='inference', dtype='float32', use_scaled_rope=True)


In [10]:
# Load model weights
params = load_llama_weights(ckpt_dir+"/model.safetensors", model_config)

# Initialize the model
model = LLaMa(model_config)

# Create a JAX random key
seed = 1
rng_key = jax.random.PRNGKey(seed)


# Estimate and print memory usage
params_size_bytes = estimate_pytree_memory_footprint(params)
print(f"Estimated model params size: {format_bytes(params_size_bytes)}")


Estimated model params size: 13.44GB


In [16]:
# JAX init
batch_size = 2
seq_len = 512
dummy_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
dummy_kvcache = KVCache.new(n_layers=model_config.n_layers, bsz=batch_size, max_seqlen=model_config.max_seqlen, kv_heads=model_config.n_kv_heads, head_dim=model_config.head_dim, dtype=model_config.dtype)

kvcache_size_bytes = estimate_pytree_memory_footprint(dummy_kvcache)
print(f"Estimated kvcache size: {format_bytes(kvcache_size_bytes)}")

Estimated kvcache size: 1.75GB


In [17]:
logits,updated_kvcache = model.apply({"params": params}, dummy_tokens, 0, dummy_kvcache)

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 896.00M. That was not possible. There are 612.30M free.; (0x0x0_HBM0)