In [None]:
from transformers import AutoModelForCausalLM

model_id = "/mnt/disks/jacobplatin/models/llama4/maverick/4-layer-debug-hf/HF-4layers/"
# model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct-Original"
hf_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype="float32",
)

In [None]:

import jax
import jax.numpy as jnp
import numpy as np

import MaxText.layers.models as models
import MaxText.layers.quantizations as quantizations
from MaxText import pyconfig
from MaxText import max_utils

import argparse
import sys

model_args = ['/mnt/disks/jacobplatin/code/maxtext/llama4_maverick_check_weight.py', 'MaxText/configs/base.yml', 'hardware=cpu', 'scan_layers=false', 'base_output_directory=llama4', 'run_name=temp-testing-only', 'model_name=llama4-17b-128e', 'skip_jax_distributed_system=true', 'load_parameters_path=/mnt/disks/jacobplatin/models/llama4/maverick/4-layer-unscanned/0/items/']
config = pyconfig.initialize(model_args)

init_rng = jax.random.PRNGKey(config.init_weights_seed)
init_rng, rng1 = jax.random.split(init_rng)
devices_array = max_utils.create_device_mesh(config)
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
quant = quantizations.configure_quantization(config)
model = models.Transformer(config, mesh=mesh, quant=quant)
state, _ = max_utils.setup_decode_state(model, config, rng1, mesh, None)
pass


In [54]:
RTOL, ATOL = 1e-3, 1e-3


def get_nested_robust(data, path: str, default = None, sep: str = '.'):
  """
  Accesses nested dictionary/list elements using a dot-separated path string.
  Handles dictionary keys and list/tuple indices.

  Args:
      data: The dictionary/list/tuple to traverse.
      path: The dot-separated path string (e.g., 'a.b.c' or 'a.d.1').
      default: The value to return if the path is not found or invalid. Defaults to None.
      sep: The separator character used in the path string. Defaults to '.'.

  Returns:
      The value found at the specified path, or the default value if not found.
  """
  keys = path.split(sep)
  current_value = data
  for key in keys:
      if current_value is None: # Stop early if we hit None
          return default
      try:
          if isinstance(current_value, dict):
              current_value = current_value.get(key) # Safe dict access
              if current_value is None and key not in current_value: # Distinguish missing key from value being None
                  return default
          elif isinstance(current_value, (list, tuple)):
              try:
                  index = int(key)
                  current_value = current_value[index]
              except (ValueError, IndexError): # Handle non-integer key or out-of-bounds
                  return default
          else:
              # Cannot index further into this type
              return default
      except (KeyError, IndexError, TypeError): # Catch potential errors during access
            return default

  return current_value

hf_to_maxtext_mapping = {
  "model.layers.{layer_idx}.input_layernorm.weight": "decoder.layers_{layer_idx}.pre_self_attention_layer_norm.scale",
  "model.layers.{layer_idx}.post_attention_layernorm.weight": "decoder.layers_{layer_idx}.post_self_attention_layer_norm.scale",
  "model.layers.{layer_idx}.self_attn.q_proj.weight": "decoder.layers_{layer_idx}.self_attention.query.kernel",
  "model.layers.{layer_idx}.self_attn.k_proj.weight": "decoder.layers_{layer_idx}.self_attention.key.kernel",
  "model.layers.{layer_idx}.self_attn.v_proj.weight": "decoder.layers_{layer_idx}.self_attention.value.kernel",
  "model.layers.{layer_idx}.self_attn.o_proj.weight": "decoder.layers_{layer_idx}.self_attention.out.kernel",
}
params = state.params["params"]


for hf_key, maxtext_key in hf_to_maxtext_mapping.items():
    print("On key:", hf_key)
    for i in range(4):
        is_dense_layer = i % 2 == 0
        hf_key = hf_key.format(layer_idx=i)
        maxtext_key = maxtext_key.format(layer_idx=i)
        a = hf_model.state_dict()[hf_key].detach().numpy()
        b = get_nested_robust(params, maxtext_key)
        if "self_attn." in hf_key:
            if "o_proj" in hf_key:
                b = b.reshape(5120, -1)
            else:
                b = b.reshape(5120, -1).transpose()
        if not np.allclose(a, b, rtol=RTOL, atol=ATOL):
            raise ValueError(f"Failed on {hf_key} and {maxtext_key}")

On key: model.layers.{layer_idx}.input_layernorm.weight
On key: model.layers.{layer_idx}.post_attention_layernorm.weight
On key: model.layers.{layer_idx}.self_attn.q_proj.weight
On key: model.layers.{layer_idx}.self_attn.k_proj.weight
On key: model.layers.{layer_idx}.self_attn.v_proj.weight
On key: model.layers.{layer_idx}.self_attn.o_proj.weight


In [None]:
RTOL, ATOL = 1e-3, 1e-3

a = hf_model.model.layers[0].input_layernorm.weight.detach().numpy()
b = state.params["params"]['decoder']['layers_0']["pre_self_attention_layer_norm"]["scale"]

np.testing.assert_allclose(a, b, rtol=RTOL, atol=ATOL)


a = hf_model.model.norm.weight.detach().numpy()
b = state.params["params"]["decoder"]['decoder_norm']['scale']
np.testing.assert_allclose(a, b, rtol=RTOL, atol=ATOL)

RTOL, ATOL = 1e-3, 1e-3
np.testing.assert_allclose(hf_model.lm_head.weight.detach().cpu().numpy(), state.params["params"]["decoder"]['logits_dense']['kernel'].transpose(), rtol=RTOL, atol=ATOL)

[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 1. 1. 1.]


KeyboardInterrupt: 