In [None]:
import json
from pathlib import Path
import regex as re
import collections as C
import itertools as I

from loguru import logger
import networkx as nx

from common.utils import remove_singleline_comments, remove_multiline_comments, remove_comments, replace_sorry, replace_calc


def rglob_full(path, pattern):
    if path.is_symlink():
        path = path.resolve()
    if path.is_dir():
        for p in path.iterdir():
            yield from rglob_full(p, pattern)
    if path.match(pattern):
        yield path

def shortest_distance_from_root(search_graph, target_node):
    root = [node for node, in_degree in search_graph.in_degree() if in_degree == 0]
    assert len(root) == 1
    root = root[0]

    assert not(root not in search_graph.nodes or target_node not in search_graph.nodes)
    
    distance = nx.shortest_path_length(search_graph, source=root, target=target_node)
    return distance

def solution_decompose(formal_solution_draft: str) -> list[str]:
    '''Decompose a formal solution draft into steps (blocks)'''
    raw_lines = replace_sorry(replace_calc(remove_multiline_comments(formal_solution_draft))).rstrip().split('\n')
    if len(remove_comments(formal_solution_draft).strip()) == 0:
        return []

    line_stack = list(reversed(raw_lines))
    parse_result = []

    while len(line_stack) > 0:
        line = line_stack.pop().rstrip()
        if line.strip() == '':
            # Current line is empty: skip
            continue
        
        # If submitted, add the rest and exit
        if line.startswith('exact '):
            cur_block = line + '\n' + '\n'.join(reversed(line_stack))
            parse_result.append(cur_block)
            break
        
        cur_block = line
        try:
            if remove_singleline_comments(line) == '':
                # Current line is a root-level comment: Add following lines, until empty line or another root-level comment
                
                # 1. Add consecutive comments
                while len(line_stack) > 0 and remove_singleline_comments(line_stack[-1]).strip() == '':
                    cur_block += '\n' + line_stack.pop().rstrip()
                # 2. When encounter tactics, add them and following lines, until empty line or another root-level comment
                assert len(line_stack) > 0 and remove_singleline_comments(line_stack[-1]).strip() != '', 'Comments-probe failed'
                # Add tactic and structures
                
                # If submitted, add the rest and exit
                if line_stack[-1].startswith('exact '):
                    cur_block += '\n' + '\n'.join(reversed(line_stack))
                    parse_result.append(cur_block)
                    break
                
                # Else: add the current structure
                cur_block += '\n' + line_stack.pop().rstrip()
                while len(line_stack) > 0 and (line_stack[-1].strip() == '' or remove_singleline_comments(line_stack[-1]).startswith(' ')):
                    cur_block += '\n' + line_stack.pop().rstrip()
                parse_result.append(cur_block)
            else:
                # Tactic
                while len(line_stack) > 0 and (line_stack[-1].strip() == '' or remove_singleline_comments(line_stack[-1]).startswith(' ')):
                    cur_block += '\n' + line_stack.pop().rstrip()
                parse_result.append(cur_block)
        except:
            parse_result.append(cur_block)
            break

    assert '\n'.join(I.chain(*[[l.rstrip() for l in r.split('\n') if l.strip() != ''] for r in parse_result])) == '\n'.join(l.rstrip() for l in raw_lines if l.strip() != ''), 'Reconstruction failed'
    return parse_result

def replace_block(s: str) -> str:
    pattern = r':= by \{.*?\}'
    return re.sub(pattern, 'sorry', s, flags=re.DOTALL)

In [None]:
exp_root = 'output/'

benchmarks = ['formal_math500', 'minif2f_solving']
benchmark_rejections = {
    'formal_math500' : [31, 42, 111, 119, 124, 188, 196, 204, 205],
    'minif2f_solving' : [34, 46, 263, 312, 339],
}

