In [1]:
import os

# Must be done before importing JAX or any JAX-related modules
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
os.environ["TPU_VISIBLE_CHIPS"] = "0"

import jax
import jax.numpy as jnp
from pathlib import Path
import dataclasses
from models.llama.model import LLaMa 
from models.llama.config import ModelConfig
from models.llama.load import load_llama_weights
from models.llama.tokenizer import Tokenizer
from utils.kvcache import KVCache

print(jax.devices())


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]


In [2]:
model_path = Path("/home/ammar/weights/Llama-3.2-1B")
jax_dtype = jnp.bfloat16
use_scaled_rope = False
batch_size = 1
max_seq_len = 256

In [3]:
config = ModelConfig.from_json_file(model_path)
config = dataclasses.replace(config, dtype=jax_dtype)

In [4]:
config

ModelConfig(vocab_size=128256, dim=2048, ffn_hidden_dim=8192, n_layers=16, n_heads=32, n_kv_heads=8, activation_fn='silu', max_seqlen=8192, rope_theta=500000.0, rms_norm_eps=1e-05, mode='inference', dtype=<class 'jax.numpy.bfloat16'>, use_scaled_rope=True)

In [5]:
# Initialize JAX model
model = LLaMa(config)
print("✓ JAX model initialized")


✓ JAX model initialized


In [6]:
loaded_params = load_llama_weights(str(model_path), config)

Loading model from /home/ammar/weights/Llama-3.2-1B
Found 1 safetensor file(s)
  Loading model.safetensors...
Loaded 146 tensors total
Converting HuggingFace weights to ReLax format...
  ✓ Embeddings: (128256, 2048)
  ✓ LM head (tied embeddings): (128256, 2048) -> (2048, 128256)
  ✓ Final norm: (2048,)
  Converting 16 transformer layers...
    Layer 0 shapes:
      wq: (2048, 32, 64)
      wk: (2048, 8, 64)
      wv: (2048, 8, 64)
      wo: (2048, 2048)
      w_gate: (2048, 8192)
      w_up: (2048, 8192)
      w_down: (8192, 2048)
      attention_norm_weight: (2048,)
      ffn_norm_weight: (2048,)
  ✓ Converted all 16 layers


In [16]:
test_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\n\nToday Date: 23 July 2024\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

In [17]:
tokenizer_path = model_path / "original" / "tokenizer.model"
if not tokenizer_path.exists():
    tokenizer_path = model_path / "tokenizer.model"

if not tokenizer_path.exists():
    raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")

# Initialize tokenizer
tokenizer = Tokenizer(model_path=str(tokenizer_path))

# Tokenize input text
prompt_tokens = tokenizer.encode(test_prompt, bos=False, eos=False)
print(f"Test prompt: {test_prompt}")
print(f"Prompt tokens: {len(prompt_tokens)}")

max_gen_len = 256

Test prompt: <|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023

Today Date: 23 July 2024

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Prompt tokens: 111


In [18]:
print(tokenizer.decode(prompt_tokens))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023

Today Date: 23 July 2024

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


In [21]:
print("\n" + "="*80)
print("GENERATING TEXT WITH GREEDY SAMPLING")
print("="*80)

# Initialize KV cache for generation
max_seq_len = len(prompt_tokens) + max_gen_len
head_dim = config.head_dim
kv_cache = KVCache.new(
    config.n_layers,
    batch_size,
    max_seq_len,
    config.n_kv_heads,
    head_dim,
    dtype=jax_dtype
)
print(f"✓ KV cache initialized: {config.n_layers} layers, "
        f"batch={batch_size}, max_seq_len={max_seq_len}")

# Convert prompt tokens to JAX array
tokens = jnp.array([prompt_tokens], dtype=jnp.int32)  # [1, prompt_len]
current_seq_len = len(prompt_tokens)



GENERATING TEXT WITH GREEDY SAMPLING
✓ KV cache initialized: 16 layers, batch=1, max_seq_len=367


