In [3]:
import os
import os.path as osp
import json
import pickle
import collections as C
import itertools as I

from loguru import logger

from common.constants import CORE_OPTIONS
from common.utils import remove_comments, normalize_spaces, remove_spaces
from common.pantograph.dataclasses import ProblemGenerationProcess

In [4]:
def is_falsified(d: ProblemGenerationProcess) -> bool:
    if 'eval_old_result' in d.metainfo:
        if any(p is not None for p in d.metainfo['eval_old_result']['falsify_proofs']):
            return True
    assert 'eval_result' in d.metainfo
    # 'eval_result' may contain both falsifying and satisfying
    if any(p is not None for p in d.metainfo['eval_result'].get('satisfy_proofs', [None])):
        return False
    return any(p is not None for p in d.metainfo['eval_result']['falsify_proofs'])

def is_proven(d: ProblemGenerationProcess) -> bool:
    if len(d.trajectory) > 0:
        # Deductive Exploration
        return len(d.formal_statement) > 0 and d.metainfo.get('is_solution_validated')
    else:
        # Baselines
        if len(d.formal_solution_draft or '') > 0:
            return True
        if any(p is not None for p in d.metainfo.get('eval_result', dict()).get('proofs', [None])):
            for p in d.metainfo.get('eval_result', dict()).get('proofs', [None]):
                if p is not None:
                    d.formal_solution_draft = p
            logger.warning(f"`len(d.formal_solution_draft or '') == 0` but `eval_result` is proven")
            return True
        return False

def is_valid(d: ProblemGenerationProcess) -> bool:
    return is_proven(d) and not is_falsified(d)

def count_prompt_token_cost(d: ProblemGenerationProcess) -> int:
    if len(d.trajectory) > 0:
        # Deductive Exploration
        if 'prompt_tokens' in d.metainfo.keys():
            return d.metainfo['prompt_tokens'] * d.metainfo.get('server_failure_token_discounted_ratio', 1.0)
        else:
            return sum(d.metainfo['token_usage']['prompt_tokens'][d.metainfo.get('first_failure_idx', 0):])
    else:
        # Baselines
        total_prompt_tokens = 0
        if 'token_usage' in d.metainfo and 'generate_statement' in d.metainfo['token_usage']:
            # Whole-statement generation baseline
            assert 'token_usage:stmt_autoformalizer' not in d.metainfo
            assert len(d.metainfo['token_usage']['generate_statement']) == 2
            total_prompt_tokens += d.metainfo['token_usage']['generate_statement'][1]
            if 'provers.prove' in d.metainfo['token_usage']:
                total_prompt_tokens += sum(d.metainfo['token_usage']['provers.prove']['prompt_tokens'])
        elif 'token_usage:stmt_autoformalizer' in d.metainfo:
            # Autoformalization-based baselines
            total_prompt_tokens += sum(d.metainfo['token_usage:stmt_autoformalizer']['prompt_tokens'])
            if 'eval_result' in d.metainfo.keys():
                total_prompt_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['prompt_tokens'])
        else:
            # MUSTARD baseline
            total_prompt_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['prompt_tokens'])
        return total_prompt_tokens

def count_token_cost(d: ProblemGenerationProcess) -> int:
    if len(d.trajectory) > 0:
        # Deductive Exploration
        if 'completion_tokens' in d.metainfo.keys():
            return d.metainfo['completion_tokens'] * d.metainfo.get('server_failure_token_discounted_ratio', 1.0)
        else:
            return sum(d.metainfo['token_usage']['completion_tokens'][d.metainfo.get('first_failure_idx', 0):])
    else:
        # Baselines
        total_completion_tokens = 0
        if 'token_usage' in d.metainfo and 'generate_statement' in d.metainfo['token_usage']:
            # Whole-statement generation baseline
            assert 'token_usage:stmt_autoformalizer' not in d.metainfo
            assert len(d.metainfo['token_usage']['generate_statement']) == 2
            total_completion_tokens += d.metainfo['token_usage']['generate_statement'][0]
            if 'provers.prove' in d.metainfo['token_usage']:
                total_completion_tokens += sum(d.metainfo['token_usage']['provers.prove']['completion_tokens'])
        elif 'token_usage:stmt_autoformalizer' in d.metainfo:
            # Autoformalization-based baselines
            total_completion_tokens += sum(d.metainfo['token_usage:stmt_autoformalizer']['completion_tokens'])
            if 'eval_result' in d.metainfo.keys():
                total_completion_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])
        else:
            # MUSTARD baseline
            total_completion_tokens += sum(d.metainfo['eval_result']['prove_token_usage']['completion_tokens'])
        return total_completion_tokens

