In [2]:
from models.llama.config import ModelConfig
import jax
import jax.numpy as jnp
import torch
from models.llama.model import TransformerBlock as TransformerBlockJax
from models.llama.config import ModelConfig
from experiments.torch_llama import TransformerBlock as TransformerBlockTorch, ModelArgs
from utils.ops import precompute_freqs_cis as precompute_freqs_cis_jax
from utils.kvcache import KVCache
from experiments.torch_llama import precompute_freqs_cis as precompute_freqs_cis_torch

jax.config.update("jax_default_matmul_precision", "float32")

In [3]:
dim = 768
n_layers = 8
n_heads = 12
n_kv_heads = 4
ffn_hidden_dim = 2048
head_dim = 64
vocab_size = 1000
multiple_of = 256   
norm_eps = 1e-5
rope_theta = 500000
use_scaled_rope = False
max_batch_size = 4
max_seq_len = 1024
flash = False

args = ModelArgs(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_kv_heads,
    ffn_hidden_dim=ffn_hidden_dim,
    vocab_size=vocab_size,
    multiple_of=multiple_of,
    norm_eps=norm_eps,
    rope_theta=rope_theta,
    use_scaled_rope=False,
    max_batch_size=4,
    max_seq_len=max_seq_len,
    flash=flash,
)

config = ModelConfig(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_kv_heads,
    ffn_hidden_dim=ffn_hidden_dim,
    rms_norm_eps=norm_eps,
    activation_fn="silu",
    dtype="float32",
    max_seqlen=max_seq_len, 
    vocab_size=vocab_size,
    rope_theta=rope_theta,
)

In [5]:
transformer_block_jax = TransformerBlockJax(config)
transformer_block_torch = TransformerBlockTorch(args)