benchmark_lengths = {
    'formal_math500': 387 - len(benchmark_rejections['formal_math500']),
    'minif2f_solving': 375 - len(benchmark_rejections['minif2f_solving']),
}
print(benchmark_lengths)

In [None]:
results_all = C.defaultdict(    # Cycle
    lambda : C.defaultdict(     # Experiment Name
        lambda : C.defaultdict( # Benchmark
            list                # List of {True, False} (Attempts)
        )
    )
)

In [None]:
for p in rglob_full(Path(exp_root), "*.jsonl"):
    print(p)
    cycle, exp_name, model, benchmark = p.parts[-4:]
    try:
        cycle = int(cycle[len('cycle'):])
    except:
        pass
    print(cycle, exp_name, model, benchmark)
    benchmark = benchmark.split('.')[0]
    if benchmark not in benchmarks:
        print(benchmark)
        continue
    rejections = benchmark_rejections[benchmark]

    if any(k in exp_name for k in ['ar']):
        with open(p, 'r') as f:
            results = [json.loads(l) for l in f.readlines()]
        for i, rs in enumerate(results):
            if i in rejections or rs is None:
                continue
            results_all[cycle][(exp_name, model)][benchmark].append(
                [(r.get('success', False), len(r['solution'])) for r in rs]
            )
    elif any(k in exp_name for k in ['bfs']):
        with open(p, 'r') as f:
            results = [json.loads(l) for l in f.readlines()]

        for i, rs in enumerate(results):
            if i in rejections or rs is None:
                continue
            correctness = []

            for j, r in enumerate(rs):
                solution_len = float('inf')
                c = False
                search_graph = nx.readwrite.json_graph.node_link_graph(r['search_graph'])
                for n in r['final_nodes']:
                    node_data = search_graph.nodes(data=True)[n]
                    if len(node_data.get('rpe_proof', None) or '') > 0:
                        c = True
                        solution_len = shortest_distance_from_root(search_graph, n)
                        break
                correctness.append((c, len(search_graph.edges), solution_len))
            results_all[cycle][(exp_name, model)][benchmark].append(correctness)
    elif any(k in exp_name for k in ['wg', 'sa']):
        with open(p, 'r') as f:
            results = [json.loads(l) for l in f.readlines()]
        for i, rs in enumerate(results):
            if i in rejections or rs is None:
                continue
            results_all[cycle][(exp_name, model)][benchmark].append(
                [(r['success'], float('inf') if r.get('formal_solution_draft') is None else len(solution_decompose(replace_block(r['formal_solution_draft'])))) for r in rs]
            )
    else:
        logger.warning(f'Method not suitable for RPE: {exp_name}')
        continue
    print(len(results))

In [None]:
print('Cycle\tExperiment\tModel\tBenchmark\tAll\tSolved\tBudget\tSteps(correct)')
for c, results in results_all.items():
    for (e, m) in sorted(results.keys()):
        rs = results[(e, m)]
        for b in benchmarks:
            if len(rs[b]) == 0:
                continue
            n_steps_list = []
            n_all = 0
            n_correct = 0
            n_steps = 0
            n_correct_steps = 0
            for r in rs[b]:
                if all(k not in e for k in ['wg', 'sa']):    # AR / BFS / HAR / H-BFS
                    if any(v[0] for v in r):
                        n_steps_list.append(0)
                        for v in r:
                            n_steps += v[1]
                            n_steps_list[-1] += v[1]
                            if v[0]:
                                n_correct += 1
                                n_correct_steps += v[-1]
                                break
                    else:
                        n_steps += 80
                        n_steps_list.append(81)
                else:   # WG / H-WG
                    n_steps_list.append(0)
                    for v in r:
                        n_steps += v[1]
                        n_steps_list[-1] += v[1]
                        if v[0]:
                            n_correct += 1
                            n_correct_steps += v[-1]
                            break

            print(c, e, m, b, len(rs[b]), n_correct, n_steps, n_correct_steps, sep='\t', end='\n')