In [1]:
import os
import os.path as osp
import json
import pickle
import collections as C
import itertools as I
import random
import regex as re
from typing import List, Optional, Dict, Tuple
import time

import msgspec
from tqdm import tqdm
from loguru import logger
from dacite import from_dict
import dacite

from common.constants import SYSTEM_PROMPT_FPG, CORE_OPTIONS, BANNED_TOKENS
from common.utils import remove_comments, replace_sorry, replace_calc, remove_multiline_comments, remove_singleline_comments, parse_idents
from common.pantograph.dataclasses import ProblemGenerationProcess, ProblemGenerationStep, Variable, normalize_draft, replace_span, Goal, GoalState, ProblemGenerationStep, ProblemGenerationProcess, TacticDraft
from common.pantograph.server import PersistentServer, TacticFailure, ServerError
from agent.problem_generation import AutoregressiveProblemGenerationAgent



In [2]:
def print_xtuner_sample(data: list):
    if 'system' in data['conversation'][0].keys():
        print('<SYSTEM>')
        print(data['conversation'][0]['system'])
        print('</SYSTEM>')
    print('<INPUT>')
    print(data['conversation'][0]['input'])
    print('</INPUT>\n<OUTPUT>')
    print(data['conversation'][0]['output'])
    print('</OUTPUT>')


In [3]:
base_dir = '/home/ma-user/workspace/formal_problem_generation/data/Numina-Lean/deductive_transformation'
data = []

for i in tqdm(list(range(41))):
    with open(osp.join(base_dir, f'done_v2_chunk_{1024*i}.pkl'), 'rb') as f:
        data.extend(pickle.load(f))

100%|██████████| 41/41 [01:32<00:00,  2.26s/it]


### Data Analysis

In [4]:
data_parsed = [d for d in data if d is not None]
assert all('parse_result' in d.keys() for d in data_parsed)
print(len(data), len(data_parsed), len(data)-len(data_parsed))

41109 39840 1269


In [5]:
n_total_units = 0
n_transformed_units = 0
for d in data_parsed:
    units = d['parse_result']['units']
    all_parsed_units = [i_u for i_u, u in enumerate(units) if len(u['invocations'] or []) > 0]
    n_total_units += len(all_parsed_units)
    n_transformed_units += len([i_u for i_u in all_parsed_units if 'deductive_steps' in units[i_u].keys()])
print(n_total_units, n_transformed_units, n_total_units-n_transformed_units)

41543 40069 1474


In [6]:
anonymous_name_cnt = C.Counter()
submission_cnt = C.Counter()

for d in data_parsed:
    for u in d['parse_result']['units']:
        if 'deductive_states' in u.keys():
            init_state = u['deductive_states'][0]
            assert len(init_state) == 1
            # init_state = from_dict(Goal, init_state[0])
            for v in init_state[0]['variables']:
                if '✝' in v['name']:
                    anonymous_name_cnt[v['name']] += 1
                    assert v['name'].replace('✝', '_') not in str(u['deductive_states'])
            
            submission = u['deductive_steps'][-1][-1][len('exact '):].strip()
            if ' ' in submission or '.' in submission:
                print(submission)
                raise
            else:
                submission_cnt[submission] += 1

eq.symm


RuntimeError: No active exception to reraise

In [7]:
server = PersistentServer(
    max_count=64,
    is_state_based=True,
    tag='',
    _sync_init=False,
    imports=["Mathlib", "Aesop"],
    project_path='/home/ma-user/workspace/formal_problem_generation/formal_problem_generation/data/MiniF2F',
    core_options=CORE_OPTIONS,
    timeout=300,
)

In [8]:
superscript_to_digit = {
    '⁰': '0', '¹': '1', '²': '2', '³': '3', '⁴': '4',
    '⁵': '5', '⁶': '6', '⁷': '7', '⁸': '8', '⁹': '9'
}

subscript_to_digit = {
    '₀': '0', '₁': '1', '₂': '2', '₃': '3', '₄': '4',
    '₅': '5', '₆': '6', '₇': '7', '₈': '8', '₉': '9'
}

digit_to_superscript = {v: k for k, v in superscript_to_digit.items()}
digit_to_subscript = {v: k for k, v in subscript_to_digit.items()}

allowed_prefices = ['h', 'h_']

