In [4]:
# 记得改路径！

# Haydee
input_path="/Users/yinghanding/Desktop/Small-Data/data/processed/apigen_normalized.csv"
out_splits_dir="/Users/yinghanding/Desktop/Small-Data/splits"
reports_dir="/Users/yinghanding/Desktop/Small-Data/reports"

# Iris
# input_path="/Users/iriswu/Desktop/3001 Small Data/Small-Data/data/processed/apigen_normalized.csv"
# out_splits_dir="/Users/iriswu/Desktop/3001 Small Data/Small-Data/splits"
# reports_dir="/Users/iriswu/Desktop/3001 Small Data/Small-Data/reports"

# Jay
# input_path="/Users/yioha_/Desktop/Small-Data/data/processed/apigen_normalized.csv"
# out_splits_dir="/Users/yioha_/Desktop/Small-Data/splits"
# reports_dir="/Users/yioha_/Desktop/Small-Data/reports"


In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations
import argparse
import json
import math
import os
from collections import Counter
from dataclasses import dataclass
from typing import Dict, List
import pandas as pd


@dataclass
class Args:
    input: str
    out_splits_dir: str
    reports_dir: str
    test_size: int
    seed: int
    tool_key: str = "tool_name"


def ensure_dirs(*dirs: str) -> None:
    for d in dirs:
        os.makedirs(d, exist_ok=True)


def read_normalized_csv(path: str) -> pd.DataFrame:
    df = pd.read_csv(path, dtype=str)
    required = {"query", "tools", "gold_call"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"缺少必要列: {missing}. 需要列: {sorted(required)}")
    if "_row_id" not in df.columns:
        df.insert(0, "_row_id", range(len(df)))
    return df


def parse_tool_name(gold_call_str: str, tool_key: str) -> str:
    if not isinstance(gold_call_str, str) or not gold_call_str.strip():
        return "__UNKNOWN__"
    try:
        obj = json.loads(gold_call_str)
        if isinstance(obj, dict):
            if tool_key in obj and isinstance(obj[tool_key], str):
                return obj[tool_key]
            if "name" in obj and isinstance(obj["name"], str):
                return obj["name"]
        if isinstance(obj, list) and obj:
            first = obj[0]
            if isinstance(first, dict):
                if tool_key in first and isinstance(first[tool_key], str):
                    return first[tool_key]
                if "name" in first and isinstance(first["name"], str):
                    return first["name"]
    except Exception:
        pass
    return "__UNKNOWN__"


def add_stratify_label(df: pd.DataFrame, tool_key: str) -> pd.DataFrame:
    labels = [parse_tool_name(x, tool_key) for x in df["gold_call"].tolist()]
    df = df.copy()
    df["_label_tool"] = labels
    return df


def proportional_allocate(counts: Dict[str, int], k: int) -> Dict[str, int]:
    total = sum(counts.values())
    if k > total:
        raise ValueError(f"test_size={k} 大于样本总数={total}")

    raw_alloc = {c: (k * counts[c] / total) for c in counts}
    base = {c: int(math.floor(v)) for c, v in raw_alloc.items()}
    remainder = k - sum(base.values())

    frac_sorted = sorted(((raw_alloc[c] - base[c], c) for c in counts), reverse=True)
    for i in range(remainder):
        base[frac_sorted[i % len(frac_sorted)][1]] += 1

    available = counts.copy()
    alloc = {c: min(base[c], available[c]) for c in counts}
    deficit = k - sum(alloc.values())
    if deficit > 0:
        pool = {c: available[c] - alloc[c] for c in counts}
        order = sorted(pool.items(), key=lambda x: (-x[1], x[0]))
        idx = 0
        while deficit > 0 and any(v > 0 for v in pool.values()):
            c = order[idx % len(order)][0]
            if pool[c] > 0:
                alloc[c] += 1
                pool[c] -= 1
                deficit -= 1
            idx += 1

    assert sum(alloc.values()) == k
    return alloc


