In [None]:
from typing import List, Sequence

def build_subtree_counts(values: Sequence[int], n: int, h: int) -> List[List[int]]:
    """
    构建高度为 h 的二叉树的计数信息。
    
    参数：
        values:  整数序列，每个值在 [0, 2^n)。
        n:       二进制位数（所有值都视作 n 位二进制）。
        h:       树的高度（根为 0 层，总层数 = h + 1），要求 h <= n。
    
    返回：
        counts: 长度为 h+1 的列表。
                counts[l] 是一个长度为 2^l 的列表，
                counts[l][k] 表示第 l 层第 k 个节点的计数。
    """
    if h < 0:
        raise ValueError("h 必须 >= 0")
    if h > n:
        raise ValueError("要求 h <= n，否则前缀长度超过可用比特数")

    # 检查取值范围
    upper = 1 << n
    for x in values:
        if not (0 <= x < upper):
            raise ValueError(f"值 {x} 超出范围 [0, 2^{n})")

    # 初始化 counts
    counts: List[List[int]] = [ [0] * (1 << l) for l in range(h + 1) ]

    for x in values:
        # 根节点
        counts[0][0] += 1

        # 后续各层
        for l in range(1, h + 1):
            # 取高位前 l 位作为前缀索引
            prefix = x >> (n - l)
            counts[l][prefix] += 1

    return counts


def print_tree_counts(counts: List[List[int]]) -> None:
    """
    简单按层打印二叉树计数结果。
    
    参数：
        counts: build_subtree_counts 的返回值。
    """
    h = len(counts) - 1
    for l in range(h + 1):
        level_counts = counts[l]
        # 把这一层的每个节点的计数打印出来
        # 用固定宽度稍微美观一点
        line = " ".join(f"{c:4d}" for c in level_counts)
        print(f"Level {l:2d}: {line}")


Level  0:    8
Level  1:    5    3
Level  2:    4    1    2    1
Level  3:    2    2    0    1    2    0    0    1


In [5]:
import os
import re
import torch
from glob import glob

ckpt_dir = "/home/gaochi2/SeLM_v2/ckpt/cte"

# 正则精确匹配你生成的文件格式
pattern = re.compile(
    r"N(?P<N>\d+)_vocab(?P<vocab>\d+)_ps(?P<ps>0\.\d+)_ratio(?P<ratio>[0-9.]+)\.ckpt$"
)

# 只选择符合该格式的文件
candidate_files = []
for f in glob(os.path.join(ckpt_dir, "*.ckpt")):
    if pattern.search(os.path.basename(f)):
        candidate_files.append(f)

candidate_files = sorted(candidate_files)
print(f"[INFO] Found {len(candidate_files)} matched ckpt files")

# 加载 train_locations
extracted = {}  # { filename : {"train_locations": tensor, shape, dtype} }

for f in candidate_files:
    try:
        data = torch.load(f, map_location="cpu")
    except Exception as e:
        print(f"[WARN] Failed to load {f}: {e}")
        continue

    if "train_locations" in data:
        tl = data["train_locations"]  # (N, tp)
        extracted[f] = {
            "train_locations": tl,
            "shape": tuple(tl.shape),
            "dtype": tl.dtype,
        }
        print(f"[OK] {os.path.basename(f)}  shape={tuple(tl.shape)}")
    else:
        print(f"[SKIP] No train_locations in {os.path.basename(f)}")

len(extracted)


[INFO] Found 324 matched ckpt files
[OK] N2048_vocab384_ps0.125_ratio0.00.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.01.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.02.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.04.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.08.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.16.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.32.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.64.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.80.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.95.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio0.99.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.125_ratio1.00.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.5_ratio0.00.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.5_ratio0.01.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.5_ratio0.02.ckpt  shape=(2048, 4)
[OK] N2048_vocab384_ps0.5_ratio0.04.ckpt  shape=(2048, 4)
[OK] N2048_v

  data = torch.load(f, map_location="cpu")


[OK] N32768_vocab384_ps0.125_ratio0.04.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio0.08.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio0.16.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio0.32.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio0.64.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio0.80.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio0.95.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio0.99.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.125_ratio1.00.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.5_ratio0.00.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.5_ratio0.01.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.5_ratio0.02.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.5_ratio0.04.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.5_ratio0.08.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.5_ratio0.16.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_ps0.5_ratio0.32.ckpt  shape=(32768, 4)
[OK] N32768_vocab384_p

324

In [13]:
import torch
import os
from typing import List

def compute_tree_distribution_for_file(
    file_path: str,
    h: int,
    n: int,
):
    """
    对给定 ckpt 文件，输出所有维度的树分布。
    要求 ckpt 中包含键 'train_locations'，其形状为 (N, tp)。

    返回：
        per_dim_counts: List[ counts_for_dim_i ]
        其中 counts_for_dim_i = build_subtree_counts(values, n, h)
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(file_path)

    data = torch.load(file_path, map_location="cpu")

    if "train_locations" not in data:
        raise KeyError(f"'train_locations' not found in {file_path}")

    loc = data["train_locations"]     # (N, tp)
    if loc.ndim != 2:
        raise ValueError(f"train_locations must be (N, tp), got {loc.shape}")

    N, tp = loc.shape

    # 转为 Python int 方便处理
    loc_np = loc.cpu().numpy()

    # 存储每个维度的树分布
    per_dim_counts: List[List[List[int]]] = []

    for dim in range(tp):
        # 第 dim 个维度的所有路径值
        values = loc_np[:, dim].tolist()

        # 因为 CT coordinate 是 [-n, n)，我们取绝对值作为路径整数
        # sign 在相似度中出现，但树分布只看绝对值，因此这里取 abs()。
        # 若你希望考虑符号层，则告诉我，我会改。
        pos_values = []
        for v in values:
            if v > 0:
                pos_values.append(v)
        values = pos_values

        # 构建这个维度的树分布
        counts = build_subtree_counts(values, n, h)

        per_dim_counts.append(counts)

    return per_dim_counts

file_path = os.path.join(ckpt_dir, "N2048_vocab384_ps0.125_ratio0.00.ckpt")

tree_dist = compute_tree_distribution_for_file(
    file_path=file_path,
    h=5,
    n=11  # 你需要根据你的 CT 编码填写
)

# tree_dist 是一个 list，有 tp 项，每项是一棵树
for dim, counts in enumerate(tree_dist):
    print(f"===== dim {dim} =====")
    print_tree_counts(counts)  # 使用你之前写的 print 接口


===== dim 0 =====
Level  0: 1029
Level  1:  563  466
Level  2:  219  344  300  166
Level  3:   82  137  186  158  192  108  120   46
Level  4:   34   48   92   45   97   89   92   66   81  111   57   51   73   47   23   23
Level  5:    0   34    4   44   71   21   33   12   20   77   38   51   68   24   65    1   29   52   66   45   46   11   22   29    8   65   23   24   20    3   15    8
===== dim 1 =====
Level  0: 1023
Level  1:  427  596
Level  2:  218  209  270  326
Level  3:  142   76  107  102  217   53  137  189
Level  4:   64   78   25   51   42   65   31   71   90  127   51    2   85   52   79  110
Level  5:   49   15   21   57    4   21   35   16   21   21   33   32   20   11   55   16    0   90   46   81   36   15    2    0   33   52    1   51   49   30   91   19
===== dim 2 =====
Level  0:  950
Level  1:  470  480
Level  2:  272  198  314  166
Level  3:  100  172  122   76  149  165   66  100
Level  4:   43   57   81   91   60   62   11   65  111   38   66   99   33   33  

  data = torch.load(file_path, map_location="cpu")
