In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-


# Jay路径
TRAIN_POOL_PATH = "/Users/yioha_/Desktop/Small-Data/data/splits/train_pool.csv"
OUT_SPLITS_DIR  = "/Users/yioha_/Desktop/Small-Data/data/splits"
REPORTS_DIR     = "/Users/yioha_/Desktop/Small-Data/reports"

# Iris路径
# TRAIN_POOL_PATH = "/Users/iriswu/Desktop/3001 Small Data/Small-Data/data/splits/train_pool.csv"
# OUT_SPLITS_DIR  = "/Users/iriswu/Desktop/3001 Small Data/Small-Data/data/splits"
# REPORTS_DIR     = "/Users/iriswu/Desktop/3001 Small Data/Small-Data/reports"

# Haydee路径
# TRAIN_POOL_PATH = "/Users/yinghanding/Desktop/Small-Data/data/splits/train_pool.csv"
# OUT_SPLITS_DIR  = "/Users/yinghanding/Desktop/Small-Data/data/splits"
# REPORTS_DIR     = "/Users/yinghanding/Desktop/Small-Data/reports"
# ===========================

from __future__ import annotations
from pathlib import Path
from collections import Counter
from typing import Dict, List
import json, math
import pandas as pd

# 固定参数
RANDOM_STATE = 2025
SIZES = [50, 200, 1000, 5000]
REQUIRED_COLS = {"query", "tools", "gold_call"}
TOOL_KEYS = ("tool_name", "name")  # gold_call 里解析工具名优先顺序

# 路径对象
TRAIN_POOL_PATH = Path(TRAIN_POOL_PATH)
OUT_SPLITS_DIR  = Path(OUT_SPLITS_DIR)
REPORTS_DIR     = Path(REPORTS_DIR)
REPORT_PATH     = REPORTS_DIR / "subset_stats.md"

def ensure_dirs(*dirs: Path) -> None:
    for d in dirs:
        d.mkdir(parents=True, exist_ok=True)

def read_train_pool(path: Path) -> pd.DataFrame:
    if not path.exists():
        raise FileNotFoundError(f"找不到训练池文件：{path}")
    df = pd.read_csv(path, dtype=str)
    missing = REQUIRED_COLS - set(df.columns)
    if missing:
        raise ValueError(f"train_pool.csv 缺少必要列: {missing}")
    if "_row_id" not in df.columns:
        df.insert(0, "_row_id", range(len(df)))
    return df

def parse_tool_name(gold_call_str: str) -> str:
    if not isinstance(gold_call_str, str) or not gold_call_str.strip():
        return "__UNKNOWN__"
    try:
        obj = json.loads(gold_call_str)
        if isinstance(obj, dict):
            for k in TOOL_KEYS:
                if k in obj and isinstance(obj[k], str):
                    return obj[k]
        if isinstance(obj, list) and obj:
            first = obj[0]
            if isinstance(first, dict):
                for k in TOOL_KEYS:
                    if k in first and isinstance(first[k], str):
                        return first[k]
    except Exception:
        pass
    return "__UNKNOWN__"

