In [1]:
import torch
import jax
import jax.numpy as jnp
import numpy as np
from models.llama.config import ModelConfig
from models.llama.model import LLaMa as TransformerJax
from models.llama.config import ModelConfig
from experiments.torch_llama import Transformer as TransformerTorch, ModelArgs
from utils.ops import precompute_freqs_cis as precompute_freqs_cis_jax
from utils.kvcache import KVCache
from experiments.torch_llama import precompute_freqs_cis as precompute_freqs_cis_torch

jax.config.update("jax_default_matmul_precision", "float32")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
dim = 768
n_layers = 8
n_heads = 12
n_kv_heads = 4
ffn_hidden_dim = 2048
head_dim = 64
vocab_size = 1000
multiple_of = 256   
norm_eps = 1e-5
rope_theta = 500000
use_scaled_rope = False
max_batch_size = 8
max_seq_len = 1024
flash = False

args = ModelArgs(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_kv_heads,
    ffn_hidden_dim=ffn_hidden_dim,
    vocab_size=vocab_size,
    multiple_of=multiple_of,
    norm_eps=norm_eps,
    rope_theta=rope_theta,
    use_scaled_rope=False,
    max_batch_size=4,
    max_seq_len=max_seq_len,
    flash=flash,
)

config = ModelConfig(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_kv_heads,
    ffn_hidden_dim=ffn_hidden_dim,
    rms_norm_eps=norm_eps,
    activation_fn="silu",
    dtype="float32",
    max_seqlen=max_seq_len, 
    vocab_size=vocab_size,
    rope_theta=rope_theta,
)

In [3]:
def create_shared_full_transformer_weights(dim, n_layers, n_heads, n_kv_heads, ffn_hidden_dim, vocab_size, seed=42):
    """Create shared weights for full JAX and PyTorch Transformers"""
    np.random.seed(seed)
    head_dim = dim // n_heads
    
    # Embedding weights
    tok_embeddings = np.random.normal(0, 0.02, (vocab_size, dim)).astype(np.float32)
    
    # Final norm weights
    final_norm = np.ones(dim, dtype=np.float32)
    
    # Create weights for all transformer layers
    layer_weights = []
    for i in range(n_layers):
        np.random.seed(seed + i + 1)  # Different seed for each layer
        layer_weights_np = {
            # Attention weights
            'wq': np.random.normal(0, 0.02, (dim, n_heads * head_dim)).astype(np.float32),
            'wk': np.random.normal(0, 0.02, (dim, n_kv_heads * head_dim)).astype(np.float32),
            'wv': np.random.normal(0, 0.02, (dim, n_kv_heads * head_dim)).astype(np.float32),
            'wo': np.random.normal(0, 0.02, (n_heads * head_dim, dim)).astype(np.float32),
            
            # Feed-forward weights  
            'w1': np.random.normal(0, 0.02, (dim, ffn_hidden_dim)).astype(np.float32),
            'w2': np.random.normal(0, 0.02, (ffn_hidden_dim, dim)).astype(np.float32),
            'w3': np.random.normal(0, 0.02, (dim, ffn_hidden_dim)).astype(np.float32),
            
            # Normalization weights
            'attention_norm': np.ones(dim, dtype=np.float32),
            'ffn_norm': np.ones(dim, dtype=np.float32),
        }
        layer_weights.append(layer_weights_np)
    
    return {
        'tok_embeddings': tok_embeddings,
        'final_norm': final_norm,
        'layers': layer_weights
    }

