In [1]:
import os
import os.path as osp
import json
import pickle
import collections as C
import itertools as I
import random
from typing import Optional, List, Tuple

import numpy as np
from loguru import logger
from math_verify import parse

from common.constants import CORE_OPTIONS
from common.utils import remove_comments, normalize_spaces, remove_spaces, replace_sorry, extract_code
from common.pantograph.dataclasses import ProblemGenerationProcess

In [5]:
with open('/sfs/liuqi/data/fpg_valid_fixed_evaluated/sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch-valid_samples.jsonl', 'r') as f:
    data = [json.loads(l) for l in f.readlines()]

psq = [d for d in data if (d['informalization']['problem_type'] == 'Problem-Solving Question' and d['informalization']['informal_answer_parsed'] is not None)]
print(len(data), len(psq))
random.seed(42)
sampled = random.sample(psq, 100)

2726 1326


In [6]:
with open('/home/ma-user/workspace/formal_problem_generation/formal_problem_generation/informalization.1326.seed42.manual_check.original_answer.md', 'w') as f:
    for d in sampled:
        f.write(f'## Index {d["index"]}, Src {d["src"]}\n')
        f.write(f'### Formal Statement\n```lean4\n{d["formal_statement"].strip()}\n```\n')
        f.write(f'### Informal Problem\n{d["informalization"]["informal_problem"].strip()}\n')
        f.write(f'### Informal Answer\n{d["informalization"]["informal_answer"].strip()}\n\n\n')

In [2]:
def is_falsified(d: ProblemGenerationProcess) -> bool:
    if 'eval_old_result' in d.metainfo:
        if any(p is not None for p in d.metainfo['eval_old_result']['falsify_proofs']):
            return True
    assert 'eval_result' in d.metainfo
    # 'eval_result' may contain both falsifying and satisfying
    if any(p is not None for p in d.metainfo['eval_result'].get('satisfy_proofs', [None])):
        return False
    return any(p is not None for p in d.metainfo['eval_result']['falsify_proofs'])

def is_proven(d: ProblemGenerationProcess) -> bool:
    if len(d.trajectory) > 0:
        # Deductive Exploration
        return len(d.formal_statement) > 0 and d.metainfo.get('is_solution_validated')
    else:
        # Baselines
        if len(d.formal_solution_draft or '') > 0:
            return True
        if any(p is not None for p in d.metainfo.get('eval_result', dict()).get('proofs', [None])):
            for p in d.metainfo.get('eval_result', dict()).get('proofs', [None]):
                if p is not None:
                    d.formal_solution_draft = p
            logger.warning(f"`len(d.formal_solution_draft or '') == 0` but `eval_result` is proven")
            return True
        return False

def is_valid(d: ProblemGenerationProcess) -> bool:
    return is_proven(d) and not is_falsified(d)

def count_prompt_token_cost(d: ProblemGenerationProcess) -> int:
    if len(d.trajectory) > 0:
        # Deductive Exploration
        if 'prompt_tokens' in d.metainfo.keys():
            return d.metainfo['prompt_tokens'] * d.metainfo.get('server_failure_token_discounted_ratio', 1.0)
        else:
            return sum(d.metainfo['token_usage']['prompt_tokens'][d.metainfo.get('first_failure_idx', 0):])
    else:
        # Baselines
        total_prompt_tokens = 0
        if 'token_usage' in d.metainfo and 'generate_statement' in d.metainfo['token_usage']:
            # Whole-statement generation baseline
            assert 'token_usage:stmt_autoformalizer' not in d.metainfo
            assert len(d.metainfo['token_usage']['generate_statement']) == 2
            total_prompt_tokens += d.metainfo['token_usage']['generate_statement'][1]
            if 'provers.prove' in d.metainfo['token_usage']:
                total_prompt_tokens += sum(d.metainfo['token_usage']['provers.prove']['prompt_tokens'])
        elif 'token_usage:stmt_autoformalizer' in d.metainfo:
            # Autoformalization-based baselines
            total_prompt_tokens += sum(d.metainfo['token_usage:stmt_autoformalizer']['prompt_tokens'])
            if 'eval_result' in d.metainfo.keys():
                total_prompt_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['prompt_tokens'])
        else:
            # MUSTARD baseline
            total_prompt_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['prompt_tokens'])
        return total_prompt_tokens

