In [24]:
import os
import os.path as osp
import json
import pickle
import collections as C
import itertools as I
import random
import regex as re
from typing import List, Optional, Dict, Tuple

import msgspec
from tqdm import tqdm

from common.constants import SYSTEM_PROMPT_FPG
from common.pantograph.dataclasses import ProblemGenerationProcess, ProblemGenerationStep, Variable, normalize_draft, replace_span

In [2]:
def print_xtuner_sample(data: list):
    print('<SYSTEM>')
    print(data['conversation'][0]['system'])
    print('</SYSTEM>\n<INPUT>')
    print(data['conversation'][0]['input'])
    print('</INPUT>\n<OUTPUT>')
    print(data['conversation'][0]['output'])
    print('</OUTPUT>')



In [3]:
enc = msgspec.msgpack.Encoder()
dec = msgspec.msgpack.Decoder(List[Optional[ProblemGenerationProcess]])

In [4]:
data_root = '/cache/data/cycle0123_succeeded'
done_paths = [osp.join(data_root, p) for p in os.listdir(data_root) if p.startswith('done_chunk_') and p.endswith('.msgp')]
assert len(done_paths) == len([p for p in os.listdir(data_root) if p.startswith('raw') ])
assert len(done_paths) == len([p for p in os.listdir(data_root) if p.startswith('exception') ])
print(len(done_paths))

210


In [5]:
data = []

for p in tqdm(done_paths):
    with open(p, 'rb') as f:
        data.extend(dec.decode(f.read()))

100%|██████████| 210/210 [02:38<00:00,  1.33it/s]


In [6]:
print(len(data))
data = [d for d in data if d is not None]
print(len(data))

214442
201762


In [7]:
for d in data:
    d.metainfo = json.loads(d.metainfo)
    if 'proof_search_results' in d.metainfo:
        d.metainfo.pop('proof_search_results')

In [8]:
C.Counter([
    str(sorted(d.metainfo.keys())) for d in data
])

Counter({"['annotator', 'problem_is_valid', 'problem_type', 'question_type', 'solution_is_valid', 'source', 'synthetic', 'time_consumption']": 178714,
         "['level', 'subject']": 23048})

In [9]:
C.Counter([
    d.metainfo['problem_type'] for d in data if 'problem_type' in d.metainfo
])

Counter({'Algebra': 106358,
         'Geometry': 23386,
         'Combinatorics': 17412,
         'Logic and Puzzles': 12403,
         'Number Theory': 12147,
         'Other': 4417,
         'Calculus': 1469,
         'Inequalities': 1122})

In [10]:
C.Counter([
    d.metainfo['question_type'] for d in data if 'question_type' in d.metainfo
])

Counter({'math-word-problem': 178714})

In [11]:
C.Counter([
    d.metainfo['source'] for d in data if 'source' in d.metainfo
])

Counter({'orca_math': 81653,
         'synthetic_math': 45641,
         'cn_k12': 32583,
         'olympiads': 13792,
         'metamath': 2052,
         'aops_forum': 1557,
         'cn_contest': 1167,
         'number_theory': 93,
         'olympiads_ref': 88,
         'amc_aime': 47,
         'inequalities': 41})

In [12]:
C.Counter([
    d.metainfo['synthetic'] for d in data if 'synthetic' in d.metainfo
])

Counter({True: 129346, False: 49368})

In [13]:
C.Counter([
    d.metainfo['subject'] for d in data if 'subject' in d.metainfo
])

Counter({'Algebra': 7704,
         'Intermediate Algebra': 6401,
         'Prealgebra': 5334,
         'Precalculus': 2027,
         'Number Theory': 1582})

In [14]:
C.Counter([
    str(sorted(d.metainfo.keys())) for d in data
])

Counter({"['annotator', 'problem_is_valid', 'problem_type', 'question_type', 'solution_is_valid', 'source', 'synthetic', 'time_consumption']": 178714,
         "['level', 'subject']": 23048})

In [15]:
C.Counter([
    type(d.metainfo) for d in data if d is not None
])

Counter({dict: 201762})

In [16]:
data_nonsynthetic_n15 = [
    d for d in data if 'synthetic' in d.metainfo and d.metainfo['synthetic'] is False
]
len(data_nonsynthetic_n15)

49368

In [17]:
C.Counter([
    d.metainfo['problem_type'] for d in data_nonsynthetic_n15 if 'problem_type' in d.metainfo
])

Counter({'Algebra': 24124,
         'Geometry': 9286,
         'Combinatorics': 7330,
         'Number Theory': 3283,
         'Logic and Puzzles': 2338,
         'Calculus': 1440,
         'Inequalities': 948,
         'Other': 619})

In [18]:
C.Counter([
    d.metainfo['source'] for d in data_nonsynthetic_n15 if 'source' in d.metainfo
])