def stratified_sample(df: pd.DataFrame, label_col: str, k: int, seed: int) -> pd.DataFrame:
    counts = Counter(df[label_col])
    alloc = proportional_allocate(dict(counts), k)

    taken_idx: List[int] = []
    for cls, need in alloc.items():
        cls_idx = df.index[df[label_col] == cls].tolist()
        if need > 0:
            chosen = pd.Index(cls_idx).to_series().sample(n=need, random_state=seed, replace=False).tolist()
            taken_idx.extend(chosen)

    return df.loc[sorted(taken_idx)]


def write_split_reports(df_full: pd.DataFrame, df_test: pd.DataFrame, df_train: pd.DataFrame, report_path: str) -> None:
    def summarize(df: pd.DataFrame, name: str) -> pd.DataFrame:
        cnt = df["_label_tool"].value_counts().rename("count").to_frame()
        cnt.index.name = "tool"
        cnt["share"] = (cnt["count"] / len(df)).round(6)
        cnt["split"] = name
        return cnt.reset_index()

    parts = [
        summarize(df_full, "full"),
        summarize(df_test, "test"),
        summarize(df_train, "train_pool"),
    ]
    stat = pd.concat(parts, ignore_index=True)

    lines: List[str] = []
    lines.append("# Split Statistics\n")
    lines.append(f"Total samples: {len(df_full)}\n\n")
    lines.append("## Per-split overview\n\n")
    lines.append(f"- test: {len(df_test)}\n")
    lines.append(f"- train_pool: {len(df_train)}\n\n")

    lines.append("## Tool distribution\n\n")
    for split_name, sub in stat.groupby("split"):
        lines.append(f"### {split_name}\n\n")
        lines.append("| tool | count | share |\n")
        lines.append("|---|---:|---:|\n")
        for _, r in sub.sort_values(["count", "tool"], ascending=[False, True]).iterrows():
            lines.append(f"| {r['tool']} | {int(r['count'])} | {r['share']:.6f} |\n")
        lines.append("\n")

    tools_full = set(df_full["_label_tool"].unique())
    tools_test = set(df_test["_label_tool"].unique())
    tools_train = set(df_train["_label_tool"].unique())
    lines.append("## Tool coverage\n\n")
    lines.append(f"- full tools: {len(tools_full)}\n")
    lines.append(f"- test tools: {len(tools_test)} ({len(tools_test)/len(tools_full):.2%} of full)\n")
    lines.append(f"- train_pool tools: {len(tools_train)} ({len(tools_train)/len(tools_full):.2%} of full)\n\n")

    with open(report_path, "w", encoding="utf-8") as f:
        f.write("".join(lines))


def run_split(input_path, out_splits_dir, reports_dir, test_size=2000, seed=2025, tool_key="tool_name"):
    ensure_dirs(out_splits_dir, reports_dir)

    df = read_normalized_csv(input_path)
    df = add_stratify_label(df, tool_key)

    total_n = len(df)
    if test_size <= 0 or test_size >= total_n:
        raise ValueError(f"test-size 必须在 (0, {total_n}) 之间，当前为 {test_size}")

    df_test = stratified_sample(df, "_label_tool", test_size, seed)
    test_ids = set(df_test["_row_id"].tolist())
    df_train_pool = df[~df["_row_id"].isin(test_ids)].copy()

    test_path = os.path.join(out_splits_dir, "test.csv")
    train_pool_path = os.path.join(out_splits_dir, "train_pool.csv")

    df_test[["query", "tools", "gold_call"]].to_csv(test_path, index=False, encoding="utf-8")
    df_train_pool[["query", "tools", "gold_call"]].to_csv(train_pool_path, index=False, encoding="utf-8")

    report_path = os.path.join(reports_dir, "split_stats.md")
    write_split_reports(df, df_test, df_train_pool, report_path)

    print("✅ Done.")
    print(f"test.csv -> {test_path}")
    print(f"train_pool.csv -> {train_pool_path}")
    print(f"reports/split_stats.md -> {report_path}")



if __name__ == "__main__":
    run_split(
        input_path,
        out_splits_dir,
        reports_dir,
        test_size=2000,
        seed=2025,
    )

✅ Done.
test.csv -> /Users/yinghanding/Desktop/Small-Data/splits/test.csv
train_pool.csv -> /Users/yinghanding/Desktop/Small-Data/splits/train_pool.csv
reports/split_stats.md -> /Users/yinghanding/Desktop/Small-Data/reports/split_stats.md