def set_torch_full_transformer_weights(transformer, weights):
    """Set weights for PyTorch Transformer"""
    with torch.no_grad():
        # Set embedding weights
        transformer.tok_embeddings.weight.copy_(torch.from_numpy(weights['tok_embeddings']))
        
        # Set final norm weights
        transformer.norm.weight.copy_(torch.from_numpy(weights['final_norm']))
        
        # Set weights for each layer
        for i, layer_weights in enumerate(weights['layers']):
            layer = transformer.layers[i]
            
            # Attention weights (transpose for Linear layers)
            layer.attention.wq.weight.copy_(torch.from_numpy(layer_weights['wq'].T))
            layer.attention.wk.weight.copy_(torch.from_numpy(layer_weights['wk'].T))
            layer.attention.wv.weight.copy_(torch.from_numpy(layer_weights['wv'].T))
            layer.attention.wo.weight.copy_(torch.from_numpy(layer_weights['wo'].T))
            
            # Feed-forward weights
            layer.feed_forward.w1.weight.copy_(torch.from_numpy(layer_weights['w1'].T))
            layer.feed_forward.w2.weight.copy_(torch.from_numpy(layer_weights['w2'].T))
            layer.feed_forward.w3.weight.copy_(torch.from_numpy(layer_weights['w3'].T))
            
            # Norm weights
            layer.attention_norm.weight.copy_(torch.from_numpy(layer_weights['attention_norm']))
            layer.ffn_norm.weight.copy_(torch.from_numpy(layer_weights['ffn_norm']))

def set_jax_full_transformer_weights(jax_variables, weights, n_heads, n_kv_heads):
    """Set weights for JAX LLaMa model"""
    head_dim = weights['tok_embeddings'].shape[1] // n_heads
    new_params = jax_variables['params'].copy()
    
    # Set embedding weights
    new_params['tok_embeddings'] = {'embedding': jnp.array(weights['tok_embeddings'])}
    
    # Set final norm weights
    new_params['norm_weight'] = jnp.array(weights['final_norm'])
    
    # Set weights for each layer
    for i, layer_weights in enumerate(weights['layers']):
        layer_key = f'layer_{i}'
        new_params[layer_key] = {
            # Attention weights (reshape for JAX format)
            'wq': jnp.array(layer_weights['wq'].reshape(-1, n_heads, head_dim)),
            'wk': jnp.array(layer_weights['wk'].reshape(-1, n_kv_heads, head_dim)),
            'wv': jnp.array(layer_weights['wv'].reshape(-1, n_kv_heads, head_dim)),
            'wo': jnp.array(layer_weights['wo']),
            
            # Feed-forward weights (note the mapping)
            'w1_gate': jnp.array(layer_weights['w1']),
            'w2_up': jnp.array(layer_weights['w3']),
            'w3_down': jnp.array(layer_weights['w2']),
            
            # Norm weights
            'attention_norm_weight': jnp.array(layer_weights['attention_norm']),
            'ffn_norm_weight': jnp.array(layer_weights['ffn_norm']),
        }
    
    return {'params': new_params}


In [4]:
# JAX init
key = jax.random.PRNGKey(0)
batch_size = 4
seq_len = 512
dummy_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
dummy_kvcache = KVCache.new(n_layers=config.n_layers, bsz=batch_size, max_seq_len=max_seq_len, kv_heads=config.n_kv_heads, head_dim=config.dim // config.n_heads, dtype=config.dtype)

transformer_jax = TransformerJax(config)

jax_variables = transformer_jax.init(key, dummy_tokens, 0,dummy_kvcache)

In [5]:
model_weights = create_shared_full_transformer_weights(config.dim, config.n_layers, config.n_heads, config.n_kv_heads, config.ffn_hidden_dim, config.vocab_size)

In [6]:
jax_variables = set_jax_full_transformer_weights(jax_variables, model_weights, n_heads, n_kv_heads)

In [7]:
import numpy as np

batch_size = 4
seq_len = 512

input_tokens = np.random.randint(0,vocab_size, (batch_size, seq_len))

transformer_torch = TransformerTorch(args)
set_torch_full_transformer_weights(transformer_torch, model_weights)

torch_output = transformer_torch.forward_inference(torch.tensor(input_tokens,device=device), 0)


jax_kvcache = KVCache.new(n_layers=config.n_layers, bsz=batch_size, max_seq_len=max_seq_len, kv_heads=config.n_kv_heads, head_dim=config.dim // config.n_heads, dtype=config.dtype)


jax_output,_ = transformer_jax.apply(jax_variables, input_tokens, 0, jax_kvcache)

In [8]:
print(torch_output.shape)
print(jax_output.shape)

torch.Size([4, 512, 1000])
(4, 512, 1000)
