In [1]:
import os
import os.path as osp
import json
import pickle
import collections as C
import itertools as I
import random
import regex as re
from typing import List, Optional, Dict, Tuple

import msgspec
from tqdm import tqdm
from loguru import logger

from common.constants import SYSTEM_PROMPT_FPG, CORE_OPTIONS
from common.utils import remove_comments, replace_sorry, replace_calc, remove_multiline_comments, remove_singleline_comments
from common.pantograph.dataclasses import ProblemGenerationProcess, ProblemGenerationStep, Variable, normalize_draft, replace_span, Goal
from common.pantograph.server import PersistentServer, TacticFailure, ServerError

def print_xtuner_sample(data: list):
    print('<SYSTEM>')
    print(data['conversation'][0]['system'])
    print('</SYSTEM>\n<INPUT>')
    print(data['conversation'][0]['input'])
    print('</INPUT>\n<OUTPUT>')
    print(data['conversation'][0]['output'])
    print('</OUTPUT>')



In [20]:
def is_deductive(state_before: List[Goal], state_after: List[Goal]) -> bool:
    return len(state_before) == 1 and len(state_after) == 1 and state_before[0].target == state_after[0].target

superscript_to_digit = {
    '⁰': '0', '¹': '1', '²': '2', '³': '3', '⁴': '4',
    '⁵': '5', '⁶': '6', '⁷': '7', '⁸': '8', '⁹': '9'
}

subscript_to_digit = {
    '₀': '0', '₁': '1', '₂': '2', '₃': '3', '₄': '4',
    '₅': '5', '₆': '6', '₇': '7', '₈': '8', '₉': '9'
}

digit_to_superscript = {v: k for k, v in superscript_to_digit.items()}
digit_to_subscript = {v: k for k, v in subscript_to_digit.items()}

allowed_prefices = ['h', 'h_']

def generate_submission_name(name_list: List[str]) -> str:
    # Parse names
    numbers_existing = C.defaultdict(list)
    for n in name_list:
        for p in allowed_prefices:
            if n.startswith(p):
                num_str = n[len(p):]
                if num_str == '':
                    numbers_existing[-1].append((p, 'text'))
                elif all(c in superscript_to_digit for c in num_str):
                    num = int(''.join(superscript_to_digit[c] for c in num_str))
                    numbers_existing[num].append((p, 'sup'))
                elif all(c in subscript_to_digit for c in num_str):
                    num = int(''.join(subscript_to_digit[c] for c in num_str))
                    numbers_existing[num].append((p, 'sub'))
                elif all(c.isascii() and c.isdigit() for c in num_str):
                    num = int(num_str)
                    numbers_existing[num].append((p, 'text'))
                    
    if not numbers_existing:
        numbers_existing = C.defaultdict(list, {
            -1: [('h', 'text')]
        })
    # Generate new name
    max_number = sorted(numbers_existing.keys())[-1]
    number_chosen = max_number + 1
    prefix, format_type = random.choice(numbers_existing[max_number])
    
    if number_chosen == 0:
        formatted_num = ''
    else:
        num_str = str(number_chosen)
        if format_type == 'sup':
            formatted_num = ''.join(digit_to_superscript[c] for c in num_str)
        elif format_type == 'sub':
            formatted_num = ''.join(digit_to_subscript[c] for c in num_str)
        else:  # text
            formatted_num = num_str
    new_name = f"{prefix}{formatted_num}"
    logger.debug(f'numbers_existing={numbers_existing}, max_number={number_chosen}, new_name={new_name}')
    return new_name


In [3]:
server = PersistentServer(
    max_count=64,
    is_state_based=True,
    tag='',
    _sync_init=False,
    imports=["Mathlib", "Aesop"],
    project_path='/home/ma-user/workspace/formal_problem_generation/formal_problem_generation/data/MiniF2F',
    core_options=CORE_OPTIONS,
    timeout=300,
)

In [4]:
base_dir = '/home/ma-user/workspace/formal_problem_generation/data/Numina-Lean/deductive_transformation'
data = []

for i in tqdm(list(range(41))):
    with open(osp.join(base_dir, f'done_chunk_{1024*i}.pkl'), 'rb') as f:
        data.extend(pickle.load(f))

100%|██████████| 41/41 [01:17<00:00,  1.88s/it]