def count_kc(d: ProblemGenerationProcess) -> int:
    try:
        return d.metainfo['eval_result']['KC']
    except:
        return len(remove_spaces(remove_comments(d.formal_solution_draft)))

### Results Analysis

In [23]:
output_root = '/home/ma-user/workspace/formal_problem_generation/output_tmp/output'


# Numina-Lean.whole_statement_generatior
# path_prefix = 'sft_wg_starified/Goedel-Prover-V2-8B.Numina-Lean.whole_statement_generatior.nopack'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-162131.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250909-232315.log'
# )

# # Numina-Lean-linear.39980.problem_generator
# path_prefix = 'sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-linear.39980.problem_generator.nopack'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-230039.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250908-092208.log'
# )

# # Numina-Lean.problem_generator.nopack
# path_prefix = 'sft_ar_v2_strict/Goedel-Prover-V2-8B.Numina-Lean.problem_generator.nopack'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250907-170721.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250906-102831.log'
# )

# Numina-Lean-reasseblmed.39509.problem_generator
# path_prefix = 'sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch'
# load_path = osp.join(
#     output_root, path_prefix,
#     # 'problem_generation.20250909-090734.pkl'
#     None
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'problem_generation.20250909-090734.log'
# )
# raise

# # ScaleQuest-Math_kimina7b
# path_prefix = 'autoformalization_pg_kimina7b/ScaleQuest-Math_kimina7b'
# load_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154955.pkl'
# )
# log_path = osp.join(
#     output_root, path_prefix,
#     'fpg_evaluate_falsify_prove.20250911-154955.log'
# )

# MUSTARDSauce_lean4_parsed
path_prefix = 'MUSTARDSauce_lean4_parsed'
load_path = osp.join(
    output_root, path_prefix,
    'fpg_evaluate_falsify_prove.20250911-121840.pkl'
)
log_path = osp.join(
    output_root, path_prefix,
    'fpg_evaluate_falsify_prove.20250911-121840.log'
)

In [24]:
failure_steps = C.defaultdict(list)
with open(load_path, 'rb') as f:
    (conditions, finished_list) = pickle.load(f)
with open(log_path, 'r') as f:
    for l in f.readlines():
        l = l.strip()
        if 'agent.problem_generation:generate_async' in l and l.endswith('failed due to AssertionError()'):
            e = l.split()
            assert e[7].startswith('generate_async(') and e[7].endswith('):')
            gen_idx = int(e[7][len('generate_async('):-len('):')])
            step_idx = int(e[8].split('/')[0])
            failure_steps[gen_idx].append(step_idx)

for d in finished_list:
    try:
        d.metainfo = json.loads(d.metainfo)
    except:
        pass

if path_prefix == 'MUSTARDSauce_lean4_parsed':
    conditions, finished_list = [c for (c, d) in zip(conditions, finished_list) if 'eval_result' in d.metainfo.keys()], [d for (c, d) in zip(conditions, finished_list) if 'eval_result' in d.metainfo.keys()]

len(conditions), len(finished_list), len(failure_steps)

(3794, 3794, 0)

