In [2]:
# %% [markdown]
# # 批量排序 MUSK-feature `.pt`
# - 按 (x,y) 坐标对 `bag_feats / coords` 排序  
# - 默认保存为 `<原文件名>_sorted.pt`，如需覆盖原文件把 `suffix` 设为 `""`  
# - 建议在 **CPU** 上处理，避免显存占用  
# - 如需并行，可自行把最底下的 `for` 换成 `joblib.Parallel` / `multiprocessing.Pool`

# %%  导入依赖
import os
from pathlib import Path
import torch
import numpy as np
from tqdm import tqdm

# %%  参数：修改这里即可
root_dir = "/remote-home/share/lisj/Workspace/SOTA_NAS/datasets/core/MUSK-feature"  # 待处理文件夹
suffix    = ""      # 保存文件名后缀："" 表示覆盖

# %%  核心函数：处理单个 .pt
def sort_one_pt(pt_path: Path, suffix: str = "_sorted") -> None:
    """加载→校验→按 coords 排序→保存"""
    try:
        data = torch.load(pt_path, map_location="cpu", weights_only=False)
    except Exception as e:
        print(f"[❌] 载入失败 {pt_path.name}: {e}")
        return

    if not isinstance(data, dict) or "bag_feats" not in data or "coords" not in data:
        print(f"[⚠️] {pt_path.name} 缺少 bag_feats / coords，跳过")
        return
    feats, coords = data["bag_feats"], data["coords"]
    if coords.ndim != 2 or coords.shape[1] != 2:
        print(f"[⚠️] {pt_path.name} coords 维度异常 {coords.shape}，跳过")
        return
    if feats.shape[0] != coords.shape[0]:
        print(f"[⚠️] {pt_path.name} feats/coords 行数不匹配，跳过")
        return

    # 排序
    idx = np.lexsort((coords[:, 1].numpy(), coords[:, 0].numpy()))
    idx_t = torch.from_numpy(idx)
    sorted_feats  = feats[idx_t]
    sorted_coords = coords[idx_t]

    out_path = pt_path if suffix == "" else pt_path.with_name(pt_path.stem + suffix + pt_path.suffix)
    torch.save({"bag_feats": sorted_feats, "coords": sorted_coords}, out_path)

# %%  批量遍历
root = Path(root_dir).expanduser().resolve()
pt_files = sorted(root.glob("*.pt"))
if not pt_files:
    print(f"在 {root} 下没有找到 .pt 文件")
else:
    print(f"共发现 {len(pt_files)} 个 .pt，开始处理…")
    for f in tqdm(pt_files, unit="file"):
        sort_one_pt(f, suffix=suffix)
    print("✅ 全部完成！")

# %% [markdown]
# ## 抽样验证（可选）
# 运行下方单元查看任意一个 *_sorted.pt 是否已经按 `(x,y)` 递增

# %%
sample_file = pt_files[0].with_name(pt_files[0].stem + suffix + ".pt")
data = torch.load(sample_file, map_location="cpu")
print("前 5 行 coords：\n", data["coords"][:5])
print("后 5 行 coords：\n", data["coords"][-5:])


共发现 560 个 .pt，开始处理…


100%|██████████| 560/560 [00:43<00:00, 12.97file/s]

✅ 全部完成！
前 5 行 coords：
 tensor([[146, 315],
        [146, 316],
        [146, 317],
        [146, 334],
        [146, 335]], dtype=torch.int32)
后 5 行 coords：
 tensor([[289, 381],
        [289, 382],
        [289, 383],
        [289, 384],
        [289, 385]], dtype=torch.int32)