In [5]:
len(data)

41109

In [6]:
n_transformed = 0
n_remaining = 0

for datapoint in data:
    if 'exception' in datapoint.keys() or 'traceback' in datapoint.keys():
        continue
    p_raw = datapoint['formal_code']

    import_list = datapoint['parse_result']['import_list']
    open_scoped_list = datapoint['parse_result']['open_scoped_list']
    open_list = datapoint['parse_result']['open_list']
    option_list = datapoint['parse_result']['option_list']
    units = datapoint['parse_result']['units']

    
    all_parsed_units = [i_u for i_u, u in enumerate(units) if len(u['invocations'] or []) > 0]
    remaining_units = [i_u for i_u in all_parsed_units if 'deductive_steps' not in units[i_u].keys()]
    
    n_transformed += len(all_parsed_units) - len(remaining_units)
    n_remaining += len(remaining_units)

n_transformed, n_remaining

(37773, 3770)

In [7]:
def count_indent(s: str) -> int:
    count = 0
    for char in s:
        if char == ' ':
            count += 1
        else:
            break
    return count

def proof_decompose(formal_proof: str) -> list[str]:
    '''Decompose a formal solution draft into steps'''
    # Count the minimal indents of all tactics
    min_indents = float('inf')
    pure_proof_lines = replace_sorry(replace_calc(remove_comments(formal_proof))).split('\n')
    for l in pure_proof_lines:
        if l.strip() != '':
            min_indents = min(min_indents, count_indent(l))

    # Reset the minimal indents to zero
    levels = []
    raw_lines = []  # (n_indents, line)
    for l in replace_sorry(replace_calc(remove_multiline_comments(formal_proof))).rstrip().split('\n'):
        n_indent = count_indent(l)
        if n_indent < min_indents:
            assert len(remove_comments(l).strip()) == 0
        
        if len(remove_comments(l).strip()) == 0:
            level = float('inf')   # Comment
        else:
            level = n_indent - min(n_indent, min_indents)   # Tactic
        raw_lines.append(l[min(n_indent, min_indents):])
        levels.append(level)
    
    # print('\n'.join(raw_lines))
    is_first = True
    parse_result = []
    cur_block = []
    for (level, line) in zip(levels, raw_lines):
        # print(line)
        if len(line.strip()) == 0:
            continue
        if level != 0:
            cur_block.append(line)
        else:   # Root-level tactic
            if is_first:    # First tactic block: neglect and add
                is_first = False
                cur_block.append(line)
            else:   # Other tactic block: end and new
                parse_result.append('\n'.join(cur_block))
                # print('\n<begin>\n' + parse_result[-1], end='\n<end>\n')
                cur_block = [line]
    
    if len(cur_block) > 0:
        parse_result.append('\n'.join(cur_block))
        # print('\n<begin>\n' + parse_result[-1], end='\n<end>\n')
    
    return parse_result

bracket_pairings = {
    '(' : ')',
    '[' : ']',
    '{' : '}',
    '⦃' : '⦄'
}

def parse_variables(s : str) -> Tuple[str, str]:
    base = 0
    variables = []
    target = None
    while base < len(s):
        if s[base] in ['(', '[', '{', '⦃']:
            bracket_type = s[base]
            bracket_pairing = bracket_pairings[bracket_type]
        
            stack_cnt = 0
            start_end_positions = []

            for i, char in enumerate(s[base:]):
                if char == bracket_type:
                    if stack_cnt == 0:
                        start_position = i
                    stack_cnt += 1
                elif char == bracket_pairing:
                    if stack_cnt > 0:
                        stack_cnt -= 1
                        if stack_cnt == 0:
                            end_position = i
                            start_end_positions.append((start_position, end_position))
                            break
            
            start, end = start_end_positions[0]
            variables.append(s[base+start:base+end+1])
            base += i
        else:
            if s[base] == ':':
                target = s[base+1:]
                break
            base += 1
    
    return variables, target