Counter({'cn_k12': 32583,
         'olympiads': 13792,
         'aops_forum': 1557,
         'cn_contest': 1167,
         'number_theory': 93,
         'olympiads_ref': 88,
         'amc_aime': 47,
         'inequalities': 41})

In [None]:
SYSTEM_PROMPT_FPG = '''You are an Olympiad problem setter and a Lean 4 expert.
You revel in conjuring elegant problems — starting from a spare set of hypotheses, you let rigorous deduction lead you to surprising and beautiful conclusions.'''

In [25]:
def format_step(self):
    if self.proof is None:
        return self.step_draft
    else:
        normalized_step_draft = normalize_draft(self.step_draft)
        matches = list(re.finditer(':= sorry', normalized_step_draft))
        assert len(matches) == len(self.proof)
        for (m, p) in reversed(list(zip(matches, self.proof))):
            normalized_step_draft = replace_span(m.span(), ':= by {\n' + '\n'.join('  ' + l for l in p.splitlines()) + '\n}', normalized_step_draft)
        return normalized_step_draft

In [43]:
def format_forward_solution_step_prompt(d: ProblemGenerationProcess, state: List[Variable]) -> str:
    context = ''
    vars_to_format = [v for v in state]
    while len(vars_to_format) > 0:
        for i in range(len(vars_to_format)):
            if i + 1 == len(vars_to_format) or not (vars_to_format[i].t == vars_to_format[i+1].t and vars_to_format[i].v is None and vars_to_format[i+1].v is None):
                break
        if i == 0:
            context += str(vars_to_format[0]) + '\n'
            vars_to_format.pop(0)
        else:
            context += ' '.join([v.name if v.name is not None else "_" for v in vars_to_format[:i+1]]) + f' : {vars_to_format[0].t}\n'
            vars_to_format = vars_to_format[i+1:]
    
    prompt = f'''Given a Lean 4 context, propose the single most natural next step to explore toward a beautiful conclusion — either
- derive a new intermediate fact,
- introduce a fresh variable or hypothesis, or
- submit one of the local facts as the final answer.

Requirements
1. Flavoured {d.metainfo['problem_type']} and suitable for posting on forums about {d.metainfo['source']}.
2. Fully formal Lean 4 code (inline comments in natural language are fine for planning and reasoning). Assume `import Mathlib`.


# Lean 4 Context
```lean4
{context.rstrip()}
```
'''
    return prompt


def format_forward_solution_step_response(d: ProblemGenerationProcess, step: ProblemGenerationStep):
    step_type = 'Derive' if step.is_deducing else 'Introduce' if step.is_introducing else 'Submit'
    response = f'''# Step {step_type}
```lean4
{format_step(step).rstrip()}
```
'''
    return response

In [44]:
data_problem_generation = []

In [47]:
for d in tqdm(data_nonsynthetic_n15):
    for p in d.trajectory:
        data_problem_generation.append({
            "conversation":[
                {
                    "system": SYSTEM_PROMPT_FPG,
                    "input": format_forward_solution_step_prompt(d, p[0]),
                    "output": format_forward_solution_step_response(d, d.steps[p[1]])
                }
            ]
        })

100%|██████████| 49368/49368 [00:39<00:00, 1259.23it/s]


In [48]:
len(data_problem_generation)

624235

In [49]:
with open('/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle123_problem_generation_steps.jsonl', 'w') as f:
    for s in data_problem_generation:
        f.write(json.dumps(s) + '\n')

In [45]:
print(format_forward_solution_step_prompt(d, p[0]))

Given a Lean 4 context, propose the single most natural next step to explore toward a beautiful conclusion — either
- derive a new intermediate fact,
- introduce a fresh variable or hypothesis, or
- submit one of the local facts as the final answer.

Requirements
1. Flavoured Geometry and suitable for posting on forums about cn_k12.
2. Fully formal Lean 4 code (inline comments in natural language are fine for planning and reasoning). Assume `import Mathlib`.


# Lean 4 Context
```lean4
answer : ℝ × ℝ
translation_down : ℝ × ℝ → ℝ × ℝ
h₀ : ∀ (p : ℝ × ℝ), translation_down p = (p.1, p.2 - 2)
translation_left : ℝ × ℝ → ℝ × ℝ
h₁ : ∀ (p : ℝ × ℝ), translation_left p = (p.1 - 4, p.2)
h_answer✝ : translation_left (translation_down (0, 1)) = answer
h_translate_down : translation_down (0, 1) = (0, 1 - 2)
h_down : translation_down (0, 1) = (0, -1)
h_translate_left : translation_left (0, -1) = (0 - 4, -1)
h_left : translation_left (0, -1) = (-4, -1)
h_answer : answer = (-4, -1)
```



In [46]:
print(format_forward_solution_step_response(d, d.steps[p[1]]))

# Step Submit
```lean4
submit_answer h_answer
```

