In [1]:
import os
import os.path as osp
import pickle
import json
import collections as C
import itertools as I
import random
import regex as re
import multiprocessing as mp
import dataclasses as D
from typing import List, Optional, Dict, Tuple

import networkx as nx
from tqdm import tqdm
from dacite import from_dict
from loguru import logger

from common.constants import BANNED_TOKENS_IN_ANSWER_TYPE, BANNED_TOKENS_IN_SOLVING_STATE, CORE_OPTIONS, FPS_GLOBAL_SETTING, OPEN_HEADER
from common.pantograph.dataclasses import Goal, GoalState, Variable, CompilationUnit, TacticDraft, FormalProblem, SolutionAutoformalizationResult, ProblemGenerationStep, ProblemGenerationProcess
from common.pantograph.server import Server, TacticFailure
from common.pantograph.solving_server import PersistentPropSolvingServer
from common.utils import remove_comments, normalize_spaces, format_forward_solution_step_prompt, replace_span, chunk_list, parse_idents, remove_min_whitespace, normalize_draft

In [2]:
import msgspec

enc = msgspec.msgpack.Encoder()
dec = msgspec.msgpack.Decoder(List[Optional[ProblemGenerationProcess]])

In [3]:
with open('/cache/data/cycle0123_succeeded/done_chunk_0.msgp', 'rb') as f:
    data = dec.decode(f.read())

In [7]:
print(data[3].dependencies)

[(0, 5), (0, 9), (0, 11), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (3, 11), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (5, 9), (5, 11), (6, 9), (6, 10), (6, 11), (7, 9), (7, 10), (7, 11), (8, 9), (8, 10), (8, 11), (9, 10), (11, 12)]


In [13]:
server = PersistentPropSolvingServer(
    imports=["Mathlib", "Aesop"],
    project_path='/home/ma-user/workspace/formal_problem_generation/formal_problem_generation/data/MiniF2F',
    timeout=120,
    _sync_init=False,
)

server.set_tag(f'test')

## Load from CoPA Data (Cycle 0~3)

In [4]:
with open('/cache/data_succeeded_0123.msgp', 'rb') as f:
    object_msgp = f.read()

In [6]:
data_succeeded_all = dec.decode(object_msgp)

In [8]:
type(data_succeeded_all), len(data_succeeded_all), C.Counter([type(d) for d in data_succeeded_all])

(list,
 214442,
 Counter({common.pantograph.dataclasses.SolutionAutoformalizationResult: 214442}))

In [7]:
splits = []
for (split_base_cnt, data_chunk) in (chunk_list(data_succeeded_all, 1024)):
    splits.append(split_base_cnt)
    with open(osp.join('/cache/data/cycle0123_succeeded', f'raw_chunk_{split_base_cnt}.pkl'), 'wb') as f:
        pickle.dump(data_chunk, f)

In [2]:
with open('/cache/data/cycle0123_succeeded/raw_chunk_0.pkl', 'rb') as f:
    data_chunk = pickle.load(f)

In [6]:
data_chunk[0].metainfo

