In [2]:
"""
Minimal setup for a tiny LLaMA model with randomly initialized weights.

Run with:
    python examples/minimal_llama_setup.py
"""

import os
os.environ["TPU_PROCESS_BOUNDS"]="1,1,1"
os.environ["TPU_VISIBLE_CHIPS"]="0"


import jax
import jax.numpy as jnp

from models.llama.config import ModelConfig
from models.llama.model import LLaMa
from utils.kvcache import KVCache
from utils.ops import build_attn_mask

# Tiny config so it runs fast on CPU
cfg = ModelConfig(
    vocab_size=128,
    dim=32,
    ffn_hidden_dim=64,
    n_layers=2,
    n_heads=4,
    n_kv_heads=2,
    activation_fn="silu",
    max_seqlen=16,
    rope_theta=10000.0,
    rms_norm_eps=1e-5,
    dtype=jnp.float32,
    use_scaled_rope=False,
)

model = LLaMa(cfg)

# Dummy token batch
tokens = jnp.array(
    [
        [1, 5, 7, 9],
        [4, 3, 2, 0],
    ],
    dtype=jnp.int32,
)
true_lengths = jnp.array([4, 3], dtype=jnp.int32)  # second sequence is padded at the end
bsz, seqlen = tokens.shape

# Fresh KV cache and mask
kv_cache = KVCache.new(
    n_layers=cfg.n_layers,
    bsz=bsz,
    max_seqlen=cfg.max_seqlen,
    kv_heads=cfg.n_kv_heads,
    head_dim=cfg.head_dim,
    dtype=cfg.dtype,
)
mask = build_attn_mask(seqlen, kv_cache, true_lengths)

# Initialize random weights, then run a forward pass
rng = jax.random.PRNGKey(0)
variables = model.init(rng, tokens, true_lengths, kv_cache, mask)
# logits, updated_cache = model.apply(variables, tokens, true_lengths, kv_cache, mask)

# print("Logits shape:", logits.shape)  # (batch, seq, vocab_size)
# print("Updated cache positions:", updated_cache.seq_positions)


ModuleNotFoundError: No module named 'models'

In [3]:
from models.llama.model import LLaMa

ModuleNotFoundError: No module named 'models'