In [4]:
freqs_cis_jax = precompute_freqs_cis_jax(dim//n_heads, end=2*max_seq_len, theta=rope_theta)
freqs_cis_torch = precompute_freqs_cis_torch(dim//n_heads, end=2*max_seq_len, theta=rope_theta)

print(freqs_cis_jax.shape)
print(freqs_cis_torch.shape)

import numpy as np

assert np.allclose(np.array(freqs_cis_jax), np.array(freqs_cis_torch.cpu().numpy()), atol=2e-4,rtol=1e-8)

(2048, 32, 2)
torch.Size([2048, 32, 2])


In [5]:

def create_shared_transformer_weights(dim, n_heads, n_kv_heads, ffn_hidden_dim, seed=42):
    """Create shared weights for JAX and PyTorch TransformerBlocks"""
    np.random.seed(seed)
    head_dim = dim // n_heads
    
    # Create numpy weights
    weights_np = {
        # Attention weights
        'wq': np.random.normal(0, 0.02, (dim, n_heads * head_dim)).astype(np.float32),
        'wk': np.random.normal(0, 0.02, (dim, n_kv_heads * head_dim)).astype(np.float32),
        'wv': np.random.normal(0, 0.02, (dim, n_kv_heads * head_dim)).astype(np.float32),
        'wo': np.random.normal(0, 0.02, (n_heads * head_dim, dim)).astype(np.float32),
        
        # Feed-forward weights  
        'w1': np.random.normal(0, 0.02, (dim, ffn_hidden_dim)).astype(np.float32),
        'w2': np.random.normal(0, 0.02, (ffn_hidden_dim, dim)).astype(np.float32),
        'w3': np.random.normal(0, 0.02, (dim, ffn_hidden_dim)).astype(np.float32),
        
        # Normalization weights
        'attention_norm': np.ones(dim, dtype=np.float32),
        'ffn_norm': np.ones(dim, dtype=np.float32),
    }
    
    # Convert to PyTorch format
    weights_torch = {k: torch.from_numpy(v) for k, v in weights_np.items()}
    
    # Convert to JAX format (reshape attention weights for JAX)
    weights_jax = {
        'wq': jnp.array(weights_np['wq'].reshape(dim, n_heads, head_dim)),
        'wk': jnp.array(weights_np['wk'].reshape(dim, n_kv_heads, head_dim)),
        'wv': jnp.array(weights_np['wv'].reshape(dim, n_kv_heads, head_dim)),
        'wo': jnp.array(weights_np['wo']),
        'w1_gate': jnp.array(weights_np['w1']),
        'w2_up': jnp.array(weights_np['w3']),
        'w3_down': jnp.array(weights_np['w2']),
        'attention_norm_weight': jnp.array(weights_np['attention_norm']),
        'ffn_norm_weight': jnp.array(weights_np['ffn_norm']),
    }
    
    return weights_torch, weights_jax

def set_torch_weights(transformer_block_torch, weights_torch):
    """Set weights for PyTorch TransformerBlock"""
    with torch.no_grad():
        transformer_block_torch.attention.wq.weight.copy_(weights_torch['wq'].T)
        transformer_block_torch.attention.wk.weight.copy_(weights_torch['wk'].T)
        transformer_block_torch.attention.wv.weight.copy_(weights_torch['wv'].T)
        transformer_block_torch.attention.wo.weight.copy_(weights_torch['wo'].T)
        
        transformer_block_torch.feed_forward.w1.weight.copy_(weights_torch['w1'].T)
        transformer_block_torch.feed_forward.w2.weight.copy_(weights_torch['w2'].T)
        transformer_block_torch.feed_forward.w3.weight.copy_(weights_torch['w3'].T)
        
        transformer_block_torch.attention_norm.weight.copy_(weights_torch['attention_norm'])
        transformer_block_torch.ffn_norm.weight.copy_(weights_torch['ffn_norm'])

def set_jax_weights(jax_variables, weights_jax):
    """Set weights for JAX TransformerBlock variables"""
    new_params = jax_variables['params'].copy()
    
    new_params['wq'] = weights_jax['wq']
    new_params['wk'] = weights_jax['wk']
    new_params['wv'] = weights_jax['wv']
    new_params['wo'] = weights_jax['wo']
    
    new_params['w1_gate'] = weights_jax['w1_gate']
    new_params['w2_up'] = weights_jax['w2_up']
    new_params['w3_down'] = weights_jax['w3_down']
    
    new_params['attention_norm_weight'] = weights_jax['attention_norm_weight']
    new_params['ffn_norm_weight'] = weights_jax['ffn_norm_weight']
    
    return {'params': new_params}

# Usage:
weights_torch, weights_jax = create_shared_transformer_weights(dim, n_heads, n_kv_heads, ffn_hidden_dim)
set_torch_weights(transformer_block_torch, weights_torch)

In [6]:
seq_len = 512
np_tokens = np.random.uniform(0, 10, (max_batch_size, seq_len, dim)).astype(np.float32)  # Shape: (2, 512, 768)

torch_mask = torch.full((seq_len, seq_len), float("-inf"))
torch_mask = torch.triu(torch_mask, diagonal=1)

torch_output = transformer_block_torch(torch.tensor(np_tokens).float(), start_pos=0, freqs_cis=freqs_cis_torch[:seq_len], mask=torch_mask)


In [7]:
key = jax.random.PRNGKey(0)
dummy_input = jnp.ones((max_batch_size, 1, dim),dtype=jnp.float32)  # dummy input for initialization
dummy_freqs = freqs_cis_jax[:1]
dummy_kvcache = KVCache.new(n_layers=config.n_layers, bsz=max_batch_size, max_seq_len=config.max_seqlen, kv_heads=config.n_kv_heads, head_dim=config.head_dim,dtype=jnp.float32)
variables = transformer_block_jax.init(key, dummy_input, dummy_freqs, dummy_kvcache, 0, 0)
variables = set_jax_weights(variables, weights_jax)


kvcache = KVCache.new(n_layers=config.n_layers, bsz=max_batch_size, max_seq_len=config.max_seqlen, kv_heads=config.n_kv_heads, head_dim=config.head_dim,dtype=jnp.float32)

# Now apply it with the actual inputs
jax_output, _ = transformer_block_jax.apply(variables, np_tokens, freqs_cis_jax, kvcache, 0, 0)