In [None]:
import torch
from safetensors import safe_open
import argparse
import os
import pathlib

# example cmd: python3 ckpt.py --ckpt_dir=/home/ranran/ran_ckpt/hf-16b/hf/ --checkpoint_type=safetensors

CHECKPOINT_TYPES = ("pth", "safetensors")

def print_nested_keys(data, prefix=""):
  """
  Prints nested keys of a dictionary-like structure in a directory-like format.
  Args:
      data: The dictionary-like structure to traverse.
      prefix: The current path prefix.
  """
  if isinstance(data, dict):
    for key, value in data.items():
      current_path = f"{prefix}{key}."
      print_nested_keys(value, current_path)
  else:
    print(f"key: {prefix}")
    print(f"value shape: {data.shape}")


def load_pth_checkpoint(ckpt_paths):
  chkpt_vars_raw = {}
  for i, ckpt_path in enumerate(ckpt_paths):
    print(f"Loading checkpointpath {i+1} of {len(ckpt_paths)} ...")
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    chkpt_vars_raw[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint  
  print_nested_keys(chkpt_vars_raw)


def load_safetensors_checkpoint(ckpt_paths):
  chkpt_vars_raw = {}
  for i, ckpt_path in enumerate(ckpt_paths):
    print(f"Loading checkpoint path {i+1} of {len(ckpt_paths)} ...")
    with safe_open(ckpt_path, framework="pt") as f:
      for k in f.keys():
        chkpt_vars_raw[k] = f.get_tensor(k)
  print_nested_keys(chkpt_vars_raw)


def main(argv):
  parser = argparse.ArgumentParser(description="Print the contents (keys and shapes) of safetensors and PyTorch checkpoint files.")
  parser.add_argument("--ckpt_dir", type=str, required=True)
  parser.add_argument("--checkpoint_type", type=str, required=True)
  args = parser.parse_args(argv)
  print(args)

  if args.checkpoint_type not in CHECKPOINT_TYPES:
    raise NotImplementedError

  ckpt_paths = sorted(pathlib.Path(args.ckpt_dir).glob(f"[!.]*.{args.checkpoint_type}"))
  if args.checkpoint_type == "safetensors":
    load_safetensors_checkpoint(ckpt_paths)
  else:
    assert args.checkpoint_type == "pth"
    load_pth_checkpoint(ckpt_paths)


In [16]:
cmd = "--ckpt_dir=/home/shuningjin/llama4-17b-16e/meta-bf16 --checkpoint_type=pth".split()
main(cmd)

Namespace(ckpt_dir='/home/shuningjin/llama4-17b-16e/meta-bf16', checkpoint_type='pth')
hi
Loading checkpointpath 1 of 8 ...
Loading checkpointpath 2 of 8 ...
Loading checkpointpath 3 of 8 ...
Loading checkpointpath 4 of 8 ...
Loading checkpointpath 5 of 8 ...
Loading checkpointpath 6 of 8 ...
Loading checkpointpath 7 of 8 ...
Loading checkpointpath 8 of 8 ...
key: 0.tok_embeddings.weight.
value shape: torch.Size([25256, 5120])
key: 0.vision_projection.weight.
value shape: torch.Size([640, 4096])
key: 0.norm.weight.
value shape: torch.Size([5120])
key: 0.output.weight.
value shape: torch.Size([25256, 5120])
key: 0.layers.0.feed_forward.norm.weight.
value shape: torch.Size([5120])
key: 0.layers.0.attention.wo.weight.
value shape: torch.Size([5120, 640])
key: 0.layers.0.feed_forward.global_gate_stats_3E.
value shape: torch.Size([3, 16])
key: 0.layers.0.feed_forward.expert_activation_DE.
value shape: torch.Size([5120, 16])
key: 0.layers.0.feed_forward.running_gate_stats_3E.
value shape: to