In [38]:
random.shuffle(data)
for datapoint in data:
    if 'exception' in datapoint.keys() or 'traceback' in datapoint.keys():
        continue
    p_raw = datapoint['formal_code']

    import_list = datapoint['parse_result']['import_list']
    open_scoped_list = datapoint['parse_result']['open_scoped_list']
    open_list = datapoint['parse_result']['open_list']
    option_list = datapoint['parse_result']['option_list']
    units = datapoint['parse_result']['units']

    
    all_parsed_units = [i_u for i_u, u in enumerate(units) if len(u['invocations'] or []) > 0]
    remaining_units = [i_u for i_u in all_parsed_units if 'deductive_steps' not in units[i_u].keys()]
    
    if len(remaining_units) > 0:
       print(f'{len(remaining_units)}/{len(all_parsed_units)} units to transform')
       break

1/1 units to transform


In [39]:
base_cnt = 0
idx = 0
datapoint.keys()

dict_keys(['uuid', 'problem', 'question_type', 'answer', 'author', 'formal_statement', 'ground_truth_type', 'rl_data', 'source', 'problem_type', 'exam', 'formal_code', 'parse_result'])

In [40]:
assert 'exception' not in datapoint.keys() and 'traceback' not in datapoint.keys()

# I. Parse tactic invocation
p_raw = datapoint['formal_code']

import_list = datapoint['parse_result']['import_list']
open_scoped_list = datapoint['parse_result']['open_scoped_list']
open_list = datapoint['parse_result']['open_list']
option_list = datapoint['parse_result']['option_list']
units = datapoint['parse_result']['units']

all_parsed_units = [i_u for i_u, u in enumerate(units) if len(u['invocations'] or []) > 0]
remaining_units = [i_u for i_u in all_parsed_units if 'deductive_steps' not in units[i_u].keys()]
logger.debug(f'async_worker({base_cnt+idx}): {len(remaining_units)}/{len(all_parsed_units)} units to transform')

if len(remaining_units) == 0:
    raise


tactic_header = ''
load_header = ''
if len(open_scoped_list):
    tactic_header += 'open scoped ' + ' '.join(t for t in open_scoped_list) + ' in\n'
    load_header += 'open scoped ' + ' '.join(t for t in open_scoped_list) + '\n'
if len(open_list):
    tactic_header += 'open ' + ' '.join(t for t in open_list) + ' in\n'
    load_header += 'open ' + ' '.join(t for t in open_list) + '\n'
if len(option_list):
    tactic_header += '\n'.join('set_option ' + t + ' in' for t in option_list) + '\n'
    load_header += '\n'.join('set_option ' + t for t in option_list) + '\n'

p_injected: List[str] = p_raw.splitlines()
for (i, l) in reversed(list(enumerate(p_injected))):
    if l.startswith('import '):
        i += 1
        break
p_injected = '\n'.join(p_injected[:i]) + '\n\n' + '\n'.join('set_option ' + t.replace('=', ' ') for t in CORE_OPTIONS) + '\n\n' + '\n'.join(p_injected[i:])