def count_token_cost(d: ProblemGenerationProcess) -> int:
    if len(d.trajectory) > 0:
        # Deductive Exploration
        if 'completion_tokens' in d.metainfo.keys():
            return d.metainfo['completion_tokens'] * d.metainfo.get('server_failure_token_discounted_ratio', 1.0)
        else:
            return sum(d.metainfo['token_usage']['completion_tokens'][d.metainfo.get('first_failure_idx', 0):])
    else:
        # Baselines
        total_completion_tokens = 0
        if 'token_usage' in d.metainfo and 'generate_statement' in d.metainfo['token_usage']:
            # Whole-statement generation baseline
            assert 'token_usage:stmt_autoformalizer' not in d.metainfo
            assert len(d.metainfo['token_usage']['generate_statement']) == 2
            total_completion_tokens += d.metainfo['token_usage']['generate_statement'][0]
            if 'provers.prove' in d.metainfo['token_usage']:
                total_completion_tokens += sum(d.metainfo['token_usage']['provers.prove']['completion_tokens'])
        elif 'token_usage:stmt_autoformalizer' in d.metainfo:
            # Autoformalization-based baselines
            total_completion_tokens += sum(d.metainfo['token_usage:stmt_autoformalizer']['completion_tokens'])
            if 'eval_result' in d.metainfo.keys():
                total_completion_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])
        else:
            # MUSTARD baseline
            total_completion_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])
        return total_completion_tokens

def count_kc(d: ProblemGenerationProcess) -> int:
    try:
        if 'proving_results' in d.metainfo:
            return len(remove_spaces(remove_comments(d.formal_solution_draft)))
        else:
            return d.metainfo['eval_result']['KC']
    except:
        # return len(remove_spaces(remove_comments(d.formal_solution_draft)))
        return float('nan')

def count_step_kc(d: ProblemGenerationProcess) -> int:
    try:
        return len(remove_spaces(remove_comments(d.formal_solution_draft)))
    except:
        return float('nan')

def count_falsifier_prompt_token_cost(d: ProblemGenerationProcess) -> int:
    try:
        return sum((sum(v) for v in d.metainfo['falsifier_token_usage']['prompt_tokens']))
    except:
        return 0

def count_falsifier_token_cost(d: ProblemGenerationProcess) -> int:
    try:
        return sum((sum(v) for v in d.metainfo['falsifier_token_usage']['completion_tokens']))
    except:
        return 0

### Results Analysis

In [3]:
output_root = '/home/ma-user/workspace/formal_problem_generation/output_tmp/output'


# # Numina-Lean.whole_statement_generatior
# path_prefix = 'sft_wg_starified/Goedel-Prover-V2-8B.Numina-Lean.whole_statement_generatior.nopack'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-162131.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250909-232315.log'
# )

# # Numina-Lean-linear.39980.problem_generator
# path_prefix = 'sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-linear.39980.problem_generator.nopack'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_kc.20250912-172618.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250908-092208.log'
# )

# # Numina-Lean.problem_generator.nopack
# path_prefix = 'sft_ar_v2_strict/Goedel-Prover-V2-8B.Numina-Lean.problem_generator.nopack'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_kc.20250912-202850.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250906-102831.log'
# )

# Numina-Lean-reasseblmed.39509.problem_generator
path_prefix = 'sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch'
load_path = osp.join(
    output_root, path_prefix,
    'fpg_evaluate_kc.20250913-004304.pkl'
)
log_path = osp.join(
    output_root, path_prefix,
    'problem_generation.20250909-090734.log'
)

# # Main (Repeated)
# path_prefix = 'sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch.0913'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250913-220057.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250913-124624.log'
# )

# # Ablation (-Order)
# path_prefix = 'sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch.0913.staged'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250914-104447.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'staged_problem_generation.20250913-214247.log'
# )

# # PromptCoT-QwQ_kimina7b
# path_prefix = 'autoformalization_pg_kimina7b/PromptCoT-QwQ_kimina7b'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154953.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154953.log'
# )

# # PromptCoT-DS_kimina7b
# path_prefix = 'autoformalization_pg_kimina7b/PromptCoT-DS_kimina7b'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154956.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154956.log'
# )