def generate_submission_name(name_list: List[str]) -> str:
    # Parse names
    numbers_existing = C.defaultdict(list)
    for n in name_list:
        for p in allowed_prefices:
            if n.startswith(p):
                num_str = n[len(p):]
                if num_str == '':
                    numbers_existing[-1].append((p, 'text'))
                elif all(c in superscript_to_digit for c in num_str):
                    num = int(''.join(superscript_to_digit[c] for c in num_str))
                    numbers_existing[num].append((p, 'sup'))
                elif all(c in subscript_to_digit for c in num_str):
                    num = int(''.join(subscript_to_digit[c] for c in num_str))
                    numbers_existing[num].append((p, 'sub'))
                elif all(c.isascii() and c.isdigit() for c in num_str):
                    num = int(num_str)
                    numbers_existing[num].append((p, 'text'))
                    
    if not numbers_existing:
        numbers_existing = C.defaultdict(list, {
            -1: [('h', 'text')]
        })
    # Generate new name
    max_number = sorted(numbers_existing.keys())[-1]
    number_chosen = max_number + 1
    prefix, format_type = random.choice(numbers_existing[max_number])
    
    if number_chosen == 0:
        formatted_num = ''
    else:
        num_str = str(number_chosen)
        if format_type == 'sup':
            formatted_num = ''.join(digit_to_superscript[c] for c in num_str)
        elif format_type == 'sub':
            formatted_num = ''.join(digit_to_subscript[c] for c in num_str)
        else:  # text
            formatted_num = num_str
    new_name = f"{prefix}{formatted_num}"
    logger.debug(f'numbers_existing={numbers_existing}, max_number={number_chosen}, new_name={new_name}')
    return new_name

In [9]:
datapoint = d
base_cnt = 0
idx = 0
i_p = 0

In [10]:
import_list = datapoint['parse_result']['import_list']
open_scoped_list = datapoint['parse_result']['open_scoped_list']
open_list = datapoint['parse_result']['open_list']
option_list = datapoint['parse_result']['option_list']
units = datapoint['parse_result']['units']

all_transformed_units = [i_u for i_u, u in enumerate(units) if 'deductive_steps' in units[i_u].keys()]
remaining_units = [i_u for i_u in all_transformed_units if 'generation_process' not in units[i_u].keys()]
logger.debug(f'async_worker({base_cnt+idx}): {len(remaining_units)}/{len(all_transformed_units)} units to reasseblme')
if len(remaining_units) == 0:
    raise

tactic_header = ''
load_header = ''
if len(open_scoped_list):
    tactic_header += 'open scoped ' + ' '.join(t for t in open_scoped_list) + ' in\n'
    load_header += 'open scoped ' + ' '.join(t for t in open_scoped_list) + '\n'
if len(open_list):
    tactic_header += 'open ' + ' '.join(t for t in open_list) + ' in\n'
    load_header += 'open ' + ' '.join(t for t in open_list) + '\n'
if len(option_list):
    tactic_header += '\n'.join('set_option ' + t + ' in' for t in option_list) + '\n'
    load_header += '\n'.join('set_option ' + t for t in option_list) + '\n'

# II. Reassemble trajectories
agent = AutoregressiveProblemGenerationAgent(0)