{'subject': 'Intermediate Algebra',
 'level': 'Level 5',
 'solution_state_transition': [[{'state_id': -1,
    'goals': [{'variables': [{'t': 'ℝ', 'v': None, 'name': 'answer'},
       {'t': 'ℝ', 'v': None, 'name': 'x'},
       {'t': 'ℝ', 'v': None, 'name': 'y'},
       {'t': 'x ^ 3 - 3 * x ^ 2 + 5 * x = 1', 'v': None, 'name': 'h_x'},
       {'t': 'y ^ 3 - 3 * y ^ 2 + 5 * y = 5', 'v': None, 'name': 'h_y'},
       {'t': 'x + y = answer', 'v': None, 'name': 'h_answer'}],
      'target': '?w',
      'sibling_dep': [2],
      'name': 'h.mp',
      'is_conversion': False},
     {'variables': [{'t': 'ℝ', 'v': None, 'name': 'answer'},
       {'t': 'ℝ', 'v': None, 'name': 'x'},
       {'t': 'ℝ', 'v': None, 'name': 'y'},
       {'t': 'x ^ 3 - 3 * x ^ 2 + 5 * x = 1', 'v': None, 'name': 'h_x'},
       {'t': 'y ^ 3 - 3 * y ^ 2 + 5 * y = 5', 'v': None, 'name': 'h_y'},
       {'t': '?w', 'v': None, 'name': 'h_submission'}],
      'target': 'x + y = answer',
      'sibling_dep': [],
      'name': 'h.mp

## Load from CoPA Data

In [None]:
base_dir = '/sfs/liuqi/data/formal_problem_solving/data/metamath-qwen2-math/numina_cot_qwen2.verified.valid.all/cycle1/all.train.leanv4.15.0.0321/process_0411'
data_processed = []

for d in tqdm([
    p for p in os.listdir(base_dir) if p.startswith('processed_')
]):
    with open(osp.join(base_dir, d), 'rb') as f:
        data_processed.extend(pickle.load(f))

In [None]:
data_processed_cycle0 = data_processed

In [30]:
for d in data_processed_cycle0[4::5]:
    assert all(isinstance(dd, SolutionAutoformalizationResult) for dd in d)

In [None]:
base_dir = '/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle1/processed_all'
data_processed_cycle1 = []

for d in tqdm([
    p for p in os.listdir(base_dir) if p.startswith('processed_')
]):
    with open(osp.join(base_dir, d), 'rb') as f:
        data_processed_cycle1.extend(pickle.load(f))

In [31]:
for d in data_processed_cycle1[4::5]:
    assert all(isinstance(dd, SolutionAutoformalizationResult) for dd in d)

In [None]:
base_dir = '/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle2/dones_all'
data_processed_cycle2 = []

for d in tqdm([
    p for p in os.listdir(base_dir) if p.startswith('processed_')
]):
    with open(osp.join(base_dir, d), 'rb') as f:
        data_processed_cycle2.extend(pickle.load(f))

In [32]:
for d in data_processed_cycle2[4::5]:
    assert all(isinstance(dd, SolutionAutoformalizationResult) for dd in d)

In [None]:
base_dir = '/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle3/dones_all'
data_processed_cycle3 = []

for d in tqdm([
    p for p in os.listdir(base_dir) if p.startswith('done_')
]):
    with open(osp.join(base_dir, d), 'rb') as f:
        data_processed_cycle3.extend(pickle.load(f))

In [56]:
data_succeeded_cycle0 = [
    d for d in I.chain(*data_processed_cycle0[4::5]) if isinstance(d, SolutionAutoformalizationResult) and \
        d.success and \
        'found' not in d.informal_answer and \
        # d.metainfo['problem_is_valid'] == 'Yes' and \
        # d.metainfo['solution_is_valid'] == 'Yes' and \
        d.formal_statement is not None and \
        d.formal_answer_type is not None and \
        all(t not in d.formal_answer_type for t in BANNED_TOKENS_IN_ANSWER_TYPE) and \
        all(t not in str(v) for v in d.intros + d.outros for t in BANNED_TOKENS_IN_SOLVING_STATE)
]
len(data_succeeded_cycle0)

23988

In [57]:
data_succeeded_cycle1 = [
    d for d in I.chain(*data_processed_cycle1[4::5]) if isinstance(d, SolutionAutoformalizationResult) and \
        d.success and \
        'found' not in d.informal_answer and \
        d.metainfo['problem_is_valid'] == 'Yes' and \
        d.metainfo['solution_is_valid'] == 'Yes' and \
        d.formal_statement is not None and \
        d.formal_answer_type is not None and \
        all(t not in d.formal_answer_type for t in BANNED_TOKENS_IN_ANSWER_TYPE) and \
        all(t not in str(v) for v in d.intros + d.outros for t in BANNED_TOKENS_IN_SOLVING_STATE)
]
len(data_succeeded_cycle1)

111635

In [58]:
data_succeeded_cycle2 = [
    d for d in I.chain(*data_processed_cycle2[4::5]) if isinstance(d, SolutionAutoformalizationResult) and \
        d.success and \
        'found' not in d.informal_answer and \
        d.metainfo['problem_is_valid'] == 'Yes' and \
        d.metainfo['solution_is_valid'] == 'Yes' and \
        d.formal_statement is not None and \
        d.formal_answer_type is not None and \
        all(t not in d.formal_answer_type for t in BANNED_TOKENS_IN_ANSWER_TYPE) and \
        all(t not in str(v) for v in d.intros + d.outros for t in BANNED_TOKENS_IN_SOLVING_STATE)
]
len(data_succeeded_cycle2)

51949

In [59]:
data_succeeded_cycle3 = [
    d for d in data_processed_cycle3 if isinstance(d, SolutionAutoformalizationResult) and \
        d.success and \
        'found' not in d.informal_answer and \
        d.metainfo['problem_is_valid'] == 'Yes' and \
        d.metainfo['solution_is_valid'] == 'Yes' and \
        d.formal_statement is not None and \
        d.formal_answer_type is not None and \
        all(t not in d.formal_answer_type for t in BANNED_TOKENS_IN_ANSWER_TYPE) and \
        all(t not in str(v) for v in d.intros + d.outros for t in BANNED_TOKENS_IN_SOLVING_STATE)
]
len(data_succeeded_cycle3)

26870

In [60]:
data_succeeded_all = list(I.chain(
    data_succeeded_cycle0,
    data_succeeded_cycle1,
    data_succeeded_cycle2,
    data_succeeded_cycle3,
))
len(data_succeeded_all)

214442

In [61]:
for d in data_succeeded_all:
    assert isinstance(d, SolutionAutoformalizationResult)

In [64]:
with open('/cache/data_succeeded_0123.msgp', 'wb') as f:
    f.write(enc.encode(data_succeeded_all))

In [None]:
'/sfs/liuqi/data/formal_problem_solving/data/metamath-qwen2-math/numina_cot_qwen2.verified.valid.all/cycle1/all.train.leanv4.15.0.0321/process_0411'

In [None]:
'/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle1/processed_all'

In [None]:
'/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle2/dones_all'

In [None]:
def worker_cycle1(args):
    path = args
    with open(osp.join(path), 'rb') as f:
        _, _, _, _, data = pickle.load(f)
    return data


In [None]:
# Cycle 1
data_root = '/sfs/liuqi/data/AI-MO/NuminaMath-1.5/cycle2/dones_all'

ps = [osp.join(data_root, p) for p in os.listdir(data_root) if p.startswith('processed_') and p.endswith('.pkl')]
print(len(ps))

In [None]:
# with mp.Pool(processes=128) as pool:
#     results = pool.map(worker_cycle1, ps)
samples = [worker_cycle1(ps[0])]

In [None]:
samples = list(I.chain(*samples))

## Load from Kangjie Examples

In [14]:
with open('data/MiniF2F/example_deductive_proof.lean', 'r') as f:
    examples = f.read().split('example\n')[1:]

In [28]:
# samples = []

# for (i, ex) in examples:
#     elems = ex.split('\n\n')
#     assert elems[0].endswith(':= by')
#     formal_statement = 'example\n' + elems[0][:-len(':= by')] + ':= sorry'
#     solution_steps = [
#         remove_min_whitespace(e) for e in elems[1:] if len(remove_comments(e).strip()) > 0
#     ]
#     assert solution_steps[-1] == 'exact h_answer'
#     sample = await server.load_problem_async(SolutionAutoformalizationResult(
#         header=OPEN_HEADER,
#         formal_statement=formal_statement
#     ))
#     sample.metainfo['solution_state_transition'] = [(None, s) for s in solution_steps]
#     samples.append(
#         sample
#     )
#     print(f'{i}/{len(examples)}')
#     print(formal_statement)

# with open('data/MiniF2F/example_deductive_proof.pkl', 'wb') as f:
#     pickle.dump(samples, f)

with open('data/MiniF2F/example_deductive_proof.pkl', 'rb') as f:
    samples = pickle.load(f)

## Choose Sample

In [29]:
sample = samples[-2]

solution_transitions = sample.metainfo['solution_state_transition'][:]
formal_proofs = sample.formal_proofs[:]

print(sample.formal_statement)


example
  -- $x, y$ are two real numbers
  (x y : Real)
  -- $x, y$ are both greater than 1
  (hx : x > 1)
  (hy : y > 1)
  -- $\\log_x (y^x) = \\log_y (x^(4y)) = 10$
  (hxy1 : Real.logb x (y^x) = 10)
  (hxy2 : Real.logb y (x^(4*y)) = 10)
  -- find the value of $xy$
  (answer : Real)
  (h_answer : answer = x*y)
  -- # Answer 25
: answer = 25
:= sorry


## Dependency Analysis

In [None]:
# Load data point
if 'solution_state_transition' in sample.metainfo.keys():   # Cycle 0, 1, 2
    solution_blocks = [s[1] for s in sample.metainfo['solution_state_transition']]
    sample.metainfo.pop('solution_state_transition')
else:
    solution_blocks = solution_decompose(sample['formal_solution_draft'])   # Cycle 3
formal_proofs = sample.formal_proofs[:]
# logger.debug(f'async_worker({split_base_cnt+idx}): sample.formal_statement:\n{sample.formal_statement}')

# Sanity check
solution_draft_normalized = normalize_draft('\n'.join([s for s in solution_blocks]))
matches = list(re.finditer(':= sorry', solution_draft_normalized))
assert len(matches) == len(formal_proofs), f'`len(matches) == len(formal_proofs)` failed because {len(matches)} != {len(formal_proofs)}, unable to prune'

# Parse submission
action = remove_comments(solution_blocks[-1]).strip()
assert action.startswith('exact '), action
submission_name = action[len('exact '):]
# logger.debug(f'async_worker({split_base_cnt+idx}): submission_name: {submission_name}')

# Initialize
forward_state = await server.init_forward_reasoning_state_async(sample)
assert len(forward_state.goals) == 1, str(forward_state)
assert all('✝' not in v.name for v in forward_state.goals[0].variables), str(forward_state)

dependency_graph = nx.DiGraph()
hard_dependencies_global = []
parsed_steps = [
    ProblemGenerationStep(
        step_draft=f'have {v.name} : {v.t} := sorry' if v.v is None else f'let {v.name} : {v.t} := {v.v}',
        proof=None,
        new_contexts=[v]
    ) for v in forward_state.goals[0].variables
]
fvarid_to_istep = {
    v.raw_name : i for (i, v) in enumerate(forward_state.goals[0].variables)
}
i_proof = 0

# Add dependencies between current `parsed_steps` (hypotheses)
dependency_graph.add_nodes_from(range(len(parsed_steps)))
for (i, v) in enumerate(parsed_steps):
    idents = parse_idents(v.new_contexts[0].t)
    for (j, u) in enumerate(parsed_steps[:i]):
        if u.new_contexts[0].name in idents:
            # edge (u, v): v depends on u
            dependency_graph.add_edge(j, i)

# Depednency between proof scripts
for i_step, draft_step in enumerate(solution_blocks[:-1]):
    # 1. Execute current step
    normalized_draft_step = normalize_draft(draft_step)
    if 'sorry' in parse_idents(normalized_draft_step):
        new_forward_state = await server.tactic_server.goal_tactic_async(forward_state, 0, TacticDraft('by\n' + normalized_draft_step + '\nsorry'))
    else:
        new_forward_state = await server.tactic_server.goal_tactic_async(forward_state, 0, normalized_draft_step)

    assert new_forward_state.goals[-1].target == 'False', str(new_forward_state)
    n_sorries = len(new_forward_state.goals) - 1
    for p in sample.formal_proofs[i_proof:i_proof+n_sorries]:
        new_forward_state = await server.tactic_server.goal_tactic_async(new_forward_state, 0, '{\n' + '\n'.join([remove_min_whitespace(s[1]) for s in p.proof]) + '\n}')
    
    assert len(new_forward_state.goals) == 1 and new_forward_state.goals[0].target == 'False', str(new_forward_state)
    
    # 2. Analyze state difference
    new_contexts = [
        v for v in new_forward_state.goals[0].variables if
            v.raw_name not in {vv.raw_name for vv in forward_state.goals[0].variables}
            # v not in forward_state.goals[0].variables
    ]
    if len(new_contexts) == 0:
        logger.warning(f'async_worker({split_base_cnt+idx}): Unused step: {[normalized_draft_step]}')
    for v in new_contexts:
        # assert v.raw_name not in fvarid_to_istep.keys()
        fvarid_to_istep[v.raw_name] = len(parsed_steps) # Maybe override!
    
    # 3.1 Add parsed step
    cur_step = ProblemGenerationStep(
        step_draft=draft_step,
        proof=['\n'.join([remove_min_whitespace(s[1]) for s in p.proof]) for p in sample.formal_proofs[i_proof:i_proof+n_sorries]],
        new_contexts=new_contexts
    )
    # logger.debug(f'async_worker({split_base_cnt+idx}): Step: {cur_step.step}')
    parsed_steps.append(cur_step)
    dependency_graph.add_node(len(parsed_steps)-1)
    i_proof += n_sorries
    # 3.2 Coarse-grained dependency
    # - Case 1. types in new_contexts
    # - Case 2. proofs
    
    # 4. (Optional) Validate assumption: forward_state.goals[0].variables is topologically sorted
    tmp_parsing_state = forward_state
    while len(tmp_parsing_state.goals[0].variables) > 0:
        tmp_parsing_state = await server.tactic_server.goal_tactic_async(tmp_parsing_state, 0, f'clear! {tmp_parsing_state.goals[0].variables[-1].name}')
    assert str(tmp_parsing_state) == '⊢ False', str(tmp_parsing_state)
    
    # 5. Analyze dependency
    soft_dependencies = set()    # Set of fVarId. Removing which will corrupt other variables
    hard_dependencies = set()    # Set of fVarId. Removing which will make the current step unable to prove
    # Try removing `v` and re-executing cur_step
    # Assumption: tmp_parsing_state.goals[0].variables is topologically sorted
    tmp_parsing_state = forward_state
    
    for v in forward_state.goals[0].variables:
        assert v.raw_name not in soft_dependencies and v.raw_name not in hard_dependencies, f'v.raw_name={v.raw_name}, soft_dependencies={soft_dependencies}, hard_dependencies={hard_dependencies}'
        
        # Shall we try clearing steps introducing `v` and all variables dependent on it?
        # No. Because this clearing is in reversed order. If some variable `u` is dependent on `v`
        # - Case 1. `s` does not depend on `u`: `u` is already removed
        # - Case 2. `s` depends on `u`: it does not matter if we still connect `v` with `u`.
        # TODO: 08.05 - Current impl. is not reversed order. try clear!

        # 5.1. Find v
        v_to_remove = [vv for vv in tmp_parsing_state.goals[0].variables if vv.raw_name == v.raw_name]
        if len(v_to_remove) == 0:
            continue
        assert len(v_to_remove) == 1, str(v_to_remove)    # `tmp_parsing_state` is constructed by iteratively removing variables in forward_state, thus must find exactly one
        v_to_remove = v_to_remove[0]
        
        # 5.2. Try removing `v`
        if '✝' not in v.name:
            try:
                new_tmp_parsing_state = await server.tactic_server.goal_tactic_async(tmp_parsing_state, 0, f'clear! {v.name}')
            except TacticFailure as e:
                soft_dependencies.add(v.raw_name)
                import pdb; pdb.set_trace()
                logger.warning(f'async_worker({split_base_cnt+idx}): Cannot remove {v} ({[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]})')
                continue
        else:
            n_inaccessible_after = 0
            for vv in reversed(tmp_parsing_state.goals[0].variables):
                if vv.raw_name == v.raw_name:
                    break
                else:
                    if '✝' in vv.name:
                        n_inaccessible_after += 1
            assert all(vv.name != '_TMP_NAME_TO_REMOVE' for vv in tmp_parsing_state.goals[0].variables), str(tmp_parsing_state)
            new_tmp_parsing_state = await server.tactic_server.goal_tactic_async(tmp_parsing_state, 0, f'rename_i _TMP_NAME_TO_REMOVE' + ' _' * n_inaccessible_after)
            
            all_to_temove = [vv for vv in new_tmp_parsing_state.goals[0].variables if vv.name == '_TMP_NAME_TO_REMOVE']
            assert len(all_to_temove) == 1 and all_to_temove[0].raw_name == v_to_remove.raw_name, f'all_to_temove={all_to_temove}, v_to_remove={v_to_remove}'
            
            try:
                new_tmp_parsing_state = await server.tactic_server.goal_tactic_async(new_tmp_parsing_state, 0, f'clear! _TMP_NAME_TO_REMOVE')
                # Try clear!
            except TacticFailure as e:
                soft_dependencies.add(v.raw_name)
                import pdb; pdb.set_trace()
                logger.warning(f'async_worker({split_base_cnt+idx}): Cannot remove {v} ({[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]})')
                continue
        
        # 5.3. Try executing cur_step
        try:
            test_tmp_parsing_state = await server.tactic_server.goal_tactic_async(new_tmp_parsing_state, 0, cur_step.step)
            tmp_parsing_state = new_tmp_parsing_state
        except TacticFailure as e:
            hard_dependencies.add(v.raw_name)
            hard_dependencies_global.append((parsed_steps[fvarid_to_istep[v.raw_name]], cur_step))
            # logger.debug(f'async_worker({split_base_cnt+idx}): {[vv.name for vv in cur_step.new_contexts]} depends on {[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]}')
            continue
        # logger.info(f'Removed {v} ({[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]})')
    # logger.info(f'Final removing state: {test_tmp_parsing_state}')
    
    # 6. Iteration end
    if len(soft_dependencies) > 0:
        import pdb; pdb.set_trace()
        logger.warning(f'async_worker({split_base_cnt+idx}): len(soft_dependencies) > 0: {soft_dependencies}')
    for d in I.chain(soft_dependencies, hard_dependencies):
        # edge (u, v): v depends on u
        # logger.info(f'async_worker({split_base_cnt+idx}): Adding dependency: {[vv.name for vv in parsed_steps[fvarid_to_istep[d]].new_contexts]} -> {[vv.name for vv in cur_step.new_contexts]}')
        dependency_graph.add_edge(fvarid_to_istep[d], len(parsed_steps)-1)
    
    forward_state = new_forward_state

assert i_proof == len(formal_proofs), f'i_proof={i_proof}, len(formal_proofs)={len(formal_proofs)}'
assert submission_name in [v.name for v in parsed_steps[-1].new_contexts], f'submission_name={submission_name}, new_context={[v.name for v in parsed_steps[-1].new_contexts]}'

# Add submission step
submission_step = ProblemGenerationStep(
    step_draft=f'submit_answer {submission_name}',
    proof=None,
    new_contexts=None
)
dependency_graph.add_node(len(parsed_steps))
for (i, s) in reversed(list(enumerate(parsed_steps))):
    if submission_name in [v.name for v in s.new_contexts]:
        dependency_graph.add_edge(i, len(parsed_steps))
        break
assert dependency_graph.in_degree(len(parsed_steps)) == 1, f'dependency_graph.in_degree(submission_step)={dependency_graph.in_degree(len(parsed_steps))}'
parsed_steps.append(submission_step)

# Reduce transitive edges; Compute depths
reduced_dependency_graph = nx.algorithms.dag.transitive_reduction(dependency_graph)
depth_dict = {n : 0 for n in range(len(parsed_steps))}
for u in nx.topological_sort(reduced_dependency_graph):
    for v in reduced_dependency_graph.successors(u):
        depth_dict[v] = max(depth_dict[v], depth_dict[u]+1)

# Reassemble trajectories
reassembled_trajectory = []
G = reduced_dependency_graph.copy()
deductive_state = await server.tactic_server.load_statement_async('False')

# TODO: Shall we conduct backward-dfs to collect all nodes that `answer` needs?
# TODO: the current setting (depth-first) can encourage models to explore!
# TODO: Ablation on this: Graph pruning

while True:
    available_actions = sorted([n for (n, d) in G.in_degree() if d == 0], key=lambda n : (-depth_dict[n], parsed_steps[n].is_introducing))
    chosen_action = parsed_steps[available_actions[0]]
    reassembled_trajectory.append((deductive_state.goals[0].variables, available_actions[0]))
    if chosen_action.is_submitting:
        assert submission_name in [v.name for v in deductive_state.goals[0].variables], f'submission_name={submission_name}, deductive_state={deductive_state}'
        if not set(deductive_state.goals[0].variables).issubset(set(forward_state.goals[0].variables)):
            logger.warning(f'¬(deductive_state ⊆ forward_state): {deductive_state.goals[0].variables}, {forward_state.goals[0].variables}')
        break
    deductive_state = await server.tactic_server.goal_tactic_async(deductive_state, 0, chosen_action.step)
    G.remove_node(available_actions[0])

ret = ProblemGenerationProcess(
    informal_problem=sample.informal_problem,
    informal_answer=sample.informal_answer,
    informal_solution=sample.informal_solution,
    header=sample.header,
    formal_statement=sample.formal_statement,
    formal_solution_draft=sample.formal_solution_draft,
    formal_proofs=[
        '\n'.join([remove_min_whitespace(s[1]) for s in p.proof]) for p in sample.formal_proofs
    ],
    steps=parsed_steps,
    dependencies=[e for e in dependency_graph.edges],
    trajectory=reassembled_trajectory,
    metainfo=json.dumps(sample.metainfo)
)

with open(f'./reassembled_trajectory.txt', 'w') as f:
    for state_vars, chosen_action in reassembled_trajectory:
        f.write('### State\n```lean4\n' + str(Goal(
            variables=state_vars,
            target=deductive_state.goals[0].target,
            sibling_dep=None,
            name=None)) + '\n```\n### Action\n```lean4\n' + parsed_steps[chosen_action].step + '\n```\n\n\n')

In [27]:
submission_name

'h_answer'

### Legacy

In [10]:
state, action = solution_transitions[-1]
action = remove_comments(action).strip()
assert action.startswith('exact ')
submission_name = action[len('exact '):]
print(submission_name)

h_answer


In [None]:
forward_state = await server.init_forward_reasoning_state_async(sample)
assert len(forward_state.goals) == 1


In [None]:
assert all('✝' not in v.name for v in forward_state.goals[0].variables)

In [None]:
solution_draft_normalized = normalize_draft('\n'.join([s[1] for s in solution_transitions]))
matches = list(re.finditer(':= sorry', solution_draft_normalized))
assert len(matches) == len(formal_proofs), f'`len(matches) == len(formal_proofs)` failed because {len(matches)} != {len(formal_proofs)}, unable to prune'

In [None]:
dependency_graph = nx.DiGraph()
hard_dependencies_global = []
parsed_steps = [
    ProblemGenerationStep(
        step_draft=f'have {v.name} : {v.t} := sorry' if v.v is None else f'let {v.name} : {v.t} := {v.v}',
        proof=None,
        new_contexts=tuple([v])
    ) for v in forward_state.goals[0].variables
]
fvarid_to_istep = {
    v.raw_name : i for (i, v) in enumerate(forward_state.goals[0].variables)
}
i_proof = 0


# Add dependencies between current `parsed_steps` (hypotheses)
dependency_graph.add_nodes_from(parsed_steps)
for (i, v) in enumerate(parsed_steps):
    idents = parse_idents(v.new_contexts[0].t)
    for u in parsed_steps[:i]:
        if u.new_contexts[0].name in idents:
            # edge (u, v): v depends on u
            dependency_graph.add_edge(u, v)

In [None]:
for i_step, (_, draft_step) in enumerate(solution_transitions[:-1]):
    # 1. Execute current step
    normalized_draft_step = normalize_draft(draft_step)
    if 'sorry' in parse_idents(normalized_draft_step):
        new_forward_state = await server.tactic_server.goal_tactic_async(forward_state, 0, TacticDraft('by\n' + normalized_draft_step + '\nsorry'))
    else:
        new_forward_state = await server.tactic_server.goal_tactic_async(forward_state, 0, normalized_draft_step)

    assert new_forward_state.goals[-1].target == 'False'
    n_sorries = len(new_forward_state.goals) - 1
    for p in sample.formal_proofs[i_proof:i_proof+n_sorries]:
        new_forward_state = await server.tactic_server.goal_tactic_async(new_forward_state, 0, '{\n' + '\n'.join([remove_min_whitespace(s[1]) for s in p.proof]) + '\n}')
    
    assert len(new_forward_state.goals) == 1 and new_forward_state.goals[0].target == 'False'
    
    # 2. Analyze state difference
    new_contexts = [
        v for v in new_forward_state.goals[0].variables if
            v.raw_name not in {vv.raw_name for vv in forward_state.goals[0].variables}
            # v not in forward_state.goals[0].variables
    ]
    if len(new_contexts) == 0:
        logger.warning(f'Unused step: {[normalized_draft_step]}')
    for v in new_contexts:
        # assert v.raw_name not in fvarid_to_istep.keys()
        fvarid_to_istep[v.raw_name] = len(parsed_steps) # Maybe override!
    
    # 3.1 Add parsed step
    cur_step = ProblemGenerationStep(
        step_draft=draft_step,
        proof=tuple(['\n'.join([remove_min_whitespace(s[1]) for s in p.proof]) for p in sample.formal_proofs[i_proof:i_proof+n_sorries]]),
        new_contexts=tuple(new_contexts)
    )
    logger.info(f'Step: {cur_step.step}')
    parsed_steps.append(cur_step)
    dependency_graph.add_node(cur_step)
    i_proof += n_sorries
    # 3.2 Coarse-grained dependency
    # - Case 1. types in new_contexts
    # - Case 2. proofs
    
    # 4. (Optional) Validate assumption: forward_state.goals[0].variables is topologically sorted
    tmp_parsing_state = forward_state
    while len(tmp_parsing_state.goals[0].variables) > 0:
        tmp_parsing_state = await server.tactic_server.goal_tactic_async(tmp_parsing_state, 0, f'clear {tmp_parsing_state.goals[0].variables[-1].name}')
    assert str(tmp_parsing_state) == '⊢ False'
    
    # 5. Analyze dependency
    soft_dependencies = set()    # Set of fVarId. Removing which will corrupt other variables
    hard_dependencies = set()    # Set of fVarId. Removing which will make the current step unable to prove
    # Try removing `v` and re-executing cur_step
    # Assumption: tmp_parsing_state.goals[0].variables is topologically sorted
    tmp_parsing_state = forward_state
    
    for v in forward_state.goals[0].variables:  # TODO: Debug, removed 'reversed'
        assert v.raw_name not in soft_dependencies and v.raw_name not in hard_dependencies
        
        # Shall we try clearing steps introducing `v` and all variables dependent on it?
        # No. Because this clearing is in reversed order. If some variable `u` is dependent on `v`
        # - Case 1. `s` does not depend on `u`: `u` is already removed
        # - Case 2. `s` depends on `u`: it does not matter if we still connect `v` with `u`.
        # TODO: Maybe add an extra edge pruning to remove all non-direct dependencies? (i.e. dependencies that can be constructed by transitive dependency)

        # 5.1. Find v
        v_to_remove = [vv for vv in tmp_parsing_state.goals[0].variables if vv.raw_name == v.raw_name]
        assert len(v_to_remove) == 1    # `tmp_parsing_state` is constructed by iteratively removing variables in forward_state, thus must find exactly one
        v_to_remove = v_to_remove[0]
        
        # 5.1. Try removing `v`
        if '✝' not in v.name:
            try:
                new_tmp_parsing_state = await server.tactic_server.goal_tactic_async(tmp_parsing_state, 0, f'clear {v.name}')
            except TacticFailure:
                soft_dependencies.add(v.raw_name)
                logger.info(f'Cannot remove {v} ({[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]})')
                continue
        else:
            n_inaccessible_after = 0
            for vv in reversed(tmp_parsing_state.goals[0].variables):
                if vv.raw_name == v.raw_name:
                    break
                else:
                    if '✝' in vv.name:
                        n_inaccessible_after += 1
            assert all(vv.name != '_TMP_NAME_TO_REMOVE' for vv in tmp_parsing_state.goals[0].variables)
            new_tmp_parsing_state = await server.tactic_server.goal_tactic_async(tmp_parsing_state, 0, f'rename_i _TMP_NAME_TO_REMOVE' + ' _' * n_inaccessible_after)
            
            all_to_temove = [vv for vv in new_tmp_parsing_state.goals[0].variables if vv.name == '_TMP_NAME_TO_REMOVE']
            assert len(all_to_temove) == 1 and all_to_temove[0].raw_name == v_to_remove.raw_name
            
            try:
                new_tmp_parsing_state = await server.tactic_server.goal_tactic_async(tmp_parsing_state, 0, f'clear _TMP_NAME_TO_REMOVE')
            except TacticFailure:
                soft_dependencies.add(v.raw_name)
                logger.info(f'Cannot remove {v} ({[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]})')
                continue
        
        # Step 2. Try executing cur_step
        try:
            test_tmp_parsing_state = await server.tactic_server.goal_tactic_async(new_tmp_parsing_state, 0, cur_step.step)
            tmp_parsing_state = new_tmp_parsing_state
        except TacticFailure:
            hard_dependencies.add(v.raw_name)
            hard_dependencies_global.append((parsed_steps[fvarid_to_istep[v.raw_name]], cur_step))
            logger.info(f'{[vv.name for vv in cur_step.new_contexts]} depends on {[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]}')
            continue
        logger.info(f'Removed {v} ({[vv.name for vv in parsed_steps[fvarid_to_istep[v.raw_name]].new_contexts]})')
    logger.info(f'Final removing state: {test_tmp_parsing_state}')
    # 6. Iteration end
    for d in I.chain(soft_dependencies, hard_dependencies):
        # edge (u, v): v depends on u
        logger.info(f'Adding dependency: {[vv.name for vv in parsed_steps[fvarid_to_istep[d]].new_contexts]} -> {[vv.name for vv in cur_step.new_contexts]}')
        dependency_graph.add_edge(parsed_steps[fvarid_to_istep[d]], cur_step)
    
    forward_state = new_forward_state

assert i_proof == len(formal_proofs)

In [None]:
assert submission_name in [v.name for v in parsed_steps[-1].new_contexts]

submission_step = ProblemGenerationStep(
    step_draft=f'submit_answer {submission_name}',
    proof=None,
    new_contexts=None
)
dependency_graph.add_node(submission_step)
for s in reversed(parsed_steps):
    if submission_name in [v.name for v in s.new_contexts]:
        dependency_graph.add_edge(s, submission_step)
        break
assert dependency_graph.in_degree(submission_step) == 1
parsed_steps.append(submission_step)

In [None]:
for step in parsed_steps:
    print(normalize_draft(step.step))
    print('\n'.join(v.raw_name + ' ' + str(v) for v in (step.new_contexts or [])))
    print('\n')

In [None]:
len(dependency_graph.nodes), len(parsed_steps)

In [None]:
import matplotlib.pyplot as plt

In [None]:
reduced_dependency_graph = nx.algorithms.dag.transitive_reduction(dependency_graph)

pos = nx.nx_agraph.graphviz_layout(reduced_dependency_graph, prog="dot", args="")
plt.figure(figsize=(12, 8))

# pos = nx.bfs_layout(dependency_graph, parsed_steps[0])
# pos = forest_bfs_layout(dependency_graph, [node for node, in_degree in dependency_graph.in_degree() if in_degree == 0])

color_map = ['orange' if node.is_submitting else 'cyan' if node.is_deducing else 'green' for node in reduced_dependency_graph.nodes]
# edge_styles = [
#     ('--' if e not in hard_dependencies_global else '-') for e in reduced_dependency_graph.edges
# ]

labels = {node: '\n'.join(str(v) for v in (node.new_contexts or [node.step])) for node in reduced_dependency_graph}

nx.draw(reduced_dependency_graph, pos, with_labels=True, labels=labels, node_size=800, font_size=6, node_color=color_map)
plt.show()

In [None]:
#* For visualization only
direct_dependency_graph = nx.algorithms.dag.transitive_reduction(dependency_graph)
direct_dependency_graph.add_edges_from(hard_dependencies_global) # Add this -> direct dependency graph

pos = nx.nx_agraph.graphviz_layout(direct_dependency_graph, prog="dot", args="")
plt.figure(figsize=(12, 8))

# pos = nx.bfs_layout(dependency_graph, parsed_steps[0])
# pos = forest_bfs_layout(dependency_graph, [node for node, in_degree in dependency_graph.in_degree() if in_degree == 0])

color_map = ['orange' if node.is_submitting else 'cyan' if node.is_deducing else 'green' for node in reduced_dependency_graph.nodes]
# edge_styles = [
#     ('--' if e not in hard_dependencies_global else '-') for e in reduced_dependency_graph.edges
# ]

labels = {node: '\n'.join(str(v) for v in (node.new_contexts or [node.step])) for node in direct_dependency_graph}

nx.draw(direct_dependency_graph, pos, with_labels=True, labels=labels, node_size=800, font_size=6, node_color=color_map)
plt.show()

## Exploratory Action Sequence Reassembling

In [None]:
depth_dict = {n : 0 for n in parsed_steps}

In [None]:
for u in nx.topological_sort(reduced_dependency_graph):
    for v in reduced_dependency_graph.successors(u):
        depth_dict[v] = max(depth_dict[v], depth_dict[u]+1)

In [None]:
pos = nx.nx_agraph.graphviz_layout(reduced_dependency_graph, prog="dot", args="")
plt.figure(figsize=(12, 8))


color_map = ['orange' if node.is_submitting else 'cyan' if node.is_deducing else 'green' for node in reduced_dependency_graph.nodes]

labels = {n: str(depth_dict[n]) + '\n' + '\n'.join(str(v) for v in (n.new_contexts or [n.step])) for n in reduced_dependency_graph.nodes}

nx.draw(reduced_dependency_graph, pos, with_labels=True, labels=labels, node_size=800, font_size=6, node_color=color_map)
plt.show()

In [None]:
reassembled_trajectory = []
G = reduced_dependency_graph.copy()
deductive_state = await server.tactic_server.load_statement_async('False')

while True:
    available_actions = sorted([n for (n, d) in G.in_degree() if d == 0], key=lambda n : (-depth_dict[n], n.is_introducing))
    chosen_action = available_actions[0]
    reassembled_trajectory.append((deductive_state.goals[0].variables, chosen_action))
    if chosen_action.is_submitting:
        assert submission_name in [v.name for v in deductive_state.goals[0].variables]
        if not set(deductive_state.goals[0].variables).issubset(set(forward_state.goals[0].variables)):
            logger.warning(f'¬(deductive_state ⊆ forward_state): {deductive_state.goals[0].variables}, {forward_state.goals[0].variables}')
        break
    deductive_state = await server.tactic_server.goal_tactic_async(forward_state, 0, chosen_action.step)
    G.remove_node(chosen_action)

In [None]:
pos = nx.nx_agraph.graphviz_layout(reduced_dependency_graph, prog="dot", args="")
plt.figure(figsize=(12, 8))


color_map = ['orange' if node.is_submitting else 'cyan' if node.is_deducing else 'green' for node in reduced_dependency_graph.nodes]

order_dict = {
    n : i for i, (s, n) in enumerate(reassembled_trajectory)
}
labels = {n: f'{order_dict[n]}, {depth_dict[n]}' + '\n' + '\n'.join(str(v) for v in (n.new_contexts or [n.step])) for n in reduced_dependency_graph.nodes}

nx.draw(reduced_dependency_graph, pos, with_labels=True, labels=labels, node_size=800, font_size=6, node_color=color_map)
plt.show()