In [13]:
# pip 依赖（第一次运行需要网络）
# 如果你的环境里已安装，可以注释掉这几行
import sys, subprocess
def _pip_install(pkg):
    try:
        __import__(pkg.split("==")[0])
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-q"])

for _pkg in ["datasets>=2.19.0", "pandas>=2.0.0", "pyarrow>=14.0.0"]:
    _pip_install(_pkg)

from datasets import load_dataset
import pandas as pd
from collections import Counter

# ========= 1) 在 Hugging Face 上加载 deepmind/code_contests 的 train split =========
# 说明：该数据集较大，首次下载会缓存到本地
ds = load_dataset("deepmind/code_contests", split="train")

# ========= 2) 整个 split 转存为 Parquet（原始结构） =========
# 优先使用 datasets 的内置导出；若不可用则退回 pandas
full_parquet_path = "code_contests_train_full.parquet"
try:
    # 新版 datasets 支持直接导 parquet
    ds.to_parquet(full_parquet_path)
except Exception:
    # 回退方案：转 pandas 再写 parquet（占内存更高）
    ds.to_pandas().to_parquet(full_parquet_path, engine="pyarrow", index=False)

print(f"[保存] 原始 train split 已导出：{full_parquet_path}")

# ========= 3) 按条件筛选 =========
# 条件：
# a) correct_solutions 和 incorrect_solutions 都不为空，且“Python 的答案”均至少 10 个
#    注：数据集中 language 枚举：PYTHON(1, Python2) 与 PYTHON3(3)，这里两者都算作 Python
# b) public_tests.input（即 test_input）不为空
PY_LANGS = {1, 3}  # Python2 与 Python3

def extract_python_codes(solutions_dict):
    """从 solutions/incorrect_solutions 里提取 Python 代码列表。"""
    if not isinstance(solutions_dict, dict):
        return []
    langs = solutions_dict.get("language", []) or []
    sols  = solutions_dict.get("solution", []) or []
    # zip 防止长度不一致
    out = []
    for lang, sol in zip(langs, sols):
        if lang in PY_LANGS and isinstance(sol, str) and sol.strip():
            out.append(sol)
    return out

def has_nonempty_public_input(public_tests):
    """public_tests.input 至少有一个非空字符串。"""
    if not isinstance(public_tests, dict):
        return False
    inputs = public_tests.get("input", [])
    if not isinstance(inputs, list):
        return False
    return any(isinstance(s, str) and s.strip() for s in inputs)

filtered_rows = []
filtered_difficulties = []     # 用于统计难度分布（不写入最终文件）
test_input_lengths = []        # 用于统计 test_input 长度的最短与最长

for ex in ds:
    py_correct = extract_python_codes(ex.get("solutions"))
    py_incorrect = extract_python_codes(ex.get("incorrect_solutions"))
    public_tests = ex.get("public_tests", {}) or {}
    miniset_num = 5

    if len(py_correct) >= miniset_num and len(py_incorrect) >= miniset_num and has_nonempty_public_input(public_tests):
        row = {
            # ========= 4) 仅保留并重命名你要求的字段 =========
            "problem": ex.get("description", ""),                                # 原 question/description
            "correct_solutions": py_correct[:miniset_num],                                # 正确 Python 前 10
            "incorrect_solutions": py_incorrect[:miniset_num],                            # 错误 Python 前 10
            "test_input": public_tests.get("input", []),                         # 原 public_tests.input
            "test_output": public_tests.get("output", []),                       # 原 public_tests.output
        }
        filtered_rows.append(row)

        # 仅用于统计（不写入最终 parquet）
        filtered_difficulties.append(ex.get("difficulty", 0))
        test_input_lengths.append(len(row["test_input"]))

# 保存筛选后的数据
filtered_df = pd.DataFrame(filtered_rows)
filtered_parquet_path = "code_contests_train_filtered.parquet"
# 注意：list 列需要 pyarrow 引擎
filtered_df.to_parquet(filtered_parquet_path, engine="pyarrow", index=False)

print(f"[保存] 筛选并重组后的数据已导出：{filtered_parquet_path}")

# ========= 5) 统计并打印 =========
# 5.1 总的 dataset 个数（筛选后）
total_count = len(filtered_df)
print(f"[统计] 筛选后样本总数：{total_count}")

# 5.2 “重新存储的 problem 难度”
#    这里打印 difficulty 字段的分布（按数值枚举统计）。若需要映射到名字，可用下方 DIFF_NAME_MAP。
DIFF_NAME_MAP = {
    0: "UNKNOWN_DIFFICULTY", 1: "EASY", 2: "MEDIUM", 3: "HARD", 4: "HARDEST", 5: "EXTERNAL", 6: "EXTERNAL",
    # A..V -> 7..28
    **{i: chr(ord('A') + (i - 7)) for i in range(7, 29)}
}
diff_counter = Counter(filtered_difficulties)
if diff_counter:
    print("[统计] 难度分布（difficulty_value -> count / name）：")
    for k in sorted(diff_counter.keys()):
        name = DIFF_NAME_MAP.get(k, "UNKNOWN")
        print(f"  {k:>2} -> {diff_counter[k]:>4} / {name}")
else:
    print("[统计] 无满足条件的数据，难度分布为空。")

# 5.3 最长和最短的 test_input 个数（public_tests.input 的条目数）
if test_input_lengths:
    print(f"[统计] test_input 数量：最短 = {min(test_input_lengths)}, 最长 = {max(test_input_lengths)}")
else:
    print("[统计] 无满足条件的数据，无法统计 test_input 长度。")

# ========== 小结 ==========
print("\n[完成] 文件输出：")
print(f"  - 原始 train Parquet：{full_parquet_path}")
print(f"  - 筛选后 Parquet：    {filtered_parquet_path}")


Creating parquet from Arrow format: 100%|██████████| 14/14 [00:44<00:00,  3.17s/ba]


[保存] 原始 train split 已导出：code_contests_train_full.parquet
[保存] 筛选并重组后的数据已导出：code_contests_train_filtered.parquet
[统计] 筛选后样本总数：6502
[统计] 难度分布（difficulty_value -> count / name）：
   0 -> 1685 / UNKNOWN_DIFFICULTY
   7 -> 1271 / A
   8 -> 1191 / B
   9 ->  997 / C
  10 ->  706 / D
  11 ->  397 / E
  12 ->  149 / F
  13 ->   49 / G
  14 ->   18 / H
  15 ->    5 / I
  16 ->    8 / J
  17 ->    8 / K
  19 ->    5 / M
  20 ->    7 / N
  21 ->    2 / O
  22 ->    1 / P
  23 ->    1 / Q
  24 ->    1 / R
  25 ->    1 / S
[统计] test_input 数量：最短 = 1, 最长 = 8

[完成] 文件输出：
  - 原始 train Parquet：code_contests_train_full.parquet
  - 筛选后 Parquet：    code_contests_train_filtered.parquet