In [None]:
print(f"Prefilling with {len(prompt_tokens)} tokens...")
seq_lengths = jnp.array([len(prompt_tokens)], dtype=jnp.int32)
logits, kv_cache = model.apply(
    {"params": loaded_params},
    tokens=tokens,
    seq_lengths=seq_lengths,
    kv_cache=kv_cache,
)
print(f"✓ Prefill complete")


Prefilling with 111 tokens...
✓ Prefill complete


: 

In [21]:
print(logits.shape)

(1, 112, 128256)


In [8]:
dummy_tokens = jnp.ones((batch_size,1),dtype=jnp.int32)
seq_lengths = jnp.ones((batch_size,),dtype=jnp.int32)
kv_cache = KVCache.new(
    config.n_layers,
    batch_size,
    max_seq_len,
    config.n_kv_heads,
    config.head_dim,
    dtype=config.dtype
)


In [None]:
params = model.init(model_key,dummy_tokens,seq_lengths,kv_cache)

See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.


In [13]:
print(params["params"].keys())

dict_keys(['norm_weight', 'output', 'tok_embeddings', 'layer_0', 'layer_1', 'layer_2', 'layer_3', 'layer_4', 'layer_5', 'layer_6', 'layer_7', 'layer_8', 'layer_9', 'layer_10', 'layer_11', 'layer_12', 'layer_13', 'layer_14', 'layer_15'])


In [12]:
import jax

print(jax.tree.structure(params["params"]))
print(jax.tree.structure(loaded_params))
if jax.tree.structure(params["params"]) != jax.tree.structure(loaded_params):
    raise ValueError("params and loaded_params do not share the same pytree structure")



