https://orbax.readthedocs.io/en/latest/guides/checkpoint/debug_guide.html

In [35]:
import jax
import numpy as np
from etils import epath
import orbax.checkpoint as ocp
import tensorstore as ts
import collections
import operator
import asyncio
import re

## inspect

In [36]:
def natural_sort_key(s: str):
    """
    Utility to sort parameter names, by splitting the string into text and number chunks.
    """
    return [int(text) if text.isdigit() else text for text in re.split(r'(\d+)', s)]


def inspect(path: str):
  """
  inspect orbax checkpoint by reading metadata
  """
  path = epath.Path(path)
  print(path)

  metadata = ocp.StandardCheckpointer().metadata(path)

  # 1 check size
  size_counts = collections.defaultdict(int)

  def get_arr_bytes(meta):
    dtype = meta.dtype
    shape = meta.shape
    size_counts[dtype] += 1
    return np.prod(shape) * np.dtype(dtype).itemsize

  total_bytes = jax.tree.reduce(operator.add, jax.tree.map(get_arr_bytes, metadata))
  print('{0:0.3f} GB'.format(float(total_bytes) / 1e9))
  print()


  # 2 check params only (skip opt_state): name and shape
  # k: name tuple, v: array meta data
  dictionary = ocp.tree.to_flat_dict(metadata)
  dictionary = {".".join(k): v.shape for k, v in dictionary.items() if k[0] == "params" }
  # sort layer name 1, 2 ...., 10 (rather than 1, 10, 2, ...)
  for k in sorted(dictionary, key=natural_sort_key):
      #print(f"{k}\n\t{dictionary[k]}")
      print(f"key | {k} | {dictionary[k]}")

  return metadata

In [37]:
  # print('leaf dtype counts:')
  # for dtype, count in size_counts.items():
  #   print(f'{dtype}: {count}')
  
  # # check parameter
  # metadata_contents = ['.'.join(k) for k in ocp.tree.to_flat_dict(metadata)]
  # # sort layer name 1, 2 ...., 10 (rather than 1, 10, 2, ...)
  # metadata_contents.sort(key=natural_sort_key)
  # # Here are the parameters present in the checkpoint tree.
  # for p in metadata_contents:
  #   # skip opt_state, params only
  #   if p.startswith("params"):
  #     print(p)
  # return metadata

### Random

In [38]:
# metadata = inspect("/home/shuningjin/tmp/llama4/v5-scout-instruct-bf16-final-scanned/0/items")
metadata = inspect("/home/shuningjin/tmp/gcsfuse-jacobplatin/llama4/v5-scout-instruct-bf16-final-scanned/0/items")

/home/shuningjin/tmp/gcsfuse-jacobplatin/llama4/v5-scout-instruct-bf16-final-scanned/0/items


215.540 GB