def add_label(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["_label_tool"] = [parse_tool_name(x) for x in df["gold_call"].tolist()]
    return df

def proportional_allocate(counts: Dict[str, int], k: int) -> Dict[str, int]:
    """按总体分布比例为每个工具分配样本数，总和为 k，且不超过各类可用量。"""
    total = sum(counts.values())
    k = min(k, total)
    if k <= 0:
        return {c: 0 for c in counts}

    raw = {c: k * counts[c] / total for c in counts}
    base = {c: int(math.floor(v)) for c, v in raw.items()}
    rem = k - sum(base.values())

    # 小数部分从大到小分配余数
    frac_sorted = sorted(((raw[c] - base[c], c) for c in counts), reverse=True)
    for i in range(rem):
        base[frac_sorted[i % len(frac_sorted)][1]] += 1

    # 不超过各类可用
    base = {c: min(base[c], counts[c]) for c in counts}

    # 若仍不足，按剩余容量分配
    deficit = k - sum(base.values())
    if deficit > 0:
        room = {c: counts[c] - base[c] for c in counts}
        order = sorted(room.items(), key=lambda x: (-x[1], x[0]))
        idx = 0
        while deficit > 0 and any(v > 0 for v in room.values()):
            c = order[idx % len(order)][0]
            if room[c] > 0:
                base[c] += 1
                room[c] -= 1
                deficit -= 1
            idx += 1

    assert sum(base.values()) == k
    return base

def make_nested_subsets(df_labeled: pd.DataFrame, sizes: List[int], seed: int) -> Dict[int, pd.DataFrame]:
    """
    生成嵌套分层子集：
      1) 每个类别生成稳定乱序的索引队列（random_state=seed+offset）
      2) 对每个目标 size，取到该 size 为止的“配额”数量
      3) 组合各类前缀 → 得到嵌套子集（50 ⊂ 200 ⊂ 1000 ⊂ 5000）
    """
    label_col = "_label_tool"
    counts = Counter(df_labeled[label_col])
    classes = sorted(counts.keys())

    # 为每个类别生成稳定乱序队列（可复现）
    per_class_queue: Dict[str, List[int]] = {}
    for i, cls in enumerate(classes):
        idx = df_labeled.index[df_labeled[label_col] == cls]
        per_class_queue[cls] = idx.to_series().sample(frac=1.0, random_state=seed + i).tolist()

    taken_per_class = {c: 0 for c in classes}
    subsets: Dict[int, pd.DataFrame] = {}
    total_n = len(df_labeled)

    for target in sorted(sizes):
        target = min(target, total_n)
        alloc_total = proportional_allocate(counts, target)
        # 本轮新增：推进每类已取数量到 alloc_total
        for c in classes:
            need = max(0, alloc_total[c] - taken_per_class[c])
            if need > 0:
                taken_per_class[c] += need

        # 当前子集 = 每类已取前缀并集
        cur_idx = []
        for c in classes:
            cur_idx.extend(per_class_queue[c][:taken_per_class[c]])
        cur_idx = sorted(set(cur_idx))

        # 仅输出三列
        subsets[target] = df_labeled.loc[cur_idx, ["query", "tools", "gold_call"]].copy()

    return subsets

def write_subset_stats(pool_labeled: pd.DataFrame,
                       subsets: Dict[int, pd.DataFrame],
                       report_path: Path) -> None:
    lines: List[str] = []
    pool_tools = set(pool_labeled["_label_tool"].unique())

    lines.append("# Subset Coverage Statistics\n\n")
    lines.append(f"- train_pool size: **{len(pool_labeled)}**\n")
    lines.append(f"- train_pool unique tools: **{len(pool_tools)}**\n\n")

    for size in sorted(subsets.keys()):
        df = subsets[size].copy()
        df["_label_tool"] = [parse_tool_name(x) for x in df["gold_call"]]
        tools = set(df["_label_tool"].unique())
        cov = len(tools) / max(1, len(pool_tools))
        lines.append(f"## train_{size}\n\n")
        lines.append(f"- size: **{len(df)}**\n")
        lines.append(f"- unique tools: **{len(tools)}**\n")
        lines.append(f"- tool coverage vs train_pool: **{cov:.2%}**\n\n")

        cnt = df["_label_tool"].value_counts().rename("count").to_frame()
        cnt.index.name = "tool"
        cnt["share"] = (cnt["count"] / len(df)).round(6)
        lines.append("| tool | count | share |\n")
        lines.append("|---|---:|---:|\n")
        for tool, row in cnt.sort_values(["count", "tool"], ascending=[False, True]).iterrows():
            lines.append(f"| {tool} | {int(row['count'])} | {row['share']:.6f} |\n")
        lines.append("\n")

    report_path.write_text("".join(lines), encoding="utf-8")

def main():
    ensure_dirs(OUT_SPLITS_DIR, REPORTS_DIR)
    df_pool = read_train_pool(TRAIN_POOL_PATH)
    df_pool = add_label(df_pool)

    subsets = make_nested_subsets(df_pool, SIZES, RANDOM_STATE)

    # 写 CSV（UTF-8，含表头）
    for size, df in subsets.items():
        df.to_csv(OUT_SPLITS_DIR / f"train_{size}.csv", index=False, encoding="utf-8")

    # 覆盖率统计
    write_subset_stats(df_pool, subsets, REPORT_PATH)

    print("✅ Done.")
    for size in sorted(subsets.keys()):
        print(OUT_SPLITS_DIR / f"train_{size}.csv")
    print(REPORT_PATH)

if __name__ == "__main__":
    main()


✅ Done.
/Users/yioha_/Desktop/Small-Data/data/splits/train_50.csv
/Users/yioha_/Desktop/Small-Data/data/splits/train_200.csv
/Users/yioha_/Desktop/Small-Data/data/splits/train_1000.csv
/Users/yioha_/Desktop/Small-Data/data/splits/train_5000.csv
/Users/yioha_/Desktop/Small-Data/reports/subset_stats.md
