In [3]:
from datasets import load_dataset, Dataset, DatasetDict
import re
from tqdm import tqdm
import os
import json  # ✅ 添加：用于标准 JSON 保存

# ---------------------------
# 1. Split answer field into steps
# ---------------------------
def split_answer(answer):
    lines = [l.strip() for l in answer.split('\n') if l.strip()]
    steps = [l for l in lines if not l.startswith("####")]
    final = [l.replace("####", "").strip() for l in lines if l.startswith("####")]
    if final:
        steps.append(final[0])  # Append final answer as last step
    return steps

# ---------------------------
# 2. Convert one example into stepwise prediction pairs
# ---------------------------
def expand_gsm_example(example):
    steps = split_answer(example["answer"])
    if len(steps) < 2:
        return []

    outputs = []
    correct = steps[-1]  # last step is final answer

    for i in range(1, len(steps)):
        prev_steps = "\n".join([f"Step {j+1}: {steps[j]}" for j in range(i)])
        input_text = f"""Question: {example['question']}

Previous Steps:
{prev_steps}

Next Step:"""
        outputs.append({
            "input": input_text,
            "target": steps[i],
            "question": example["question"],
            "correct_answer": correct,
            "step_number": i + 1,
            "total_steps": len(steps)
        })
    return outputs

# ---------------------------
# 3. Process all splits
# ---------------------------
def process_gsm8k(save_path="gsm8k_stepwise"):
    dataset = load_dataset("openai/gsm8k", "main")
    processed_splits = {}

    for split in dataset.keys():
        print(f"\nProcessing split: {split}")
        examples = dataset[split]
        all_pairs = []

        for ex in tqdm(examples, desc=f"{split}"):
            all_pairs.extend(expand_gsm_example(ex))

        processed_splits[split] = Dataset.from_list(all_pairs)
        print(f"  Original: {len(examples)}, Processed: {len(all_pairs)}, Avg steps: {len(all_pairs)/len(examples):.2f}")

    processed_dataset = DatasetDict(processed_splits)

    # ✅ 保存为 HuggingFace 原生格式
    processed_dataset.save_to_disk(save_path)
    print(f"\n✅ Saved full dataset to: {save_path}")

    # ✅ 保存为标准 JSON（数组格式）
    os.makedirs("gsm8k_json", exist_ok=True)
    for split in processed_dataset:
        out_path = f"gsm8k_json/gsm8k_{split}.json"
        with open(out_path, "w") as f:
            json.dump(processed_dataset[split].to_list(), f, indent=4, ensure_ascii=False)
        print(f"✅ Saved {split} as standard JSON to: {out_path}")

    return processed_dataset

# ---------------------------
# Run it
# ---------------------------
if __name__ == "__main__":
    process_gsm8k()



Processing split: train


train: 100%|██████████| 7473/7473 [00:00<00:00, 56701.21it/s]


  Original: 7473, Processed: 26720, Avg steps: 3.58

Processing split: test


test: 100%|██████████| 1319/1319 [00:00<00:00, 61658.94it/s]


  Original: 1319, Processed: 4819, Avg steps: 3.65


Saving the dataset (1/1 shards): 100%|██████████| 26720/26720 [00:00<00:00, 1997145.25 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 4819/4819 [00:00<00:00, 1201471.26 examples/s]



✅ Saved full dataset to: gsm8k_stepwise
✅ Saved train as standard JSON to: gsm8k_json/gsm8k_train.json
✅ Saved test as standard JSON to: gsm8k_json/gsm8k_test.json
