In [None]:
#!/usr/bin/env python3
import os, sys, re, json, time
from collections import defaultdict
import torch

CANDIDATES = ['state_dict', 'model', 'state_dict_ema', 'ema', 'module']

def human(n):
    for unit in ['','K','M','B']:
        if abs(n) < 1000.0:
            return f"{n:,.0f}{unit}"
        n /= 1000.0
    return f"{n:.1f}T"

def find_state_dict(obj):
    """
    Return (state_dict, origin_key) where state_dict is a flat dict[str, Tensor]
    """
    if isinstance(obj, dict):
        # 1) 常见键
        for k in CANDIDATES:
            if k in obj and isinstance(obj[k], dict) and any(isinstance(v, torch.Tensor) for v in obj[k].values()):
                return obj[k], k
        # 2) 任何含 tensor 的子字典
        for k, v in obj.items():
            if isinstance(v, dict) and any(isinstance(t, torch.Tensor) for t in v.values()):
                return v, k
        # 3) 根就是 state_dict
        if any(isinstance(v, torch.Tensor) for v in obj.values()):
            return obj, '<root>'
    raise RuntimeError("未能在 ckpt 中找到像 state_dict 的字典（里面含有 tensor）。")

def strip_prefix(name):
    # 仅用于可选对齐观察，不改变主打印；把常见前缀去掉方便你比对
    for p in ('model.', 'module.'):
        if name.startswith(p):
            name = name[len(p):]
            break
    name = re.sub(r'^keypoint_head\.', 'head.', name)
    return name

def main(path, save_all=True):
    assert os.path.isfile(path), f"文件不存在: {path}"
    obj = torch.load(path, map_location='cpu')
    ts  = time.strftime("%Y%m%d_%H%M%S")
    out = os.path.join(os.path.dirname(path), f"ckpt_dump_{os.path.basename(path)}_{ts}.txt")

    lines = []
    lines.append(f"# CKPT FILE: {path}")
    lines.append(f"# TYPE: {type(obj)}")
    if isinstance(obj, dict):
        toplvl_keys = list(obj.keys())
        lines.append(f"# TOP-LEVEL KEYS ({len(toplvl_keys)}): {toplvl_keys}")

    # 取真正的 state_dict
    sd, origin = find_state_dict(obj)
    lines.append(f"# STATE_DICT found at: {origin}  (keys={len(sd)})")

    # 统计
    total_params = 0
    groups = defaultdict(int)   # 前缀计数
    shapes = {}
    dtypes = {}
    for k, v in sd.items():
        n = v.numel()
        total_params += n
        head = k.split('.')[0]
        groups[head] += 1
        shapes[k] = tuple(v.shape)
        dtypes[k] = str(v.dtype)

    lines.append(f"# TOTAL TENSORS: {len(sd)}  TOTAL PARAMS: {human(total_params)}")
    lines.append("# GROUP COUNT (by first prefix before dot):")
    for g, c in sorted(groups.items(), key=lambda x: (-x[1], x[0])):
        lines.append(f"- {g}: {c}")

    # 预览一些键
    preview = 30
    lines.append(f"# SAMPLE KEYS (first {preview}):")
    for i, (k, v) in enumerate(sd.items()):
        if i >= preview: break
        lines.append(f"{k:80s} {tuple(v.shape)} {v.dtype}")

    lines.append("# " + "-"*80)
    lines.append("# FULL LIST (name -> shape, dtype):")
    for k, v in sd.items():
        lines.append(f"{k:80s} {tuple(v.shape)} {v.dtype}")

    # 可选：给你一个“去前缀”的镜像，方便和我们模型里的名字比
    lines.append("# " + "-"*80)
    lines.append("# STRIPPED NAMES (model./module./keypoint_head.→head.) PREVIEW (first 30):")
    i = 0
    for k, v in sd.items():
        lines.append(f"{strip_prefix(k):80s} {tuple(v.shape)} {v.dtype}")
        i += 1
        if i >= 30: break

    text = "\n".join(lines)

    # 打印到控制台（会很长）
    print(text)

    # 同时保存到文件，便于 grep / 对照
    if save_all:
        with open(out, "w", encoding="utf-8") as f:
            f.write(text)
        print(f"\n[OK] 完整键名与形状已写入：{out}")


main("pretrained/swin_tiny.pth")

  from .autonotebook import tqdm as notebook_tqdm


# CKPT FILE: pretrained/best_coco_AP_epoch_210.pth
# TYPE: <class 'dict'>
# TOP-LEVEL KEYS (3): ['meta', 'state_dict', 'message_hub']
# STATE_DICT found at: state_dict  (keys=165)
# TOTAL TENSORS: 165  TOTAL PARAMS: 90M
# GROUP COUNT (by first prefix before dot):
- backbone: 149
- head: 16
# SAMPLE KEYS (first 30):
backbone.pos_embed                                                               (1, 192, 768) torch.float32
backbone.patch_embed.projection.weight                                           (768, 3, 16, 16) torch.float32
backbone.patch_embed.projection.bias                                             (768,) torch.float32
backbone.layers.0.ln1.weight                                                     (768,) torch.float32
backbone.layers.0.ln1.bias                                                       (768,) torch.float32
backbone.layers.0.attn.qkv.weight                                                (2304, 768) torch.float32
backbone.layers.0.attn.qkv.bias                  