[32m2025-08-20 22:17:52.983[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [34m[1masync_worker(0): 1/1 units to reasseblme[0m


In [11]:
deductive_steps: List[Tuple[str, str]] = u['deductive_steps']
deductive_states: List[List[Dict]] = u['deductive_states']
if len(deductive_steps) == len(deductive_states):
    deductive_states.append([])
assert len(deductive_steps) + 1 == len(deductive_states)

states: List[GoalState] = []
steps: List[ProblemGenerationStep] = []
cur_problem_state = await server.load_statement_async('False')
states.append(cur_problem_state)

# Execute introducing steps
assert len(deductive_states[0]) == 1
init_parsed_goal = dacite.from_dict(Goal, deductive_states[0][0])
for v in init_parsed_goal.variables:
    name = v.name
    if '✝' in v.name:
        assert v.name.replace('✝', '_') not in str(u['deductive_states'])
        name = v.name.replace('✝', '_')
    cur_step = ProblemGenerationStep(   # ProblemGenerationStepCategory.Introduce
        step_draft=f'have {name} : {v.t} := sorry' if v.v is None else f'let {v.name} : {v.t} := {v.v}',
        proof=None,
        new_contexts=[]
    )
    
    try:
        new_problem_state = await server.goal_tactic_async(cur_problem_state, 0, cur_step.step)
    except (TacticFailure, ServerError):
        cur_step.step_draft = tactic_header + cur_step.step_draft
        new_problem_state = await server.goal_tactic_async(cur_problem_state, 0, cur_step.step)
    assert len(new_problem_state.goals) == 1 and new_problem_state.goals[0].target == 'False', str(new_problem_state)
    idents = set(cur_step.step.split())
    for banned_token in BANNED_TOKENS[1:]:
        if banned_token in idents:
            logger.critical(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Banned token "{banned_token}" in step "{step_code}"')
    
    cur_step.new_contexts = [
        v for v in new_problem_state.goals[0].variables if
            v.raw_name not in {vv.raw_name for vv in cur_problem_state.goals[0].variables}
            # v not in forward_state.goals[0].variables
    ]
    if len(cur_step.new_contexts) != 1:
        logger.critical(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Strange introducing step: {str(cur_step)}')
    
    states.append(new_problem_state)
    steps.append(cur_step)
    cur_problem_state = new_problem_state

if init_parsed_goal.variables != cur_problem_state.goals[0].variables:
    logger.warning(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): init_parsed_goal.variables != cur_problem_state.goals[0].variables: {[str(init_parsed_goal), str(cur_problem_state.goals[0])]}')


[32m2025-08-20 22:17:53.042[0m | [34m[1mDEBUG   [0m | [36mcommon.pantograph.server[0m:[36mcheck_restart_async[0m:[36m566[0m - [34m[1mPersistentServer(): Restarting...[0m


In [12]:
print(cur_problem_state)
print(len(steps), len(states))

x y : ℝ
hx : x < 0
hy : y < 0
h1 : |y| = 6
h2 : √((x - 8) ^ 2 + (y - 3) ^ 2) = 15
n : ℕ
hn : n > 0
h3 : √(x ^ 2 + y ^ 2) = √↑n
⊢ False
9 10


In [13]:
# Execute deriving steps
for ((step_header, step_code), next_parsed_state) in zip(deductive_steps[:-1], deductive_states[1:-1]):
    assert len(next_parsed_state) == 1
    next_parsed_goal = dacite.from_dict(Goal, next_parsed_state[0])
    cur_step = ProblemGenerationStep(   # ProblemGenerationStepCategory.Derive
        step_draft=step_header + step_code,
        proof=[],
        new_contexts=[]
    )
    
    new_problem_state = await server.goal_tactic_async(cur_problem_state, 0, cur_step.step)
    assert len(new_problem_state.goals) == 1 and new_problem_state.goals[0].target == 'False', str(new_problem_state)
    idents = set(cur_step.step.split())
    for banned_token in BANNED_TOKENS:
        if banned_token in idents:
            logger.critical(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Banned token "{banned_token}" in step "{step_code}"')
            
    cur_step.new_contexts = [
        v for v in new_problem_state.goals[0].variables if
            v.raw_name not in {vv.raw_name for vv in cur_problem_state.goals[0].variables}
            # v not in forward_state.goals[0].variables
    ]
    if len(cur_step.new_contexts) == 0:
        logger.warning(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Unused step: {str(cur_step)}')
    
    states.append(new_problem_state)
    steps.append(cur_step)
    cur_problem_state = new_problem_state

    if next_parsed_goal.variables != cur_problem_state.goals[0].variables:
        logger.warning(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): next_parsed_goal.variables != cur_problem_state.goals[0].variables: {[str(next_parsed_goal), str(cur_problem_state.goals[0])]}')


In [14]:
print(cur_problem_state)

x y : ℝ
hx : x < 0
hy✝ : y < 0
h1 : |y| = 6
h2 : √((x - 8) ^ 2 + (y - 3) ^ 2) = 15
n : ℕ
hn : n > 0
h3 : √(x ^ 2 + y ^ 2) = √↑n
this✝ : (x - 8) ^ 2 + (y - 3) ^ 2 = ↑(15 ^ 2)
hy : y = -6
this : x ^ 2 + y ^ 2 = 52
eq : 52 = n
⊢ False


In [15]:
# Execute submitting step
assert len(deductive_states[-1]) == 0
step_code = remove_comments(deductive_steps[-1][-1]).strip()
assert step_code.startswith('exact '), step_code
submission_name = step_code[len('exact '):]

if ' ' in submission_name or '.' in submission_name:
    new_name = generate_submission_name([v.name for v in cur_problem_state.goals[0].variables if v.name is not None])
    cur_step = ProblemGenerationStep(   # ProblemGenerationStepCategory.Derive
        step_draft=f'have {new_name} : {init_parsed_goal.target} := {submission_name}',
        proof=[],
        new_contexts=[]
    )
    submission_name = new_name
    
    try:
        new_problem_state = await server.goal_tactic_async(cur_problem_state, 0, cur_step.step)
    except (TacticFailure, ServerError):
        cur_step.step_draft = tactic_header + cur_step.step_draft
        new_problem_state = await server.goal_tactic_async(cur_problem_state, 0, cur_step.step)
    assert len(new_problem_state.goals) == 1 and new_problem_state.goals[0].target == 'False', str(new_problem_state)
    idents = set(cur_step.step.split())
    for banned_token in BANNED_TOKENS:
        if banned_token in idents:
            logger.critical(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Banned token "{banned_token}" in step "{step_code}"')
    
    cur_step.new_contexts = [
        v for v in new_problem_state.goals[0].variables if
            v.raw_name not in {vv.raw_name for vv in cur_problem_state.goals[0].variables}
            # v not in forward_state.goals[0].variables
    ]
    if len(cur_step.new_contexts) == 0:
        logger.warning(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Unused step: {str(cur_step)}')
    
    states.append(new_problem_state)
    steps.append(cur_step)
    cur_problem_state = new_problem_state

assert submission_name in [v.name for v in cur_problem_state.goals[0].variables], f'submission_name={submission_name}, cur_problem_state={cur_problem_state}'
steps.append(
    ProblemGenerationStep(   # ProblemGenerationStepCategory.Submit
        step_draft=f'submit_answer {submission_name}',
        proof=None,
        new_contexts=None
    )
)

# Parsed trajectory
result = ProblemGenerationProcess(
    informal_problem='',
    informal_answer='',
    informal_solution='',
    header=None,
    formal_statement='',
    formal_solution_draft='',
    formal_proofs='',
    steps=steps,
    dependencies=[],
    trajectory=[(S.goals[0].variables, i) for i, S in enumerate(states)],
    metainfo=dict()
)

[32m2025-08-20 22:18:02.724[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mgenerate_submission_name[0m:[36m55[0m - [34m[1mnumbers_existing=defaultdict(<class 'list'>, {1: [('h', 'text')], 2: [('h', 'text')], 3: [('h', 'text')]}), max_number=4, new_name=h4[0m


In [16]:
print(cur_problem_state)

x y : ℝ
hx : x < 0
hy✝ : y < 0
h1 : |y| = 6
h2 : √((x - 8) ^ 2 + (y - 3) ^ 2) = 15
n : ℕ
hn : n > 0
h3 : √(x ^ 2 + y ^ 2) = √↑n
this✝ : (x - 8) ^ 2 + (y - 3) ^ 2 = ↑(15 ^ 2)
hy : y = -6
this : x ^ 2 + y ^ 2 = 52
eq : 52 = n
h4 : n = 52
⊢ False


In [17]:
steps[-2]

ProblemGenerationStep(step_draft='have h4 : n = 52 := eq.symm', proof=[], new_contexts=[Variable(t='n = 52', v=None, name='h4')])

In [18]:
steps[-1]

ProblemGenerationStep(step_draft='submit_answer h4', proof=None, new_contexts=None)

In [19]:
len(states), len(steps)

(19, 19)

In [20]:
# Reassemble trajectory
is_analyzed = await agent.analyze_async(
    result=result,
    states=states,
    server=server,
    tag=str(base_cnt+idx),
    reassemble_trajectory=True
)


In [21]:
is_analyzed

True

In [22]:
result.metainfo = json.dumps(result.metainfo | {'time_consumption': time.time() - time_start})

u['generation_process'] = result
logger.debug(f'async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): succeeded.')

NameError: name 'time_start' is not defined

### Data Composition - Deductive Proof Generation

In [None]:
proof_lengths = []
for d in data_parsed:
    for u in d['parse_result']['units']:
        if 'deductive_steps' in u.keys():
            assert len(u['invocations'] or []) > 0
            proof_lengths.append(len(u['deductive_steps']))

In [None]:
C.Counter(proof_lengths)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
# bins: Number of intervals to group integers (adjust based on your data range)
plt.hist(proof_lengths, bins=40, color='lightgreen', edgecolor='black', alpha=0.7)

# Add labels and title
plt.xlabel('Proof Length', fontsize=12)
plt.ylabel('Frequency', fontsize=12)

# Add grid lines
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Show plot
plt.tight_layout()
plt.show()

In [None]:
data_deductive_proof_generation = []

In [None]:
header = ("""
import Mathlib
import Aesop

""" + '\n'.join('set_option ' + t.replace('=', ' ') for t in CORE_OPTIONS)).strip()
print(header)

In [None]:
for d in data_parsed:
    for u in d['parse_result']['units']:
        if 'deductive_steps' in u.keys():
            assert len(u['invocations'] or []) > 0
            
            init_state = u['deductive_states'][0]
            assert len(init_state) == 1
            init_state = from_dict(Goal, init_state[0])
            
            whole_proof = u['whole_proof']
            if whole_proof is None:
                # There is only one `whole_proof is None`, and is manually validated by Qi
                whole_proof = ''
                for t, s in u['deductive_steps']:
                    if len(t) > 0:
                        whole_proof += t
                    whole_proof += s + '\n\n'
                whole_proof = whole_proof.strip()
            
            data_deductive_proof_generation.append({
                "conversation":[
                    {
                        "input": f"""
Assume the following header is executed:
```lean4
{header}
```

Generate a deductive proof for the following Lean 4 proof state:
```lean4
{str(init_state)}
```
""".strip(),
                        "output": whole_proof
                    }
                ]
            })

In [None]:
print_xtuner_sample(random.choice(data_deductive_proof_generation))

In [None]:
with open('/home/ma-user/workspace/formal_problem_generation/data/Numina-Lean/deductive_proof_generation.40069.jsonl', 'w') as f:
    for d in data_deductive_proof_generation:
        f.write(json.dumps(d))

### Data Composition - Autoregressive Problem Generation

In [None]:
print(SYSTEM_PROMPT_FPG)

In [None]:
def format_step(self):
    if self.proof is None:
        return self.step_draft
    else:
        normalized_step_draft = normalize_draft(self.step_draft)
        matches = list(re.finditer(':= sorry', normalized_step_draft))
        assert len(matches) == len(self.proof)
        for (m, p) in reversed(list(zip(matches, self.proof))):
            normalized_step_draft = replace_span(m.span(), ':= by {\n' + '\n'.join('  ' + l for l in p.splitlines()) + '\n}', normalized_step_draft)
        return normalized_step_draft

In [None]:
def format_forward_solution_step_prompt(d: ProblemGenerationProcess, state: List[Variable]) -> str:
    context = ''
    vars_to_format = [v for v in state]
    while len(vars_to_format) > 0:
        for i in range(len(vars_to_format)):
            if i + 1 == len(vars_to_format) or not (vars_to_format[i].t == vars_to_format[i+1].t and vars_to_format[i].v is None and vars_to_format[i+1].v is None):
                break
        if i == 0:
            context += str(vars_to_format[0]) + '\n'
            vars_to_format.pop(0)
        else:
            context += ' '.join([v.name if v.name is not None else "_" for v in vars_to_format[:i+1]]) + f' : {vars_to_format[0].t}\n'
            vars_to_format = vars_to_format[i+1:]
    
    prompt = f'''Given a Lean 4 context, propose the single most natural next step to explore toward a beautiful conclusion — either
- derive a new intermediate fact,
- introduce a fresh variable or hypothesis, or
- submit one of the local facts as the final answer.

Requirements
1. Flavoured {d.metainfo['problem_type']} and suitable for posting on forums about {d.metainfo['source']}.
2. Fully formal Lean 4 code (inline comments in natural language are fine for planning and reasoning). Assume `import Mathlib`.


# Lean 4 Context
```lean4
{context.rstrip()}
```
'''
    return prompt


def format_forward_solution_step_response(d: ProblemGenerationProcess, step: ProblemGenerationStep):
    step_type = 'Derive' if step.is_deducing else 'Introduce' if step.is_introducing else 'Submit'
    response = f'''# Step {step_type}
```lean4
{format_step(step).rstrip()}
```
'''
    return response

In [None]:
data_problem_generation = []

In [None]:
for d in tqdm(data_nonsynthetic_n15):
    for p in d.trajectory:
        data_problem_generation.append({
            "conversation":[
                {
                    "system": SYSTEM_PROMPT_FPG,
                    "input": format_forward_solution_step_prompt(d, p[0]),
                    "output": format_forward_solution_step_response(d, d.steps[p[1]])
                }
            ]
        })

In [None]:
len(data_problem_generation)

In [None]:
with open('/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle123_problem_generation_steps.jsonl', 'w') as f:
    for s in data_problem_generation:
        f.write(json.dumps(s) + '\n')

In [None]:
print(format_forward_solution_step_prompt(d, p[0]))

In [None]:
print(format_forward_solution_step_response(d, d.steps[p[1]]))