PyTreeDef({'layer_0': {'attention_norm_weight': *, 'ffn_norm_weight': *, 'w_down': *, 'w_gate': *, 'w_up': *, 'wk': *, 'wo': *, 'wq': *, 'wv': *}, 'layer_1': {'attention_norm_weight': *, 'ffn_norm_weight': *, 'w_down': *, 'w_gate': *, 'w_up': *, 'wk': *, 'wo': *, 'wq': *, 'wv': *}, 'layer_10': {'attention_norm_weight': *, 'ffn_norm_weight': *, 'w_down': *, 'w_gate': *, 'w_up': *, 'wk': *, 'wo': *, 'wq': *, 'wv': *}, 'layer_11': {'attention_norm_weight': *, 'ffn_norm_weight': *, 'w_down': *, 'w_gate': *, 'w_up': *, 'wk': *, 'wo': *, 'wq': *, 'wv': *}, 'layer_12': {'attention_norm_weight': *, 'ffn_norm_weight': *, 'w_down': *, 'w_gate': *, 'w_up': *, 'wk': *, 'wo': *, 'wq': *, 'wv': *}, 'layer_13': {'attention_norm_weight': *, 'ffn_norm_weight': *, 'w_down': *, 'w_gate': *, 'w_up': *, 'wk': *, 'wo': *, 'wq': *, 'wv': *}, 'layer_14': {'attention_norm_weight': *, 'ffn_norm_weight': *, 'w_down': *, 'w_gate': *, 'w_up': *, 'wk': *, 'wo': *, 'wq': *, 'wv': *}, 'layer_15': {'attention_norm_wei

In [14]:
jax.tree.map(lambda x,y: x.shape == y.shape, params["params"], loaded_params)

{'layer_0': {'attention_norm_weight': True,
  'ffn_norm_weight': True,
  'w_down': True,
  'w_gate': True,
  'w_up': True,
  'wk': True,
  'wo': True,
  'wq': True,
  'wv': True},
 'layer_1': {'attention_norm_weight': True,
  'ffn_norm_weight': True,
  'w_down': True,
  'w_gate': True,
  'w_up': True,
  'wk': True,
  'wo': True,
  'wq': True,
  'wv': True},
 'layer_10': {'attention_norm_weight': True,
  'ffn_norm_weight': True,
  'w_down': True,
  'w_gate': True,
  'w_up': True,
  'wk': True,
  'wo': True,
  'wq': True,
  'wv': True},
 'layer_11': {'attention_norm_weight': True,
  'ffn_norm_weight': True,
  'w_down': True,
  'w_gate': True,
  'w_up': True,
  'wk': True,
  'wo': True,
  'wq': True,
  'wv': True},
 'layer_12': {'attention_norm_weight': True,
  'ffn_norm_weight': True,
  'w_down': True,
  'w_gate': True,
  'w_up': True,
  'wk': True,
  'wo': True,
  'wq': True,
  'wv': True},
 'layer_13': {'attention_norm_weight': True,
  'ffn_norm_weight': True,
  'w_down': True,
  'w_ga

In [None]:
# ============================================================================
# PYTORCH MODEL SETUP
# ============================================================================
import torch

# Device selection for PyTorch
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    TORCH_XLA_AVAILABLE = True
except ImportError:
    TORCH_XLA_AVAILABLE = False
    torch_xla = None

if TORCH_XLA_AVAILABLE:
    torch_device = torch_xla.device()
    print(f"Using PyTorch XLA device: {torch_device}")
elif torch.cuda.is_available():
    torch_device = "cuda"
    print("Using CUDA device")
else:
    torch_device = "cpu"
    print("Using CPU device")

# Import PyTorch model components
from experiments.torch_llama import Llama as Llama_wrapper


In [None]:
# Load PyTorch model
torch_original_path = model_path / "original"
torch_tokenizer_path = torch_original_path / "tokenizer.model"

if not torch_original_path.exists():
    torch_original_path = model_path
    torch_tokenizer_path = model_path / "tokenizer.model"

print(f"Loading PyTorch model from: {torch_original_path}")
print(f"Tokenizer: {torch_tokenizer_path}")

# Calculate max_seq_len needed
max_seq_len = len(prompt_tokens) + max_gen_len
max_seq_len = min(max_seq_len, 8192)  # Cap at model max

llama_wrapper = Llama_wrapper.build(
    ckpt_dir=str(torch_original_path),
    tokenizer_path=str(torch_tokenizer_path),
    max_seq_len=max_seq_len,
    max_batch_size=1,
    flash=False,
)
torch_model = llama_wrapper.model
torch_model.eval()
print("✓ PyTorch model loaded successfully")


In [None]:
# ============================================================================
# PYTORCH FORWARD PASS
# ============================================================================
print("\n" + "="*80)
print("PYTORCH FORWARD PASS")
print("="*80)

# Get device for tensor creation
if TORCH_XLA_AVAILABLE:
    device_for_tensor = torch_device
elif isinstance(torch_device, str):
    device_for_tensor = torch_device
else:
    device_for_tensor = torch_device

# Initialize KV caches in all layers
torch_params = torch_model.params
total_len = len(prompt_tokens) + max_gen_len
total_len = min(total_len, max_seq_len)

print(f"Initializing KV caches for max sequence length: {total_len}")
from experiments.torch_llama import KVCache as KVCache_torch
for i in range(torch_params.n_layers):
    layer_dtype = torch_model.layers[i].attention.wq.weight.dtype
    layer_device = torch_model.layers[i].attention.wq.weight.device
    torch_model.layers[i].attention.cache = KVCache_torch(
        batch_size=1,
        seq_length=total_len,
        n_kv_heads=torch_params.n_kv_heads,
        head_dim=torch_params.dim // torch_params.n_heads,
        dtype=layer_dtype,
        device=layer_device,
    )
print(f"✓ KV caches initialized in all {torch_params.n_layers} layers")


In [None]:
# Convert prompt tokens to tensor
tokens_torch = torch.tensor([prompt_tokens], device=device_for_tensor, dtype=torch.long)  # [1, prompt_len]
start_pos = 0

# Prefill: process the prompt
print(f"Prefilling with {len(prompt_tokens)} tokens...")
with torch.no_grad():
    logits_torch = torch_model.forward_inference(tokens_torch, start_pos)
    # Synchronize XLA operations if using XLA
    if TORCH_XLA_AVAILABLE:
        torch_xla.sync()  # Synchronize XLA operations
print(f"✓ Prefill complete")


In [None]:
# Check logits shape
print(f"PyTorch logits shape: {logits_torch.shape}")
print(f"JAX logits shape: {logits.shape}")
print(f"Shapes match: {logits_torch.shape == logits.shape}")