# ScaleQuest-Math_kimina7b
# path_prefix = 'autoformalization_pg_kimina7b/ScaleQuest-Math_kimina7b'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154955.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154955.log'
# )

# # MUSTARDSauce_lean4_parsed
# path_prefix = 'MUSTARDSauce_lean4_parsed'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-121840.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-121840.log'
# )

In [4]:
failure_steps = C.defaultdict(list)
with open(load_path, 'rb') as f:
    (conditions, finished_list) = pickle.load(f)
with open(log_path, 'r') as f:
    for l in f.readlines():
        l = l.strip()
        if 'agent.problem_generation:generate_async' in l and l.endswith('failed due to AssertionError()'):
            e = l.split()
            assert e[7].startswith('generate_async(') and e[7].endswith('):')
            gen_idx = int(e[7][len('generate_async('):-len('):')])
            step_idx = int(e[8].split('/')[0])
            failure_steps[gen_idx].append(step_idx)

for d in finished_list:
    try:
        d.metainfo = json.loads(d.metainfo)
    except:
        pass

if path_prefix == 'MUSTARDSauce_lean4_parsed':
    conditions, finished_list = [c for (c, d) in zip(conditions, finished_list) if 'eval_result' in d.metainfo.keys()], [d for (c, d) in zip(conditions, finished_list) if 'eval_result' in d.metainfo.keys()]

len(conditions), len(finished_list), len(failure_steps)

(5000, 5000, 517)

In [5]:
for (gen_idx, step_idx_list) in failure_steps.items():
    if step_idx_list == list(range(step_idx_list[0], 81)):
        # print(f'True, {gen_idx}')
        assert len(finished_list[gen_idx].steps) < step_idx_list[0]
        finished_list[gen_idx].metainfo['server_failure_token_discounted_ratio'] = (1 - len(step_idx_list) / 80)
        finished_list[gen_idx].metainfo['first_failure_idx'] = step_idx_list[0] - 1
        # print(len(step_idx_list) / 80)
    else:
        print(f'False, {gen_idx}, {len(set(range(step_idx_list[0], 81))-set(step_idx_list))}')

False, 215, 30
False, 783, 21
False, 1202, 60
False, 1802, 31
False, 2374, 2
False, 2548, 43
False, 3183, 7
False, 3564, 2
False, 4264, 27


In [6]:
for d in finished_list:
    if 'eval_result' in d.metainfo:
        break
d.metainfo

