In [1]:
import pyarrow as pa
import json
from datasets import Dataset
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def process_and_save_gsm8k_dataset(input_path, output_path):
    dataset = Dataset.from_file(input_path)
    print(dataset if "train" in input_path else len(dataset))
    dataset_dict = dataset.to_dict()

    dataset_dict["extracted_answers"] = [
        float(answer.split("####")[-1].replace(",", ""))
        for answer in dataset_dict["answer"]
    ]

    dataset_dict["id"] = list(range(len(dataset_dict["answer"])))

    with open(output_path, "w") as json_file:
        json.dump(dataset_dict, json_file)

In [3]:
process_and_save_gsm8k_dataset(
    "../datasets/GSM8K/train/data-00000-of-00001.arrow",
    "../datasets/GSM8K/train/dataset.json",
)
process_and_save_gsm8k_dataset(
    "../datasets/GSM8K/test/data-00000-of-00001.arrow",
    "../datasets/GSM8K/test/dataset.json",
)

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})
1319


In [4]:
def process_and_save_math_dataset(
    load_path="../datasets/MATH/train",
    output_path="../datasets/MATH/train/dataset.json",
):
    math_dataset_dict = {
        "question": [],
        "level": [],
        "type": [],
        "answer": [],
        "extracted_answers": [],
    }

    directories = [
        d for d in os.listdir(load_path) if os.path.isdir(os.path.join(load_path, d))
    ]

    for directory in directories:
        dir_path = os.path.join(load_path, directory)
        json_files = [file for file in os.listdir(dir_path) if file.endswith(".json")]
        for json_file in json_files:
            with open(os.path.join(dir_path, json_file), "r") as f:
                contents = json.load(f)
                math_dataset_dict["question"].append(contents["problem"])
                math_dataset_dict["answer"].append(contents["solution"])
                try:
                    math_dataset_dict["level"].append(int(contents["level"][-1]))
                except Exception as e:
                    math_dataset_dict["level"].append(-1)
                math_dataset_dict["type"].append(contents["type"].lower())

    math_dataset_dict["extracted_answers"] = [
        (
            lambda s: s[
                s.find("\\boxed{") + len("\\boxed{") : next(
                    i
                    for i, c in enumerate(
                        s[s.find("\\boxed{") + len("\\boxed{") :],
                        start=s.find("\\boxed{") + len("\\boxed{"),
                    )
                    if c == "}"
                    and s[s.find("\\boxed{") + len("\\boxed{") : i].count("{")
                    == s[s.find("\\boxed{") + len("\\boxed{") : i].count("}")
                )
            ]
            if "\\boxed{" in s
            else ""
        )(answer)
        for answer in math_dataset_dict["answer"]
    ]

    # Add boolean flag for whether extracted answer can be converted to float
    math_dataset_dict["is_answer_numeric"] = []
    for answer in math_dataset_dict["extracted_answers"]:
        try:
            # Try to convert answer to float, handling common math notation
            cleaned = (
                answer.replace("\\", "")
                .replace("frac{", "")
                .replace("dfrac{", "")
                .replace("}", "")
                .replace(" ", "")
            )
            if cleaned:
                float(eval(cleaned))  # Using eval to handle fractions like "1/2"
            math_dataset_dict["is_answer_numeric"].append(True)
        except Exception as e:
            math_dataset_dict["is_answer_numeric"].append(False)

    # Verify all lists in the dictionary have the same length
    expected_length = len(math_dataset_dict["question"])
    for key, value in math_dataset_dict.items():
        if len(value) != expected_length:
            raise ValueError(
                f"Length mismatch for {key}: expected {expected_length}, got {len(value)}"
            )

    # Add sequential IDs to the dataset
    math_dataset_dict["id"] = list(range(len(math_dataset_dict["question"])))

    with open(output_path, "w") as json_file:
        json.dump(math_dataset_dict, json_file)


In [5]:
process_and_save_math_dataset()
process_and_save_math_dataset(
    "../datasets/MATH/test", "../datasets/MATH/test/dataset.json"
)



In [6]:
# Filter out non-numeric questions from MATH dataset
def filter_and_store_numeric_questions(dataset_path, output_path):
    with open(dataset_path, "r") as f:
        dataset = json.load(f)

    # Keep only entries where is_answer_numeric is True
    numeric_indices = [
        i for i, is_numeric in enumerate(dataset["is_answer_numeric"]) if is_numeric
    ]

    filtered_dataset = {
        key: [value[i] for i in numeric_indices] for key, value in dataset.items()
    }

    # Save filtered dataset
    with open(output_path, "w") as f:
        json.dump(filtered_dataset, f)

    print(len(filtered_dataset["question"]))


# Filter test dataset only
filter_and_store_numeric_questions(
    "../datasets/MATH/test/dataset.json", "../datasets/MATH/test/dataset_numeric.json"
)
filter_and_store_numeric_questions(
    "../datasets/MATH/train/dataset.json", "../datasets/MATH/train/dataset_numeric.json"
)


3203
4877


In [7]:
with open("../datasets/MATH/train/dataset_numeric.json", "r") as f:
    dataset = json.load(f)
print(f"Number of numeric questions in train dataset: {len(dataset['question'])}")

with open("../datasets/MATH/test/dataset_numeric.json", "r") as f:
    dataset = json.load(f)
print(f"Number of numeric questions in test dataset: {len(dataset['question'])}")


with open("../datasets/MATH/train/dataset.json", "r") as f:
    dataset = json.load(f)
print(f"Number of  questions in train dataset: {len(dataset['question'])}")

with open("../datasets/MATH/test/dataset.json", "r") as f:
    dataset = json.load(f)
print(f"Number of questions in test dataset: {len(dataset['question'])}")

Number of numeric questions in train dataset: 4877
Number of numeric questions in test dataset: 3203
Number of  questions in train dataset: 7500
Number of questions in test dataset: 5000