[32m2025-08-20 16:27:16.490[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m14[0m - [34m[1masync_worker(0): 1/1 units to transform[0m


In [41]:
i_u = remaining_units[0]
u = units[i_u]

In [42]:
invocations = u['invocations']
assert len(invocations[0]['before']) == 1, 'Initial state contains multiple goals'
# nonhygienic_transformer = factory_nonhygienic_transformer(invocations[0]['before'][0])

code_segment = remove_comments(p_injected.encode()[u['i_begin']:u['i_end']].decode())
start_pos = None
for start_pos in re.finditer(r':=\s*by', code_segment):
    break
assert start_pos is not None, '":= by" not found'
statement_code, proof_code = code_segment[:start_pos.span(0)[0]], code_segment[start_pos.span(0)[1]:]

# Preprocess steps (deprecated in Pantograph v0.3.5)
# for ivc in invocations:
#     ivc['tactic'] = ivc['tactic'].replace('native_decide', 'decide')
# proof_code = proof_code.replace('native_decide', 'decide')

# 1. Parse Context from statement code
context, target = parse_variables(statement_code)
assert target is not None, f'Target parsing failed: {statement_code}'

# 2. Parse intros
hypotheses = []
intros = []
for i_ctx, declaration in enumerate(context):
    if declaration[0] == '[':
        intros.append('_')
        hypotheses.append(declaration)
    else:
        assert '✝' not in declaration, f'declaration: {declaration}'
        try:
            var_names, var_type = declaration[1:-1].split(':', 1)
        except ValueError:
            var_names = declaration[1:-1]
        # var_names = [n if '✝' not in n else '_' for n in var_names.strip().split(' ')]
        intros.extend(var_names.strip().split(' '))
        hypotheses.append('(' + declaration[1:-1] + ')')    # Replace '{v : T}' into '(v : T)

# 3. Load statement (before Pantograph v0.3.5)
formal_statement = (('∀ ' + '\n'.join(hypotheses) + '\n, ') if len(hypotheses) > 0 else '') + target
assert '⊢' not in formal_statement, '⊢ in formal_statement'
try:
    init_state = await server.load_statement_async(formal_statement, intros=intros, header=load_header)
except Exception as e:
    raise RuntimeError(e, context, target, load_header)
# assert all(match_wo_mvar(nonhygienic_transformer(g_parsed), str(g_now)) for g_parsed, g_now in zip(invocations[0]['before'], init_state.goals)), 'initial state not equivalent w/ parse results'
assert len(init_state.goals) == 1, 'deductive step execution failed' #* Non-strict match

# 3. Load statement (after Pantograph v0.3.5)
# formal_statement = ('example\n' + '\n'.join(hypotheses) + '\n: ') + target + '\n:= sorry'
# try:
#     init_units = await server.load_sorry_async(tactic_header + formal_statement)
#     assert all(m.severity != Severity.ERROR for u in init_units for m in u.messages), f'State initialization failed: {str([m for u in init_units for m in u.messages])}'
#     init_state = init_units[-1].goal_state
#     assert init_state is not None and len(init_state.goals) == 1, f'State initialization failed: {str(init_state)}' #* Non-strict match
# except Exception as e:
#     raise RuntimeError(*e.args, context, target, tactic_header)
# assert all(match_wo_mvar(nonhygienic_transformer(g_parsed), str(g_now)) for g_parsed, g_now in zip(invocations[0]['before'], init_state.goals)), 'initial state not equivalent w/ parse results'

# Start transforming
raw_steps = proof_decompose(proof_code)

states: List[List[Goal]] = [init_state.goals[:]]
deductive_steps: List[Tuple[str, str]] = []
cur_state = init_state

In [43]:
print(formal_statement)


    ∑ᶠ x ∈ {x : ℝ | 0 < x ∧ 2 * x ^ 2 - ⌊x⌋ * x = 5}, x = (3 + Real.sqrt 41 + 2 * Real.sqrt 11) / 4 


In [44]:
print(p_injected)

import Mathlib

set_option maxHeartbeats 0
set_option maxRecDepth 100000
set_option tactic.hygienic false
set_option pp.fullNames true
set_option pp.funBinderTypes true
set_option pp.piBinderTypes true


theorem algebra_608727 :
    ∑ᶠ x ∈ {x : ℝ | 0 < x ∧ 2 * x ^ 2 - ⌊x⌋ * x = 5}, x = (3 + Real.sqrt 41 + 2 * Real.sqrt 11) / 4 := by
  have h1 : {x : ℝ | 0 < x ∧ 2 * x ^ 2 - ⌊x⌋ * x = 5} = {(1 + Real.sqrt 41) / 4, (1 + Real.sqrt 11) / 2} := by
    ext x
    simp only [Set.mem_setOf_eq, Set.mem_insert_iff, Set.mem_singleton_iff, and_imp]
    constructor
    · -- Assume x is a positive solution
      intro h
      rcases h with ⟨hx_pos, heq⟩
      let n := ⌊x⌋
      have h2 : n ≤ (x : ℝ) := by exact Int.floor_le x
      have h3 : x < (n + 1 : ℝ) := by exact Int.lt_floor_add_one x
      have h4 : (n : ℝ) ≤ x := by exact_mod_cast h2
      have h5 : x < (n + 1 : ℝ) := by exact_mod_cast h3
      have h6 : 2 * x ^ 2 - (n : ℝ) * x = 5 := by linarith
      have hn : n ≥ 0 := by
        by_contra 

In [45]:
print(init_state)

⊢ ∑ᶠ (x : ℝ) (_ : x ∈ {x : ℝ | 0 < x ∧ 2 * x ^ 2 - ↑⌊x⌋ * x = 5}), x = (3 + √41 + 2 * √11) / 4


In [46]:
while len(raw_steps) > 0:
    # Execute cur_step
    cur_step = raw_steps[0]
    used_tactic_header = ''
    try:
        next_state = await server.goal_tactic_async(cur_state, 0, cur_step)
    except (TacticFailure, ServerError):
        used_tactic_header = tactic_header
        next_state = await server.goal_tactic_async(cur_state, 0, tactic_header + cur_step)
    
    if next_state.is_solved:
        if remove_comments(cur_step).strip().startswith('exact '):
            # If (solved) and (the final step is `exact`): add cur_step and break
            raw_steps = []
            cur_state = next_state
            states.append(cur_state.goals[:])
            deductive_steps.append((used_tactic_header, cur_step))
            logger.info(f"async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Detected `exact` submission: {[remove_comments(cur_step).strip()]}")
            break
        else:
            # If (solved) but (the final step is not `exact`): don't add cur_step, don't update state
            raw_steps = [cur_step]
            logger.info(f"async_worker({base_cnt+idx}-{i_p}/{len(remaining_units)}): Detected non-`exact` submission: {[remove_comments(cur_step).strip()]}")
            break   # If the final step is not `exact`, 1) do not add to `steps` - leave it for final submission; 2) do not update `cur_state`
    else:
        if not is_deductive(cur_state.goals, next_state.goals):
            # If (not solved) but (not deductive): don't add cur_step, don't update state
            break
        else:
            # If (not solved) and (is deductive): add cur_step and continue
            raw_steps.pop(0)
            cur_state = next_state
            states.append(cur_state.goals[:])
            deductive_steps.append((used_tactic_header, cur_step))


In [47]:
print(cur_state)
print('\n')
print(next_state)

h1 : {x : ℝ | 0 < x ∧ 2 * x ^ 2 - ↑⌊x⌋ * x = 5} = {(1 + √41) / 4, (1 + √11) / 2}
⊢ ∑ᶠ (x : ℝ) (_ : x ∈ {x : ℝ | 0 < x ∧ 2 * x ^ 2 - ↑⌊x⌋ * x = 5}), x = (3 + √41 + 2 * √11) / 4


h1 : {x : ℝ | 0 < x ∧ 2 * x ^ 2 - ↑⌊x⌋ * x = 5} = {(1 + √41) / 4, (1 + √11) / 2}
⊢ ∑ᶠ (x : ℝ) (_ : x ∈ {(1 + √41) / 4, (1 + √11) / 2}), x = (3 + √41 + 2 * √11) / 4


In [48]:
# Remaining non-deductive steps
if len(raw_steps) > 0:
    proof_state = cur_state

    submission_name = generate_submission_name([v.name for v in cur_state.goals[0].variables if v.name is not None])
    have_step = f'have {submission_name}: {target} := by {{\n' + '\n'.join(raw_steps) + '\n}'
    states.append(proof_state.goals[:])
    try:
        proof_state = await server.goal_tactic_async(proof_state, 0, have_step)
        assert (len(proof_state.goals) == 1 and proof_state.goals[0].target == cur_state.goals[0].target), f'`have {submission_name}` failed due to proof state: ' + str(proof_state)
        deductive_steps.append(('', have_step))
    except:
        proof_state = await server.goal_tactic_async(proof_state, 0, tactic_header + have_step)
        assert (len(proof_state.goals) == 1 and proof_state.goals[0].target == cur_state.goals[0].target), f'`have {submission_name}` failed due to proof state: ' + str(proof_state)
        deductive_steps.append((tactic_header, have_step))

    states.append(proof_state.goals[:])
    submit_step = f'exact {submission_name}'
    try:
        proof_state = await server.goal_tactic_async(proof_state, 0, submit_step)
        assert proof_state.is_solved, f'`exact {submission_name}` failed due to proof state: ' + str(proof_state)
        deductive_steps.append(('', submit_step))
    except:
        proof_state = await server.goal_tactic_async(proof_state, 0, tactic_header + submit_step)
        assert proof_state.is_solved, f'`exact {submission_name}` failed due to proof state: ' + str(proof_state)
        deductive_steps.append((tactic_header, submit_step))

[32m2025-08-20 16:27:45.314[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mgenerate_submission_name[0m:[36m58[0m - [34m[1mnumbers_existing=defaultdict(<class 'list'>, {1: [('h', 'text')]}), max_number=2, new_name=h2[0m


In [50]:
# Validate whole proof
whole_proof = ''
for t, s in deductive_steps:
    if len(t) > 0:
        whole_proof += t
    whole_proof += s + '\n\n'
whole_proof = whole_proof.strip()

try:
    final_state = await server.goal_tactic_async(init_state, 0, '{\n' + whole_proof + '\n}')
    assert final_state.is_solved, 'final_state.is_solved Failed'
except Exception as e:
    whole_proof = None
    print(e)

In [108]:
steps = proof_decompose(proof_code)
for s in steps:
    print(s)
    print('\n\n')

rw [hr] at h 



have eq1 : speed * 3 = 2 * (2 * π * (9 : ℝ)) := h 



have eq2 : speed * 3 = 36 * π := calc
    speed * 3 = 2 * (2 * π * (9 : ℝ)) := h 
    _ = 2 * (18 * π) := by ring 
    _ = 36 * π := by ring



have eq3 : speed = 12 * π := by
  linarith



exact eq3





In [109]:
formal_proof = proof_code

In [39]:
parse_result

['rw [hr] at h',
 'have eq1 : speed * 3 = 2 * (2 * π * (9 : ℝ)) := h',
 'have eq2 : speed * 3 = 36 * π := calc\n    speed * 3 = 2 * (2 * π * (9 : ℝ)) := h\n    _ = 2 * (18 * π) := by ring\n    _ = 36 * π := by ring',
 'have eq3 : speed = 12 * π := by\n  linarith',
 'exact eq3\n']

In [22]:
print('\n\n'.join(steps))

  rw [hr] at h
  have eq1 : speed * 3 = 2 * (2 * π * (9 : ℝ)) := h
  have eq2 : speed * 3 = 36 * π := calc
      speed * 3 = 2 * (2 * π * (9 : ℝ)) := h
      _ = 2 * (18 * π) := by ring
      _ = 36 * π := by ring
  have eq3 : speed = 12 * π := by
    linarith
  exact eq3


In [23]:
print(proof_code)


  rw [hr] at h 
  have eq1 : speed * 3 = 2 * (2 * π * (9 : ℝ)) := h 
  have eq2 : speed * 3 = 36 * π := by 
    calc
      speed * 3 = 2 * (2 * π * (9 : ℝ)) := h 
      _ = 2 * (18 * π) := by ring 
      _ = 36 * π := by ring
  have eq3 : speed = 12 * π := by
    linarith
  exact eq3


In [24]:
print(p_injected)

import Mathlib

set_option maxHeartbeats 0
set_option maxRecDepth 100000
set_option tactic.hygienic false
set_option pp.fullNames true
set_option pp.funBinderTypes true
set_option pp.piBinderTypes true


theorem algebra_20287 (radius : ℝ) (hr : radius = 9) (speed : ℝ) (h : speed * 3 = 2 * (2 * π * radius)) :
    speed = 12 * π := by
  rw [hr] at h 
  have eq1 : speed * 3 = 2 * (2 * π * (9 : ℝ)) := h 
  have eq2 : speed * 3 = 36 * π := by 
    calc
      speed * 3 = 2 * (2 * π * (9 : ℝ)) := h 
      _ = 2 * (18 * π) := by ring 
      _ = 36 * π := by ring
  have eq3 : speed = 12 * π := by
    linarith
  exact eq3


In [None]:
code_segment = remove_comments(p_injected.encode()[u['i_begin']:u['i_end']].decode())

In [None]:
print(p_injected.encode()[u['i_begin']:u['i_end']].decode())

In [None]:
u

In [None]:

print(d.formal_statement)

In [None]:
print(len(data))
data = [d for d in data if d is not None]
print(len(data))

In [None]:
for d in data:
    d.metainfo = json.loads(d.metainfo)
    if 'proof_search_results' in d.metainfo:
        d.metainfo.pop('proof_search_results')

In [None]:
C.Counter([
    str(sorted(d.metainfo.keys())) for d in data
])

In [None]:
C.Counter([
    d.metainfo['problem_type'] for d in data if 'problem_type' in d.metainfo
])

In [None]:
C.Counter([
    d.metainfo['question_type'] for d in data if 'question_type' in d.metainfo
])

In [None]:
C.Counter([
    d.metainfo['source'] for d in data if 'source' in d.metainfo
])

In [None]:
C.Counter([
    d.metainfo['synthetic'] for d in data if 'synthetic' in d.metainfo
])

In [None]:
C.Counter([
    d.metainfo['subject'] for d in data if 'subject' in d.metainfo
])

In [None]:
C.Counter([
    str(sorted(d.metainfo.keys())) for d in data
])

In [None]:
C.Counter([
    type(d.metainfo) for d in data if d is not None
])

In [None]:
data_nonsynthetic_n15 = [
    d for d in data if 'synthetic' in d.metainfo and d.metainfo['synthetic'] is False
]
len(data_nonsynthetic_n15)

In [None]:
d = random.choice(data_nonsynthetic_n15)
print(d.formal_statement)

In [None]:
C.Counter([
    d.metainfo['problem_type'] for d in data_nonsynthetic_n15 if 'problem_type' in d.metainfo
])

In [None]:
C.Counter([
    d.metainfo['source'] for d in data_nonsynthetic_n15 if 'source' in d.metainfo
])

In [None]:
SYSTEM_PROMPT_FPG = '''You are an Olympiad problem setter and a Lean 4 expert.
You revel in conjuring elegant problems — starting from a spare set of hypotheses, you let rigorous deduction lead you to surprising and beautiful conclusions.'''

In [None]:
def format_step(self):
    if self.proof is None:
        return self.step_draft
    else:
        normalized_step_draft = normalize_draft(self.step_draft)
        matches = list(re.finditer(':= sorry', normalized_step_draft))
        assert len(matches) == len(self.proof)
        for (m, p) in reversed(list(zip(matches, self.proof))):
            normalized_step_draft = replace_span(m.span(), ':= by {\n' + '\n'.join('  ' + l for l in p.splitlines()) + '\n}', normalized_step_draft)
        return normalized_step_draft

In [None]:
def format_forward_solution_step_prompt(d: ProblemGenerationProcess, state: List[Variable]) -> str:
    context = ''
    vars_to_format = [v for v in state]
    while len(vars_to_format) > 0:
        for i in range(len(vars_to_format)):
            if i + 1 == len(vars_to_format) or not (vars_to_format[i].t == vars_to_format[i+1].t and vars_to_format[i].v is None and vars_to_format[i+1].v is None):
                break
        if i == 0:
            context += str(vars_to_format[0]) + '\n'
            vars_to_format.pop(0)
        else:
            context += ' '.join([v.name if v.name is not None else "_" for v in vars_to_format[:i+1]]) + f' : {vars_to_format[0].t}\n'
            vars_to_format = vars_to_format[i+1:]
    
    prompt = f'''Given a Lean 4 context, propose the single most natural next step to explore toward a beautiful conclusion — either
- derive a new intermediate fact,
- introduce a fresh variable or hypothesis, or
- submit one of the local facts as the final answer.

Requirements
1. Flavoured {d.metainfo['problem_type']} and suitable for posting on forums about {d.metainfo['source']}.
2. Fully formal Lean 4 code (inline comments in natural language are fine for planning and reasoning). Assume `import Mathlib`.


# Lean 4 Context
```lean4
{context.rstrip()}
```
'''
    return prompt


def format_forward_solution_step_response(d: ProblemGenerationProcess, step: ProblemGenerationStep):
    step_type = 'Derive' if step.is_deducing else 'Introduce' if step.is_introducing else 'Submit'
    response = f'''# Step {step_type}
```lean4
{format_step(step).rstrip()}
```
'''
    return response

In [None]:
data_problem_generation = []

In [None]:
for d in tqdm(data_nonsynthetic_n15):
    for p in d.trajectory:
        data_problem_generation.append({
            "conversation":[
                {
                    "system": SYSTEM_PROMPT_FPG,
                    "input": format_forward_solution_step_prompt(d, p[0]),
                    "output": format_forward_solution_step_response(d, d.steps[p[1]])
                }
            ]
        })

In [None]:
len(data_problem_generation)

In [None]:
with open('/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle123_problem_generation_steps.jsonl', 'w') as f:
    for s in data_problem_generation:
        f.write(json.dumps(s) + '\n')

In [None]:
print(format_forward_solution_step_prompt(d, p[0]))

In [None]:
print(format_forward_solution_step_response(d, d.steps[p[1]]))