In [8]:
import datasets

ds = datasets.load_dataset('csv', data_files='lean/CombiBench/metadata.csv')['train']

In [9]:
import os
import re

from pprint import pprint


def normalize_newlines(text):
    # Replace 3 or more newlines with 2 newlines
    return re.sub(r'\n{3,}', '\n\n', text)

def get_data(ds, source_files_dir):
    new_ds = []
    for row in ds:
        source_file = os.path.join(source_files_dir, row['theorem_name'] + '.lean')
        if not os.path.exists(source_file):
            source_file = os.path.join(source_files_dir, row['theorem_name'] + '_sol.lean')
        if not os.path.exists(source_file):
            print(f"{source_file} not exist!")
            continue
        with open(source_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        header_state = True
        header_lines = []
        informal_comments = []
        formal_lines = []
        answer_tags = []
        in_comment_block = False

        for line in lines:
            stripped = line.strip()

            # process single line comment --
            if "--" in line:
                line = re.sub(r"--.*", "", line)

            if header_state and not stripped.startswith("/--"):
                if stripped.startswith("open "):
                    header_lines.append(line)
                    header_state = False
                    continue

                elif not stripped:
                    header_lines.append(line)
                    continue

                elif not stripped.startswith("import "):
                    header_state = False
                    formal_lines.append(line)
                    if "abbrev" in line and "_solution" in line:
                        match = re.search(r'abbrev\s+([^:]+)', line)
                        answer_tags.append(match.group(1).strip())

                    continue

                header_lines.append(line)
                continue

            # process comment block /-- ... -/
            if stripped.startswith("/--"):
                in_comment_block = True
                header_state = False
                informal_comments.append(stripped[3:].strip())
                formal_lines.append("\n")
                continue
            elif stripped.endswith("-/"):
                in_comment_block = False
                header_state = False
                informal_comments.append(stripped[:-2].strip())
                continue
            elif in_comment_block:
                header_state = False
                informal_comments.append(stripped)
                continue

            # the rest of the line is formal statement
            formal_lines.append(line)
            if "abbrev" in line and "_solution" in line:
                match = re.search(r'abbrev\s+([^:(]+)', line)
                answer_tags.append(match.group(1).strip())

        # print(f"theorem_name: {row['theorem_name']}")
        # print(f"header: {normalize_newlines(''.join(header_lines).strip())}")
        # print(f"formal: {normalize_newlines(''.join(formal_lines).strip())}")
        # print(f"informal: {normalize_newlines(''.join(informal_comments).strip())}")
        # print(f"answer_tags: {answer_tags}")
        # print("\n---------------------\n")

        header = normalize_newlines(''.join(header_lines).strip())
        row["formal_statement"] = header + "\n\n" + normalize_newlines(''.join(formal_lines).strip())
        row["natural_language"] = normalize_newlines(''.join(informal_comments).strip())
        row["answer_tags"] = answer_tags
        
        if row["answer"] is not None and not isinstance(row["answer"], type(None)):
            if row["answer"].startswith("[") and row["answer"].endswith("]"):
                row["answer"] = row["answer"][1:-1].split(", ")
            else:
                row["answer"] = [row["answer"]]

        new_ds.append(row)

    return new_ds

wo_solution = get_data(ds, source_files_dir='lean/CombiBench/')


pprint(wo_solution[5])

wo_solution = datasets.Dataset.from_list(wo_solution)



{'answer': ['3 / 4', '3 / 4', '1 / 2'],
 'answer_tags': ['hackmath_6_1_solution',
                 'hackmath_6_2_solution',
                 'hackmath_6_3_solution'],
 'comment': None,
 'formal_statement': 'import Mathlib\n'
                     '\n'
                     'noncomputable abbrev hackmath_6_1_solution : ENNReal := '
                     'sorry\n'
                     '\n'
                     'noncomputable abbrev hackmath_6_2_solution : ENNReal := '
                     'sorry\n'
                     '\n'
                     'noncomputable abbrev hackmath_6_3_solution : ENNReal := '
                     'sorry\n'
                     '\n'
                     'theorem hackmath_6 : PMF.binomial (1/2 : _) '
                     'ENNReal.half_le_self 2 1 +\n'
                     '    PMF.binomial (1/2 : _) ENNReal.half_le_self 2 2 = '
                     'hackmath_6_1_solution ∧\n'
                     '    PMF.binomial (1/2 : _) ENNReal.half_le_self 2 0 +\n'
            

In [10]:
from pathlib import Path

def process_solutions(ds, source_files_dir):
    for row in ds:
        # if row["answer"] is not None and not isinstance(row["answer"], type(None)):
        #     if row["answer"].startswith("[") and row["answer"].endswith("]"):
        #         row["answer"] = row["answer"][1:-1].split(", ")
        #     else:
        #         row["answer"] = [row["answer"]]
        #     try:
        #         assert isinstance(row["answer"], list), f"expected answer type list, got {type(row['answer'])}"
        #         assert len(row["answer"]) == len(row["answer_tags"]), f"{row['theorem_name']} has {len(row['answer'])} answers but {len(row['answer_tags'])} answer tags"
        #     except AssertionError:
        #         pass
        # else:
        #     row["answer"] = None

        # rewrite solutions
        source_file = os.path.join(source_files_dir, row['theorem_name'] + '.lean')
        if not os.path.exists(source_file):
            continue
        with open(source_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        solution_to_type = {}
        index = 0
        new_lines =[]
        for line in lines:
            if "abbrev" in line and "_solution" in line and row["answer"] is not None:
                if m := re.fullmatch(r".*?:\s*(.*)\s*:=\s*sorry\n", line):
                    solution_type = m.group(1)
                else:
                    raise ValueError(f"Could not parse type from: {line}")
                line = line.replace("sorry", row["answer"][index])
                solution_to_type[row["answer_tags"][index]] = solution_type.strip()
                index += 1
            for solution_name, solution_type in solution_to_type.items():
                if solution_name in line and "abbrev" not in line:
                    if (
                        solution_type[0] == "("
                        and solution_type[-1] == ")"
                    ):
                        solution_type = solution_type[1:-1]
                    line = line.replace(
                        solution_name,
                        f"(({solution_name}) : {solution_type})",
                    )   
            new_lines.append(line)
        with open(Path(source_files_dir)/"with_solution"/f"{row['theorem_name']}_sol.lean", "w", encoding="utf-8") as f:
            for line in new_lines:
                f.write(line)

# process_solutions(wo_solution, "lean/CombiBench/")

with_solution = get_data(ds, source_files_dir='lean/CombiBench/with_solution/')

pprint(with_solution[5])
with_solution = datasets.Dataset.from_list(with_solution)


{'answer': ['3 / 4', '3 / 4', '1 / 2'],
 'answer_tags': [],
 'comment': None,
 'formal_statement': 'import Mathlib\n'
                     '\n'
                     'theorem hackmath_6 : PMF.binomial (1/2 : _) '
                     'ENNReal.half_le_self 2 1 +\n'
                     '    PMF.binomial (1/2 : _) ENNReal.half_le_self 2 2 = '
                     '((3 / 4) : ENNReal ) ∧\n'
                     '    PMF.binomial (1/2 : _) ENNReal.half_le_self 2 0 +\n'
                     '    PMF.binomial (1/2 : _) ENNReal.half_le_self 2 1 = '
                     '((3 / 4) : ENNReal ) ∧\n'
                     '    PMF.binomial (1/2 : _) ENNReal.half_le_self 2 1 = '
                     '((1 / 2) : ENNReal ) := by sorry',
 'formal_statement_existence': None,
 'natural_language': 'Two coins are tossed simultaneously. What is the '
                     'probability of getting (i) At least one head? (ii) At '
                     'most one tail? (iii) A head and a tail?',
 'source': 'https:

In [11]:
# debug = {
#     'theorem_name': 'mathd_algebra_141',
#     'natural_language': 'A rectangular patio has an area of $180$ square feet and a perimeter of $54$ feet. What is the length of the diagonal (in feet) squared? Compute it.',
#     'answer': ['√369 * √41 * 3'],
#     'source': 'https://www.hackmath.net/en/word-math-problems/combinatorics',
#     'tag': 'debug',
#     'formal_statement_existence': True,
#     'comment': '',
#     'header': 'import Mathlib',
#     'formal_statement': """abbrev mathd_algebra_141_solution : ℝ := sorry

# theorem mathd_algebra_141 (a b : ℝ) (h₁ : a * b = 180) (h₂ : 2 * (a + b) = 54) :
#     a ^ 2 + b ^ 2 = mathd_algebra_141_solution := by sorry""",
#     'answer_tags': ['mathd_algebra_141_solution'],
# }

# debug = datasets.Dataset.from_list(mapping=[debug])
# print(debug.features)
print(with_solution.features)
print(wo_solution.features)

{'theorem_name': Value(dtype='string', id=None), 'natural_language': Value(dtype='string', id=None), 'answer': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'source': Value(dtype='string', id=None), 'tag': Value(dtype='string', id=None), 'formal_statement_existence': Value(dtype='bool', id=None), 'comment': Value(dtype='string', id=None), 'formal_statement': Value(dtype='string', id=None), 'answer_tags': Sequence(feature=Value(dtype='null', id=None), length=-1, id=None)}
{'theorem_name': Value(dtype='string', id=None), 'natural_language': Value(dtype='string', id=None), 'answer': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'source': Value(dtype='string', id=None), 'tag': Value(dtype='string', id=None), 'formal_statement_existence': Value(dtype='bool', id=None), 'comment': Value(dtype='string', id=None), 'formal_statement': Value(dtype='string', id=None), 'answer_tags': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)

In [12]:

features = datasets.Features({
    'theorem_name': datasets.Value('string'),
    'natural_language': datasets.Value('string'),
    'answer': datasets.Sequence(datasets.Value('string')),
    'source': datasets.Value('string'),
    'tag': datasets.Value('string'),
    'formal_statement': datasets.Value('string'),
})

wo_solution = wo_solution.remove_columns(['formal_statement_existence', 'comment', 'answer_tags']).cast(features)
with_solution = with_solution.remove_columns(['formal_statement_existence', 'comment', 'answer_tags']).cast(features)
# debug = debug.cast(features).remove_columns(['formal_statement_existence', 'comment'])


dataset = datasets.DatasetDict({
    "test": wo_solution,
    "test_with_solution": with_solution,
    # "debug": debug,
})

dataset.push_to_hub("AI-MO/CombiBench", private=True, token=os.environ.get('HF_TOKEN'))

Casting the dataset: 100%|██████████| 100/100 [00:00<00:00, 35234.41 examples/s]
Casting the dataset: 100%|██████████| 100/100 [00:00<00:00, 39624.98 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 816.81ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.43s/it]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 873.81ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.36s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/AI-MO/CombiBench/commit/5e4682cedd1c662aca2240c99e87b09f3fc0447b', commit_message='Upload dataset', commit_description='', oid='5e4682cedd1c662aca2240c99e87b09f3fc0447b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/AI-MO/CombiBench', endpoint='https://huggingface.co', repo_type='dataset', repo_id='AI-MO/CombiBench'), pr_revision=None, pr_num=None)

In [13]:
# print(datasets.load_dataset("AI-MO/rl-promptset-v4.1", split="train"))