In [9]:
import json
import yaml
from pathlib import Path
from typing import Any, Dict

def convert_mia_json_to_yaml(json_path: str | Path, yaml_path: str | None = None) -> Path:
    """
    将 MIA 结果 JSON 转换为结构化 YAML 文件
    （保留 threshold、direction、stats、ecdf、calibrator 等信息）

    参数：
    ----------
    json_path : str | Path
        输入 JSON 文件路径
    yaml_path : str | None
        输出 YAML 文件路径（若不指定则自动替换扩展名）

    返回：
    ----------
    Path: 生成的 YAML 文件路径
    """
    json_path = Path(json_path)
    if not json_path.exists():
        raise FileNotFoundError(f"❌ 找不到 JSON 文件: {json_path}")

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # 提取顶层元信息
    data_name = data.get("data_name")
    alpha = data.get("alpha")
    counts = data.get("counts", {})

    # 目标 YAML 结构
    yaml_dict: Dict[str, Any] = {
        "data_name": data_name,
        "alpha": alpha,
        "total": counts.get("total"),
        "member": counts.get("member"),
        "non_member": counts.get("non_member"),
        "thresholds": {}
    }

    for item in data.get("items", []):
        metric = item["metric"]
        direction = item.get("direction")
        threshold_bestJ = item.get("threshold_bestJ")
        threshold_fpr_alpha = item.get("threshold_fpr_alpha")

        # 根据 metric 名拆分层级（如 mink++_0.1 -> group=mink++, key=0.1）
        if "_" in metric and any(ch.isdigit() for ch in metric.split("_")[-1]):
            group = "_".join(metric.split("_")[:-1])
            key = metric.split("_")[-1]
        else:
            group = metric
            key = None

        if group not in yaml_dict["thresholds"]:
            yaml_dict["thresholds"][group] = {}

        entry = {
            "direction": direction,
            "threshold_bestJ": threshold_bestJ,
            "threshold_fpr_alpha": threshold_fpr_alpha,
        }

        # 可选：附加统计信息
        if "stats" in item:
            entry["stats"] = item["stats"]
        if "ecdf" in item:
            entry["ecdf"] = item["ecdf"]
        if "calibrator" in item:
            entry["calibrator"] = item["calibrator"]

        if key is not None:
            yaml_dict["thresholds"][group][key] = entry
        else:
            yaml_dict["thresholds"][group] = entry

    # 自动生成输出路径
    if yaml_path is None:
        yaml_path = json_path.with_suffix(".yaml")

    # 写出 YAML
    with open(yaml_path, "w", encoding="utf-8") as f:
        yaml.dump(yaml_dict, f, sort_keys=False, allow_unicode=True)

    print(f"✅ 已生成 YAML: {yaml_path}")
    return Path(yaml_path)

In [10]:
convert_mia_json_to_yaml("llama_data/threshold_WikiMIA_length256.json")

✅ 已生成 YAML: llama_data/threshold_WikiMIA_length256.yaml


PosixPath('llama_data/threshold_WikiMIA_length256.yaml')