key | params.params.decoder.decoder_norm.scale | (5120,)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel | (5120, 48, 16)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0 | (16, 48, 5120, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1 | (16, 48, 5120, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo | (16, 48, 8192, 5120)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel | (5120, 48, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel | (5120, 48, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel | (8192, 48, 5120)
key | params.params.decoder.layers.post_self_attention_layer_norm.scale | (5120, 48)
key | params.params.decoder.layers.pre_self_attention_layer_norm.scale | (5120, 48)
key | params.params.decoder.layers.self_attention.key.kernel | (5120, 48, 8, 128)
key | param

In [39]:
dictionary = ocp.tree.to_flat_dict(metadata)
dictionary = {".".join(k): v for k, v in dictionary.items()}
dictionary["params.params.decoder.decoder_norm.scale"]

ArrayMetadata :  name=params.params.decoder.decoder_norm.scale,  directory=/home/shuningjin/tmp/gcsfuse-jacobplatin/llama4/v5-scout-instruct-bf16-final-scanned/0/items,  shape=(5120,),  sharding=NamedShardingMetadata(shape=[16], axis_names=['checkpoint_sharding_axis'], partition_spec=('checkpoint_sharding_axis',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7), DeviceMetadata(id=8), DeviceMetadata(id=9), DeviceMetadata(id=10), DeviceMetadata(id=11), DeviceMetadata(id=12), DeviceMetadata(id=13), DeviceMetadata(id=14), DeviceMetadata(id=15)]),  dtype=bfloat16,  storage=StorageMetadata(chunk_shape=(320,), write_shape=(320,)),

### ckpt1

In [40]:
# jacob's ckpt from 4/15
# gs://jacobplatin/llama4/v5-scout-instruct-bf16-final-scanned/0/items
_ = inspect("/home/shuningjin/tmp/gcsfuse-jacobplatin/llama4/v5-scout-instruct-bf16-final-scanned/0/items")
# _ = inspect("/home/shuningjin/tmp/llama4/v5-scout-instruct-bf16-final-scanned/0/items")

/home/shuningjin/tmp/gcsfuse-jacobplatin/llama4/v5-scout-instruct-bf16-final-scanned/0/items
215.540 GB

key | params.params.decoder.decoder_norm.scale | (5120,)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel | (5120, 48, 16)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0 | (16, 48, 5120, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1 | (16, 48, 5120, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo | (16, 48, 8192, 5120)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel | (5120, 48, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel | (5120, 48, 8192)
key | params.params.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel | (8192, 48, 5120)
key | params.params.decoder.layers.post_self_attention_layer_norm.scale | (5120, 48)
key | params.params.decoder.layers.pre_self_attention_layer_norm.scale | (5120, 48)


### ckpt2

In [41]:
# my newly generated ckpt with current conversion script
# gs://shuningjin-multipod-dev/llama4-17b-16e/conversion/meta-scanned/0/items
_ = inspect("/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/llama4-17b-16e/conversion/meta-scanned/0/items")

/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/llama4-17b-16e/conversion/meta-scanned/0/items
215.540 GB

key | params.params.decoder.decoder_norm.scale | (5120,)
key | params.params.decoder.layers.Llama4MoEBlock_0.MoeBlock_0.gate.kernel | (5120, 48, 16)
key | params.params.decoder.layers.Llama4MoEBlock_0.MoeBlock_0.wi_0 | (16, 48, 5120, 8192)
key | params.params.decoder.layers.Llama4MoEBlock_0.MoeBlock_0.wi_1 | (16, 48, 5120, 8192)
key | params.params.decoder.layers.Llama4MoEBlock_0.MoeBlock_0.wo | (16, 48, 8192, 5120)
key | params.params.decoder.layers.Llama4MoEBlock_0.shared_experts.wi_0.kernel | (5120, 48, 8192)
key | params.params.decoder.layers.Llama4MoEBlock_0.shared_experts.wi_1.kernel | (5120, 48, 8192)
key | params.params.decoder.layers.Llama4MoEBlock_0.shared_experts.wo.kernel | (8192, 48, 5120)
key | params.params.decoder.layers.post_self_attention_layer_norm.scale | (5120, 48)
key | params.params.decoder.layers.pre_self_attention_layer_norm.scale | (5120, 48)
key | p

### ckpt3

In [42]:
# checkpoint from pretraing
# gs://shuningjin-multipod-dev/scout_pretrain/shuning-llama4-2025-06-22-23-49-54/checkpoints/0/items
_ = inspect("/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/scout_pretrain/shuning-llama4-2025-06-22-23-49-54/checkpoints/0/items")

/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/scout_pretrain/shuning-llama4-2025-06-22-23-49-54/checkpoints/0/items


646.619 GB

key | params.params.decoder.decoder_norm.scale | (5120,)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.gate.kernel | (5120, 12, 16)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.wi_0 | (16, 12, 5120, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.wi_1 | (16, 12, 5120, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.wo | (16, 12, 8192, 5120)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.shared_experts.wi_0.kernel | (5120, 12, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.shared_experts.wi_1.kernel | (5120, 12, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.shared_experts.wo.kernel | (8192, 12, 5120)
key | params.params.decoder.layers.layers_0.post_self_attention_layer_norm.scale | (5120, 12)
key | params.params.decoder.layers.layers_0.pre_self_attention_layer_norm.scale | (5120, 12)
key | params.params.decode

### ckpt 4 (scout new pretrain)

In [43]:
# gs://shuningjin-multipod-dev/scout_pretrain/shuning-llama4-2025-06-29-06-11-10/checkpoints/0/items
_ = inspect("/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/scout_pretrain/shuning-llama4-2025-06-29-06-11-10/checkpoints/0/items")

/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/scout_pretrain/shuning-llama4-2025-06-29-06-11-10/checkpoints/0/items
646.619 GB

key | params.params.decoder.decoder_norm.scale | (5120,)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.gate.kernel | (5120, 12, 16)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.wi_0 | (16, 12, 5120, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.wi_1 | (16, 12, 5120, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.MoeBlock_0.wo | (16, 12, 8192, 5120)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.shared_experts.wi_0.kernel | (5120, 12, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.shared_experts.wi_1.kernel | (5120, 12, 8192)
key | params.params.decoder.layers.layers_0.Llama4MoEBlock_0.shared_experts.wo.kernel | (8192, 12, 5120)
key | params.params.decoder.layers.layers_0.post_self_attention_layer_norm.scale | (5120, 

### ckpt 5 (maverick new pretrain)

In [44]:
# gs://shuningjin-multipod-dev/maverick_pretrain/shuning-llama4-2025-06-29-11-38-39/checkpoints/0/items
inspect("/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/maverick_pretrain/shuning-llama4-2025-06-29-11-38-39/checkpoints/0/items")

/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/maverick_pretrain/shuning-llama4-2025-06-29-11-38-39/checkpoints/0/items
2404.271 GB

key | params.params.decoder.decoder_norm.scale | (5120,)
key | params.params.decoder.layers.layers_0.mlp.wi_0.kernel | (5120, 12, 16384)
key | params.params.decoder.layers.layers_0.mlp.wi_1.kernel | (5120, 12, 16384)
key | params.params.decoder.layers.layers_0.mlp.wo.kernel | (16384, 12, 5120)
key | params.params.decoder.layers.layers_0.post_self_attention_layer_norm.scale | (5120, 12)
key | params.params.decoder.layers.layers_0.pre_self_attention_layer_norm.scale | (5120, 12)
key | params.params.decoder.layers.layers_0.self_attention.key.kernel | (5120, 12, 8, 128)
key | params.params.decoder.layers.layers_0.self_attention.out.kernel | (40, 12, 128, 5120)
key | params.params.decoder.layers.layers_0.self_attention.query.kernel | (5120, 12, 40, 128)
key | params.params.decoder.layers.layers_0.self_attention.value.kernel | (5120, 12, 8, 128)
key | par

TreeMetadata(
  custom_metadata=None
  tree={'opt_state': [{'count': ArrayMetadata :  name=opt_state.0.count,  directory=/home/shuningjin/tmp/gcsfuse-shuningjin-multipod-dev/maverick_pretrain/shuning-llama4-2025-06-29-11-38-39/checkpoints/0/items,  shape=(),  sharding=NamedShardingMetadata(shape=[  1   1 256   1   1   1   1   1   1   1   1   1], axis_names=['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], partition_spec=()) device_mesh=DeviceMetadataMesh(mesh=[[[[[[[[[[[[DeviceMetadata(id=100000)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100032)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100064)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100096)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100128)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100160)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100192)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100224)]]]]]]]]], [[[[[[[[[DeviceMetadata(id=100004)]]]]]]]]], [[[[[[[

### unscanned

In [45]:
# from xlml test: http://shortn/_dILntrz9GX
# gs://maxtext-llama/llama4-17b-16e/2025-06-27/unscanned
_ = inspect("/home/shuningjin/tmp/gcsfuse-maxtext-llama/llama4-17b-16e/2025-06-27/unscanned/0/items")

/home/shuningjin/tmp/gcsfuse-maxtext-llama/llama4-17b-16e/2025-06-27/unscanned/0/items
215.540 GB

key | params.params.decoder.decoder_norm.scale | (5120,)
key | params.params.decoder.layers_0.Llama4MoEBlock_0.MoeBlock_0.gate.kernel | (5120, 16)
key | params.params.decoder.layers_0.Llama4MoEBlock_0.MoeBlock_0.wi_0 | (16, 5120, 8192)
key | params.params.decoder.layers_0.Llama4MoEBlock_0.MoeBlock_0.wi_1 | (16, 5120, 8192)
key | params.params.decoder.layers_0.Llama4MoEBlock_0.MoeBlock_0.wo | (16, 8192, 5120)
key | params.params.decoder.layers_0.Llama4MoEBlock_0.shared_experts.wi_0.kernel | (5120, 8192)
key | params.params.decoder.layers_0.Llama4MoEBlock_0.shared_experts.wi_1.kernel | (5120, 8192)
key | params.params.decoder.layers_0.Llama4MoEBlock_0.shared_experts.wo.kernel | (8192, 5120)
key | params.params.decoder.layers_0.post_self_attention_layer_norm.scale | (5120,)
key | params.params.decoder.layers_0.pre_self_attention_layer_norm.scale | (5120,)
key | params.params.decoder.layers_0

## Mixtral

In [None]:
# gs://ml-auto-solutions/output/sparsity_diffusion_devx/maxtext/chained_tests_mixtral-8x7b_nightly-2024-11-15-01-06-09/unscanned_ckpt/checkpoints/0/items
# bash setup_gcsfuse.sh DATASET_GCS_BUCKET=ml-auto-solutions MOUNT_PATH=/home/shuningjin/tmp/gcsfuse-ml-auto-solutions
inspect("/home/shuningjin/tmp/gcsfuse-ml-auto-solutions/output/sparsity_diffusion_devx/maxtext/chained_tests_mixtral-8x7b_nightly-2024-11-15-01-06-09/unscanned_ckpt/checkpoints/0/items")

/home/shuningjin/tmp/gcsfuse-ml-auto-solutions/output/sparsity_diffusion_devx/maxtext/chained_tests_mixtral-8x7b_nightly-2024-11-15-01-06-09/unscanned_ckpt/checkpoints/0/items
93.406 GB

key | params.params.decoder.decoder_norm.scale | (4096,)
key | params.params.decoder.layers_0.MoeBlock_0.gate.kernel | (4096, 8)
key | params.params.decoder.layers_0.MoeBlock_0.wi_0 | (8, 4096, 14336)
key | params.params.decoder.layers_0.MoeBlock_0.wi_1 | (8, 4096, 14336)
key | params.params.decoder.layers_0.MoeBlock_0.wo | (8, 14336, 4096)
key | params.params.decoder.layers_0.post_self_attention_layer_norm.scale | (4096,)
key | params.params.decoder.layers_0.pre_self_attention_layer_norm.scale | (4096,)
key | params.params.decoder.layers_0.self_attention.key.kernel | (4096, 8, 128)
key | params.params.decoder.layers_0.self_attention.out.kernel | (32, 128, 4096)
key | params.params.decoder.layers_0.self_attention.query.kernel | (4096, 32, 128)
key | params.params.decoder.layers_0.self_attention.value.k

TreeMetadata(
  custom_metadata=None
  tree={'opt_state': {}, 'params': {'params': {'decoder': {'decoder_norm': {'scale': ArrayMetadata :  name=params.params.decoder.decoder_norm.scale,  directory=/home/shuningjin/tmp/gcsfuse-ml-auto-solutions/output/sparsity_diffusion_devx/maxtext/chained_tests_mixtral-8x7b_nightly-2024-11-15-01-06-09/unscanned_ckpt/checkpoints/0/items,  shape=(4096,),  sharding=NamedShardingMetadata(shape=[1 1 1 1 1 1 1 1], axis_names=['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive'], partition_spec=('tensor',)) device_mesh=DeviceMetadataMesh(mesh=[[[[[[[[DeviceMetadata(id=0)]]]]]]]]),  dtype=bfloat16,  storage=StorageMetadata(chunk_shape=(4096,), write_shape=None),}, 'layers_0': {'MoeBlock_0': {'gate': {'kernel': ArrayMetadata :  name=params.params.decoder.layers_0.MoeBlock_0.gate.kernel,  directory=/home/shuningjin/tmp/gcsfuse-ml-auto-solutions/output/sparsity_diffusion_devx/maxtext/chained_tests_mixtral-8x7b_nightly-2024

## other

In [46]:
state = {
    'a': {
        'x': np.arange(2 ** 24),
        'y': np.arange(1024),
    },
    'b': np.ones(8),
    'c': 42,
}

default_param_name = 'a.x'
default_path = epath.Path('/home/shuningjin/tmp/sample_checkpoint')
if default_path.exists():
  default_path.rmtree()
with ocp.StandardCheckpointer() as ckptr:
  ckptr.save(default_path, state)


# gs://jacobplatin/llama4/v5-scout-instruct-bf16-final-scanned/0/items
PATH="/home/shuningjin/tmp/llama4//v5-scout-instruct-bf16-final-scanned/0/items"
path = epath.Path(PATH)
print(path)


metadata = ocp.StandardCheckpointer().metadata(path)
size_counts = collections.defaultdict(int)

def get_arr_bytes(meta):
  dtype = meta.dtype
  shape = meta.shape
  size_counts[dtype] += 1
  return np.prod(shape) * np.dtype(dtype).itemsize

total_bytes = jax.tree.reduce(operator.add, jax.tree.map(get_arr_bytes, metadata))
print('{0:0.3f} GB'.format(float(total_bytes) / 1e9))
print()
print('leaf dtype counts:')
for dtype, count in size_counts.items():
  print(f'{dtype}: {count}')


metadata = ocp.StandardCheckpointer().metadata(path)
metadata_contents = ['.'.join(k) for k in ocp.tree.to_flat_dict(metadata)]
# Here are the parameters present in the checkpoint tree.
for p in metadata_contents:
  print(p)

/home/shuningjin/tmp/llama4/v5-scout-instruct-bf16-final-scanned/0/items


FileNotFoundError: Metadata file (named _METADATA) does not exist at /home/shuningjin/tmp/llama4/v5-scout-instruct-bf16-final-scanned/0/items.

In [None]:
async def disk_usage(path: epath.Path) -> int:
  """Returns the size of the checkpoint on disk.

  Note: this uses recurision because Orbax checkpoint directories are never
  more than a few levels deep.

  Args:
    path: The path to the checkpoint.
  Returns:
    The size of the checkpoint on disk.
  """

  async def helper(p):
    if p.is_dir():
      return await disk_usage(p)
    else:
      stat = await ocp.path.async_utils.async_stat(path)
      return stat.length

  futures = []
  for p in path.iterdir():
    futures.append(helper(p))
  return sum(await asyncio.gather(*futures))

print('{0:0.3f} GB'.format(float(asyncio.run(disk_usage(path))) / 1e9))

In [None]:
# Note: instead of "file", use:
#   - "gfile" on Google-internal filesystems.
#   - "gs" on GCS (do not repeat the "gs://" prefix)
ts_contents = ts.KvStore.open({"driver": "ocdbt", "base": f"file://{path.as_posix()}"}).result().list().result()
ts_contents = [p.decode("utf-8") for p in ts_contents]
ts_contents = [p.replace('.zarray', '')[:-1] for p in ts_contents if '.zarray' in p]

# We can assert that the parameters tracked by the metadata file are
# the same as those tracked by Tensorstore. If there is a discrepancy, there may
# be a deeper underlying problem.

assert len(metadata_contents) == len(ts_contents) and sorted(metadata_contents) == sorted(ts_contents)


In [None]:
path = ""  # @param {type:"string"}
# The `param_name` can be obtained by inspecting tree metadata (see above).
param_name = ""  # @param {type:"string"}
path = default_path or epath.Path(path)
param_name = default_param_name or param_name

In [None]:
metadata = ocp.StandardCheckpointer().metadata(path)
value_metadata = {'.'.join(k): v for k, v in ocp.tree.to_flat_dict(metadata).items()}[param_name]


In [None]:
print(f'shape: {value_metadata.shape}')
print(f'dtype: {value_metadata.dtype}')

In [None]:
ParamInfo = ocp.type_handlers.ParamInfo
ts_context = ts.Context({
    'file_io_concurrency': {'limit': 128},
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
})

info = ParamInfo(name=param_name, path=path / param_name, parent_dir=path, is_ocdbt_checkpoint=True, use_zarr3=False)
tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True)

t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result()
arr = t.read().result()
print(arr)

In [48]:
import numpy as np
np.zeros((8,) + (4,3)).shape

(8, 4, 3)