{'is_statement_validated': True,
 'is_solution_validated': True,
 'prompt_tokens': 9265,
 'completion_tokens': 34581,
 'time_consumption': 1537.0823423862457,
 'eval_result': {'satisfy_provers': ['/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B'],
  'satisfy_proofs': ["\n  have h_main : ∃ (n : ℕ), n ≠ 0 ∧ (∀ p ∈ n.primeFactors, p - 1 ∣ n - 1) ∧ True := by\n    use 1\n    constructor\n    · -- Prove that 1 ≠ 0\n      norm_num\n    constructor\n    · -- Prove that for all p in 1.primeFactors, p - 1 ∣ 1 - 1\n      intro p hp\n      -- Since 1 has no prime factors, the condition is vacuously true\n      simp [Nat.primeFactors] at hp\n      <;> aesop\n    · -- Prove that True is trivially true\n      trivial\n  -- Extract the witness from h_main\n  rcases h_main with ⟨n, hn₁, hn₂, hn₃⟩\n  refine' ⟨n, hn₁, hn₂, _⟩\n  <;> simp_all"],
  'satisfy_token_usage': {'completion_tokens': [2569], 'prompt_tokens': [183]},
  'provers': ['/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B',


In [14]:
(np.arange(len(d.steps)) + 1) * (d.metainfo['completion_tokens']/ len(d.steps))

array([ 4940.14285714,  9880.28571429, 14820.42857143, 19760.57142857,
       24700.71428571, 29640.85714286, 34581.        ])

In [19]:
d.metainfo['completion_tokens'], len(d.steps), d.metainfo['completion_tokens']/len(d.steps)

(34581, 7, 4940.142857142857)

In [7]:
d.metainfo['token_usage']

KeyError: 'token_usage'

In [57]:
for (i, (c, d)) in enumerate(zip(conditions, finished_list)):
    d_is_submitted = int(len(d.formal_statement or '') > 0)
    d_is_falsified = 0 if (d_is_submitted == 0) else int(is_falsified(d))
    d_is_proven = 0 if (d_is_submitted == 0) else int(is_proven(d))
    d_is_valid = 0 if (d_is_submitted == 0) else int(is_valid(d))
    d_prompt_token_cost = count_prompt_token_cost(d)
    d_token_cost = count_token_cost(d)
    d_kc = 0 if (d_is_valid == 0) else count_kc(d)
    d_step_kc = 0 if (d_is_valid == 0) else count_step_kc(d)

In [58]:
submitted_list = [d for d in finished_list if len(d.formal_statement or '') > 0]
falsified_list = [d for d in submitted_list if is_falsified(d)]
proven_list = [d for d in submitted_list if is_proven(d)]
valid_list = [d for d in submitted_list if is_valid(d)]
prompt_token_cost_list = [count_prompt_token_cost(d) for d in finished_list]
token_cost_list = [count_token_cost(d) for d in finished_list]

falsifier_prompt_token_cost_list = [count_falsifier_prompt_token_cost(d) for d in finished_list]
falsifier_token_cost_list = [count_falsifier_token_cost(d) for d in finished_list]

kc_list = [count_kc(d) for d in valid_list]
n_inf = len([kc for kc in kc_list if kc == float('inf')])
finite_kc_list = sorted([kc for kc in kc_list if kc != float('inf')])
step_kc_list = sorted([count_step_kc(d) for d in valid_list])

print(
    '#All', '#Submitted', '#Proven', '#Falsified', '#Valid',
    'Prompt Token Cost', 'Token Cost',
    'Falsifier Prompt Token Cost', 'Falsifier Prompt Token Cost',
    'Complexity (Proof)', 'Complexity (Gen)',
    'Top-1000 Complexity (Proof)', 'Top-1000 Complexity (Gen)',
    '#Inf',
    sep='\t')
print(len(finished_list), len(submitted_list), len(proven_list), len(falsified_list), len(valid_list),
      sum(prompt_token_cost_list)/len(valid_list), sum(token_cost_list)/len(valid_list),
      sum(falsifier_prompt_token_cost_list)/len(valid_list), sum(falsifier_token_cost_list)/len(valid_list),
      sum(finite_kc_list)/len(finite_kc_list), sum(step_kc_list)/len(step_kc_list),
      sum(finite_kc_list[-1000:])/len(finite_kc_list[-1000:]), sum(step_kc_list[-1000:])/len(step_kc_list[-1000:]),
      n_inf,
      sep='\t')

#All	#Submitted	#Proven	#Falsified	#Valid	Prompt Token Cost	Token Cost	Falsifier Prompt Token Cost	Falsifier Prompt Token Cost	Complexity (Proof)	Complexity (Gen)	Top-1000 Complexity (Proof)	Top-1000 Complexity (Gen)	#Inf
5000	2683	2605	269	2340	34868.85213675214	8800.142735042735	1850.8867521367522	4953.513247863248	470.2117903930131	346.2846153846154	853.793	659.194	50


In [59]:
# inf_ds = [d for d in valid_list if count_kc(d) == float('inf')]
# if len(inf_ds) > 0:
#     print(len(inf_ds))
#     with open(osp.join(
#         output_root, path_prefix, 'inf_ds.pkl'
#     ), 'wb') as f:
#         for d in inf_ds:
#             for k in ['provers', 'proofs', 'KC', 'prove_token_usage']:
#                 d.metainfo['eval_result'].pop(k)
#         pickle.dump(inf_ds, f, pickle.HIGHEST_PROTOCOL)
#     print(output_root, path_prefix)

In [60]:
d.metainfo

{'is_statement_validated': True,
 'is_solution_validated': True,
 'token_usage': {'completion_tokens': [19,
   27,
   19,
   21,
   22,
   19,
   27,
   27,
   68,
   28,
   26,
   15],
  'prompt_tokens': [190,
   202,
   231,
   241,
   257,
   264,
   273,
   281,
   293,
   306,
   316,
   321]},
 'time_consumption': 106.39736080169678,
 'falsifier_token_usage': {'completion_tokens': [[85], [188]],
  'prompt_tokens': [[126], [144]]},
 'eval_result': {'provers': ['/home/ma-user/local_cache/deepseek-ai/DeepSeek-Prover-V2-7B'],
  'proofs': ['\n  have h_contradiction : False := by\n    have h₁ : l > 0 := by\n      have h₂ : l ∈ S := hl.1\n      rw [hS] at h₂\n      exact h₂\n    have h₃ : ∀ x ∈ S, x ≤ l := hl.2\n    have h₄ : (l + 1 : ℚ) ∈ S := by\n      rw [hS]\n      norm_num [h₁]\n      <;> linarith\n    have h₅ : (l + 1 : ℚ) ≤ l := h₃ (l + 1) h₄\n    norm_num at h₅\n    <;> linarith\n  exact h_contradiction'],
  'KC': 216,
  'prove_token_usage': {'completion_tokens': [2074], 'prompt

In [61]:
def compute_kc_cost_validnum_curve(d: ProblemGenerationProcess) -> list[tuple[int, bool]]:
    if len(d.trajectory) > 0:   # By limiting the number of steps
        # Deductive Exploration
        # Maybe we should use re-ran script for alternative
        max_num = max(len(d.metainfo['token_usage']['completion_tokens']) for d in finished_list)
        total_completion_tokens_by_proof_num = np.zeros(max_num)
        is_valid_by_proof_num = np.zeros(max_num, dtype=bool)
        assert len(d.metainfo['token_usage']['completion_tokens']) <= max_num

        is_valid_by_proof_num[len(d.metainfo['token_usage']['completion_tokens'])-1:] = is_valid(d)
        total_completion_tokens_by_proof_num = np.add.accumulate(d.metainfo['token_usage']['completion_tokens'])
        total_completion_tokens_by_proof_num = np.concatenate([total_completion_tokens_by_proof_num, [total_completion_tokens_by_proof_num[-1] for _ in range(max_num - len(total_completion_tokens_by_proof_num))]])
            
        # if 'completion_tokens' in d.metainfo.keys():
        #     return d.metainfo['completion_tokens'] * d.metainfo.get('server_failure_token_discounted_ratio', 1.0)
        # else:
        #     return sum(d.metainfo['token_usage']['completion_tokens'][d.metainfo.get('first_failure_idx', 0):])
    else:   # By limiting the number of proofs
        # Baselines
        
        max_num = 12
        total_completion_tokens_by_proof_num = np.zeros(max_num)
        is_valid_by_proof_num = np.zeros(max_num, dtype=bool)
        
        if 'token_usage' in d.metainfo and 'generate_statement' in d.metainfo['token_usage']:
            # Whole-statement generation baseline
            assert 'token_usage:stmt_autoformalizer' not in d.metainfo
            assert len(d.metainfo['token_usage']['generate_statement']) == 2
            
            # total_completion_tokens_by_proof_num += d.metainfo['token_usage']['generate_statement'][0]
            # if 'provers.prove' in d.metainfo['token_usage']:
            #     total_completion_tokens += sum(d.metainfo['token_usage']['provers.prove']['completion_tokens'])
            if 'provers.prove' in d.metainfo['token_usage'].keys():
                assert len(d.metainfo['token_usage']['provers.prove']['completion_tokens']) <= max_num
                is_valid_by_proof_num[len(d.metainfo['token_usage']['provers.prove']['completion_tokens'])-1:] = is_valid(d)
                total_completion_tokens_by_proof_num = np.add.accumulate(d.metainfo['token_usage']['provers.prove']['completion_tokens'])
                total_completion_tokens_by_proof_num = np.concatenate([total_completion_tokens_by_proof_num, [total_completion_tokens_by_proof_num[-1] for _ in range(max_num - len(total_completion_tokens_by_proof_num))]])
            
            total_completion_tokens_by_proof_num += sum(d.metainfo['token_usage']['generate_statement'])
        elif 'token_usage:stmt_autoformalizer' in d.metainfo:
            # Autoformalization-based baselines
            
            if 'eval_result' in d.metainfo.keys():
                assert len(d.metainfo['eval_result']['prove_token_usage']['completion_tokens']) <= max_num
                is_valid_by_proof_num[len(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])-1:] = is_valid(d)
                total_completion_tokens_by_proof_num = np.add.accumulate(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])
                total_completion_tokens_by_proof_num = np.concatenate([total_completion_tokens_by_proof_num, [total_completion_tokens_by_proof_num[-1] for _ in range(max_num - len(total_completion_tokens_by_proof_num))]])
                
            total_completion_tokens_by_proof_num += sum(d.metainfo['token_usage:stmt_autoformalizer']['completion_tokens'])
        else:
            # MUSTARD baseline
            assert len(d.metainfo['eval_result']['prove_token_usage']['completion_tokens']) <= max_num
            is_valid_by_proof_num[len(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])-1:] = is_valid(d)
            total_completion_tokens_by_proof_num = np.add.accumulate(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])
            total_completion_tokens_by_proof_num = np.concatenate([total_completion_tokens_by_proof_num, [total_completion_tokens_by_proof_num[-1] for _ in range(max_num - len(total_completion_tokens_by_proof_num))]])
                
    return total_completion_tokens_by_proof_num, is_valid_by_proof_num

In [62]:
cost_valid_by_k = [
    (compute_kc_cost_validnum_curve(d), d) for d in finished_list
]
# Complexity and difficulty can also visualize!

In [63]:
print(path_prefix)
print('K', 'Cost', '#Valid', 'Complexity', '#Inf', sep='\t')
for k in range(max(len(d.metainfo['token_usage']['completion_tokens']) for d in finished_list)):
    cost = sum(cost_list[k] for ((cost_list, is_valid_list), d) in cost_valid_by_k)
    valid_list_k = [
        d for ((cost_list, is_valid_list), d) in cost_valid_by_k if is_valid_list[k]
    ]
    if len(valid_list_k) == 0:
        continue
    kc_list_k = [count_kc(d) for d in valid_list_k]
    finite_kc_list_k = [kc for kc in kc_list_k if kc != float('inf')]
    
    avg_complexity = sum(finite_kc_list_k) / len(finite_kc_list_k)
    print(k, cost, len(valid_list_k), avg_complexity, len(kc_list_k) - len(finite_kc_list_k), sep='\t')
    

sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch.0913.staged
K	Cost	#Valid	Complexity	#Inf
1	272623.0	46	128.6304347826087	0
2	460584.0	65	120.61904761904762	2
3	700706.0	111	177.30275229357798	2
4	986764.0	175	190.3815028901734	2
5	1302268.0	271	244.16728624535315	2
6	1676879.0	384	257.50397877984085	7
7	2070318.0	533	274.376673040153	10
8	2468847.0	650	283.7890625	10
9	2915040.0	751	299.88243243243244	11
10	3347577.0	868	323.344262295082	14
11	3821840.0	961	345.35163674762407	14
12	4321499.0	1054	353.8198458574181	16
13	4742159.0	1131	363.55395683453236	19
14	5125337.0	1213	367.3509212730318	19
15	5548459.0	1276	376.92919649960226	19
16	5951150.0	1325	387.604134762634	19
17	6321583.0	1393	393.1914119359534	19
18	6703875.0	1445	403.0680701754386	20
19	7067605.0	1497	409.427894380501	20
20	7440808.0	1545	413.12860892388454	21
21	7801814.0	1590	416.9074664964901	23
22	8127858.0	1627	420.1996257018091	24
23	8465517.0	1655	420.78847332924585	24
2

In [52]:
print(path_prefix)
print('K', 'Cost', '#Valid', 'Complexity', sep='\t')
for k in range(12):
    cost = sum(cost_list[k] for ((cost_list, is_valid_list), d) in cost_valid_by_k)
    valid_list_k = [
        d for ((cost_list, is_valid_list), d) in cost_valid_by_k if is_valid_list[k]
    ]
    if len(valid_list_k) == 0:
        continue
    kc_list_k = [count_kc(d) for d in valid_list_k]
    finite_kc_list_k = [kc for kc in kc_list_k if kc != float('inf')]
    
    avg_complexity = sum(finite_kc_list_k) / len(finite_kc_list_k)
    print(k, cost, len(valid_list_k), avg_complexity, len(kc_list_k) - len(finite_kc_list_k), sep='\t')
    

MUSTARDSauce_lean4_parsed
K	Cost	#Valid	Complexity
0	3686728.0	3578	118.42537730575741	0
1	3907534.0	3747	119.16706698692288	0
2	3963650.0	3776	118.90492584745763	0
3	3986550.0	3779	119.0500132310135	0
4	4005955.0	3781	119.04522613065326	0
5	4019846.0	3783	119.0253766851705	0
6	4031094.0	3785	119.0340819022457	0
7	4036953.0	3787	119.01241087932401	0
8	4042212.0	3788	119.01715945089757	0
9	4045194.0	3789	119.02903140670361	0
10	4047726.0	3789	119.02903140670361	0
11	4049862.0	3791	119.02903140670361	2


In [None]:
C.Counter(
    len(d.metainfo['token_usage']['provers.prove']['completion_tokens']) for d in valid_list
)

Counter({1: 725,
         2: 236,
         3: 84,
         4: 30,
         5: 23,
         6: 18,
         8: 12,
         7: 10,
         11: 8,
         10: 8,
         9: 6,
         12: 4})

In [None]:
d.metainfo

{'formal_proof': "import data.real.basic\n\n-- let's denote the cost of the bike as `b`, the weekly savings as `s`, the number of weeks as `w`, \n-- the discount rate as `d` and the remaining amount to save as `r`\nvariables (b s w d r : ℝ)\n\n-- the cost of the bike after the discount is `b - b * d`\n-- the total savings after `w` weeks is `w * s`\n-- the remaining amount to save is the cost of the bike after the discount minus the total savings\n-- therefore, we have `r = b - b * d - w * s`\n-- given that `b = 200`, `s = 20`, `w = 4` and `d = 0.25`, we can solve for `r`\n\n-- defining the equation\ndef remaining_to_save (b s w d : ℝ) : ℝ := b - b * d - w * s\n\n-- asserting the known values and solving for r\nexample : remaining_to_save 200 20 4 0.25 = 70 :=\nbegin\n  unfold remaining_to_save, -- expanding the definition of `remaining_to_save`\n  norm_num, -- simplifying the equation\nend",
 'source_path': '/cache/data/MUSTARDSauce/data/subset_filtered',
 'source_idx': 'new_step_form

In [None]:
d.metainfo['eval_result']['prove_token_usage']['completion_tokens']

[2083, 2346, 7994, 1962, 2315, 7994, 2265]

In [None]:
C.Counter(
    len(d.metainfo['token_usage:stmt_autoformalizer']['completion_tokens']) for d in valid_list
)

Counter({1: 1024})

In [None]:
for d in valid_list:
    len(d.metainfo['eval_result']['prove_token_usage'])

In [None]:
d.metainfo

{'source_path': '/home/ma-user/local_cache/zhaoxlpku/PromptCoT-QwQ-Dataset/train.jsonl',
 'source_idx': 22206,
 'token_usage:stmt_autoformalizer': {'completion_tokens': [71],
  'prompt_tokens': [106]},
 'eval_result': {'falsify_provers': ['/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B',
   '/home/ma-user/local_cache/AI-MO/Kimina-Prover-Distill-8B',
   '/home/ma-user/local_cache/deepseek-ai/DeepSeek-Prover-V2-7B',
   '/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B',
   '/home/ma-user/local_cache/AI-MO/Kimina-Prover-Distill-8B',
   '/home/ma-user/local_cache/deepseek-ai/DeepSeek-Prover-V2-7B',
   '/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B',
   '/home/ma-user/local_cache/AI-MO/Kimina-Prover-Distill-8B',
   '/home/ma-user/local_cache/deepseek-ai/DeepSeek-Prover-V2-7B',
   '/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B',
   '/home/ma-user/local_cache/AI-MO/Kimina-Prover-Distill-8B',
   '/home/ma-user/local_cache/deepseek-ai/DeepSeek-Prover-V2-7B'],

In [None]:
# for d in falsified_list:
#     # if count_kc(d) == float('inf'):
#         print(d.formal_statement)
#         print(d.formal_solution_draft)
#         print('\n\n')

In [None]:
# len([kc for kc in kc_list if kc == float('inf')]), sum([kc for kc in kc_list if kc != float('inf')])/len([kc for kc in kc_list if kc != float('inf')])

In [None]:
with open(
    osp.join(output_root, path_prefix.replace('/', '-') + '-valid_samples.jsonl'), 'w'
) as f:
    for i, (c, d) in enumerate(zip(conditions, finished_list)):
        if not is_valid(d):
            continue
        d: ProblemGenerationProcess
        assert len(d.formal_statement or '') > 0
        assert len(d.formal_solution_draft or '') > 0
        f.write(json.dumps({
            'header': d.header or '',
            'formal_statement': d.formal_statement,
            'formal_proof': d.formal_solution_draft,
            'condition': c,
            'index': i,
            'src': path_prefix
        }) + '\n')

In [None]:
output_root

'/home/ma-user/workspace/formal_problem_generation/output_tmp/output'

In [None]:
strange_list = [
    d for d in finished_list if (len(d.formal_statement or '') > 0) and not is_proven(d) and not is_falsified(d)
]
len(strange_list)

75

In [None]:
d = random.choice(strange_list)
print(d.informal_problem)
print(d.informal_answer)
print(d.formal_statement)


Let $P$ be a point chosen at random inside isosceles triangle $ABC$ with sides of length $1$ . The expected value of the expression \[\log_{10} \frac{AP}{BP}+\log_{10} \frac{BP}{CP}+\log_{10} \frac{CP}{AP}\] can be written in the form $\frac{a}{b}$ , where $a$ and $b$ are relatively prime positive integers. Find $a+b$
1
example (A B C P : EuclideanSpace ℝ (Fin 2))
    (hABC : AffineIndependent ℝ ![A, B, C])
    (hP : P ∈ interior (convexHull ℝ {A, B, C}))
    (hP1 : dist A P = 1) (hP2 : dist B P = 1) (hP3 : dist C P = 1) :
    ∃ a b : ℕ, a > 0 ∧ b > 0 ∧ a.Coprime b ∧
    (logb 10 (dist A P / dist B P) + logb 10 (dist B P / dist C P) +
     logb 10 (dist C P / dist A P)) = a / b := sorry


In [None]:
raise

RuntimeError: No active exception to reraise

### Extract falsifying tasks

In [None]:
with open('/home/ma-user/workspace/formal_problem_generation/output_tmp/output/sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-linear.39980.problem_generator.nopack/fpg_evaluate_falsify_prove.20250908-234856.pkl', 'rb') as f:
    conditions_sampled, finished_list = pickle.load(f)
len(finished_list)

In [None]:
# with open('/home/ma-user/workspace/formal_problem_generation/output_tmp/output/sft_ar_v2_strict/Goedel-Prover-V2-8B.Numina-Lean.problem_generator.nopack/fpg_evaluate_falsify_prove.20250907-170721.pkl', 'rb') as f:
#     conditions_sampled, finished_list = pickle.load(f)

In [None]:
for d in finished_list:
    try:
        d.metainfo = json.loads(d.metainfo)
    except:
        pass

In [None]:
for d in finished_list:
    if 'eval_result' not in d.metainfo.keys():
        assert len(d.formal_statement or '') == 0
    if len(d.formal_statement or '') == 0:
        assert 'eval_result' not in d.metainfo.keys()

In [None]:
for d in finished_list:
    if 'eval_result' in d.metainfo.keys() and d.metainfo['eval_result'].get('falsify_proofs', [None])[-1] is None:
        assert all(p is None for p in d.metainfo['eval_result'].get('falsify_proofs', [None]))
        d.metainfo['eval_old_result'] = d.metainfo['eval_result']
        d.metainfo.pop('eval_result')

In [None]:
with open('/home/ma-user/workspace/formal_problem_generation/output_tmp/output/sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-linear.39980.problem_generator.nopack/fpg_evaluate_falsify_prove.20250908-234856.hidden_unsuccessful_falsifying.pkl', 'wb') as f:
    pickle.dump((conditions_sampled, finished_list), f)

In [None]:
C.Counter(1 for d in finished_list if len(d.formal_statement) > 0)

In [None]:
C.Counter(1 for d in finished_list if 'eval_result' in d.metainfo.keys())

In [None]:
C.Counter(1 for d in finished_list if 'eval_old_result' in d.metainfo.keys())