In [25]:
for (gen_idx, step_idx_list) in failure_steps.items():
    if step_idx_list == list(range(step_idx_list[0], 81)):
        # print(f'True, {gen_idx}')
        assert len(finished_list[gen_idx].steps) < step_idx_list[0]
        finished_list[gen_idx].metainfo['server_failure_token_discounted_ratio'] = (1 - len(step_idx_list) / 80)
        finished_list[gen_idx].metainfo['first_failure_idx'] = step_idx_list[0] - 1
        # print(len(step_idx_list) / 80)
    else:
        print(f'False, {gen_idx}, {len(set(range(step_idx_list[0], 81))-set(step_idx_list))}')

In [26]:
for d in finished_list:
    if 'eval_result' in d.metainfo:
        break
d.metainfo

{'formal_proof': 'import data.real.basic\n\ntheorem fraction_multiply : (1 / 2 : ℝ) * (1 / 3 : ℝ) = 1 / 6 :=\nbegin\n  norm_num\nend',
 'source_path': '/cache/data/MUSTARDSauce/data/subset_filtered',
 'source_idx': 'new_step_formalans_kwgiven__ELEM_theorem_proving_5th_grade_915',
 'file_name': 'tp_ELEM_k1_5th_915_step.json',
 'lean4_source_path': 'TpELEMK15th915Step.lean',
 'code_with_sorry': 'import Mathlib\nimport Aesop\n\n\n\ntheorem fraction_multiply : (1 / 2 : ℝ) * (1 / 3 : ℝ) = 1 / 6  := by\n  sorry',
 'unit_idx': 0,
 'eval_result': {'provers': ['/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B'],
  'proofs': ['\n  have h_main : (1 / 2 : ℝ) * (1 / 3 : ℝ) = 1 / 6 := by\n    norm_num\n    <;>\n    rfl\n    <;>\n    simp_all\n    <;>\n    ring\n    <;>\n    norm_num\n    <;>\n    rfl\n  \n  exact h_main'],
  'KC': 94,
  'prove_token_usage': {'completion_tokens': [751], 'prompt_tokens': [170]},
  'satisfy_provers': ['/home/ma-user/local_cache/Goedel-LM/Goedel-Prover-V2-8B'],
 

In [27]:
for (i, (c, d)) in enumerate(zip(conditions, finished_list)):
    d_is_submitted = int(len(d.formal_statement or '') > 0)
    d_is_falsified = 0 if (d_is_submitted == 0) else int(is_falsified(d))
    d_is_proven = 0 if (d_is_submitted == 0) else int(is_proven(d))
    d_is_valid = 0 if (d_is_submitted == 0) else int(is_valid(d))
    d_prompt_token_cost = count_prompt_token_cost(d)
    d_token_cost = count_token_cost(d)
    d_kc = 0 if (d_is_valid == 0) else count_kc(d)

In [29]:
submitted_list = [d for d in finished_list if len(d.formal_statement or '') > 0]
falsified_list = [d for d in submitted_list if is_falsified(d)]
proven_list = [d for d in submitted_list if is_proven(d)]
valid_list = [d for d in submitted_list if is_valid(d)]
prompt_token_cost_list = [count_prompt_token_cost(d) for d in finished_list]
token_cost_list = [count_token_cost(d) for d in finished_list]
kc_list = [count_kc(d) for d in valid_list]

print('#All', '#Submitted', '#Proven', '#Falsified', '#Valid', 'Prompt Token Cost', 'Token Cost', 'Complexity', sep='\t')
print(len(finished_list), len(submitted_list), len(proven_list), len(falsified_list), len(valid_list), sum(prompt_token_cost_list)/len(valid_list), sum(token_cost_list)/len(valid_list), sum(kc_list)/len(valid_list), sep='\t')

#All	#Submitted	#Proven	#Falsified	#Valid	Prompt Token Cost	Token Cost	Complexity
3794	3794	3794	3	3791	249.05750461619624	1068.2833025586917	inf


In [40]:
for d in falsified_list:
    # if count_kc(d) == float('inf'):
        print(d.formal_statement)
        print(d.formal_solution_draft)
        print('\n\n')

-- The definition of the linear function.
-- The definition of the linear function.
def f (a b x : ℝ) :=
  a * x + b

-- The proof that b must be 3.

theorem b_must_be_three (a b : ℝ) (h : ∀ x, f a b x = f a (b + 3) x) : b = 3  := by
  sorry

  have h1 : f a b 2 = f a (b + 3) 2 := h 2
  unfold f at h1
  linarith




open Real

-- defining the function for the population of birds
def f (t : ℝ) :=
  2 * t ^ 3 - 3 * t ^ 2 + 5 * t - 2

-- defining the integral of the function from 0 to 20
-- In Lean, we don't have the ability to compute definite integrals directly. So here we just define it.
def totalChange :=
  "integral (λ t, f t) 0 20"

-- calculation of the total change in bird population

example (h : totalChange = "8000") : totalChange = "8000"  := by
  sorry
-- calculation of the integral
  rw [h]

-- defining the number of ways of forming groups of 4 birds from 8000
-- In Lean, we don't have the ability to compute combinations directly. So here we just define it.




open scoped Cl

In [31]:
len([kc for kc in kc_list if kc == float('inf')]), sum([kc for kc in kc_list if kc != float('inf')])/len([kc for kc in kc_list if kc != float('inf')])

(2, 119.02903140670361)

In [10]:
with open(
    osp.join(output_root, path_prefix.replace('/', '-') + '-valid_samples.jsonl'), 'w'
) as f:
    for i, (c, d) in enumerate(zip(conditions, finished_list)):
        if not is_valid(d):
            continue
        d: ProblemGenerationProcess
        assert len(d.formal_statement or '') > 0
        assert len(d.formal_solution_draft or '') > 0
        f.write(json.dumps({
            'header': d.header or '',
            'formal_statement': d.formal_statement,
            'formal_proof': d.formal_solution_draft,
            'condition': c,
            'index': i,
            'src': path_prefix
        }) + '\n')

In [26]:
raise

RuntimeError: No active exception to reraise

### Extract falsifying tasks

In [None]:
with open('/home/ma-user/workspace/formal_problem_generation/output_tmp/output/sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-linear.39980.problem_generator.nopack/fpg_evaluate_falsify_prove.20250908-234856.pkl', 'rb') as f:
    conditions_sampled, finished_list = pickle.load(f)
len(finished_list)

In [None]:
# with open('/home/ma-user/workspace/formal_problem_generation/output_tmp/output/sft_ar_v2_strict/Goedel-Prover-V2-8B.Numina-Lean.problem_generator.nopack/fpg_evaluate_falsify_prove.20250907-170721.pkl', 'rb') as f:
#     conditions_sampled, finished_list = pickle.load(f)

In [None]:
for d in finished_list:
    try:
        d.metainfo = json.loads(d.metainfo)
    except:
        pass

In [None]:
for d in finished_list:
    if 'eval_result' not in d.metainfo.keys():
        assert len(d.formal_statement or '') == 0
    if len(d.formal_statement or '') == 0:
        assert 'eval_result' not in d.metainfo.keys()

In [None]:
for d in finished_list:
    if 'eval_result' in d.metainfo.keys() and d.metainfo['eval_result'].get('falsify_proofs', [None])[-1] is None:
        assert all(p is None for p in d.metainfo['eval_result'].get('falsify_proofs', [None]))
        d.metainfo['eval_old_result'] = d.metainfo['eval_result']
        d.metainfo.pop('eval_result')

In [None]:
with open('/home/ma-user/workspace/formal_problem_generation/output_tmp/output/sft_ar_v3/Goedel-Prover-V2-8B.Numina-Lean-linear.39980.problem_generator.nopack/fpg_evaluate_falsify_prove.20250908-234856.hidden_unsuccessful_falsifying.pkl', 'wb') as f:
    pickle.dump((conditions_sampled, finished_list), f)

In [None]:
C.Counter(1 for d in finished_list if len(d.formal_statement) > 0)

In [None]:
C.Counter(1 for d in finished_list if 'eval_result' in d.metainfo.keys())

In [None]:
C.Counter(1 for d in finished_list if 'eval_old_result' in d.metainfo.keys())