In [1]:
import os
import os.path as osp
import json
import pickle
import collections as C
import itertools as I
import random
import regex as re
from typing import List, Optional, Dict, Tuple
import time

import msgspec
from tqdm import tqdm
from loguru import logger
from dacite import from_dict
import dacite

from common.constants import SYSTEM_PROMPT_FPG, CORE_OPTIONS, BANNED_TOKENS
from common.utils import remove_comments, replace_sorry, replace_calc, remove_multiline_comments, remove_singleline_comments, parse_idents, normalize_spaces
from common.pantograph.dataclasses import ProblemGenerationProcess, ProblemGenerationStep, Variable, normalize_draft, replace_span, Goal, GoalState, ProblemGenerationStep, ProblemGenerationProcess, TacticDraft
from common.pantograph.server import PersistentServer, TacticFailure, ServerError
from agent.problem_generation import AutoregressiveProblemGenerationAgent



In [4]:
dones_root = '/home/ma-user/workspace/formal_problem_generation/data/FineLeanCorpus/dones'
dones_different_root = '/home/ma-user/workspace/formal_problem_generation/data/FineLeanCorpus/different_dones_32672'
n_chunks = 498

In [5]:
data_main = []
data_extra = []

for i_chunk in tqdm(list(range(n_chunks))):
    with open(osp.join(dones_root, f'done_chunk_{1024 * i_chunk}.pkl'), 'rb') as f:
        chunk = pickle.load(f)
    with open(osp.join(dones_root, f'reassembled_fixed_chunk_{1024 * i_chunk}.pkl'), 'rb') as f:
        chunk_fixed = pickle.load(f)
    assert len(chunk) == len(chunk_fixed)
    for (i, c) in enumerate(chunk_fixed):
        if c is not None:
            assert chunk[i].informal_problem == c.informal_problem
            chunk[i] = c
    
    if osp.exists(osp.join(dones_different_root, f'reassembled_fixed_chunk_{1024 * i_chunk}.pkl')):
        with open(osp.join(dones_different_root, f'done_chunk_{1024 * i_chunk}.pkl'), 'rb') as f:
            chunk_extra = pickle.load(f)
        with open(osp.join(dones_different_root, f'reassembled_fixed_chunk_{1024 * i_chunk}.pkl'), 'rb') as f:
            chunk_fixed = pickle.load(f)
        assert len(chunk_extra) == len(chunk_fixed)
        for (i, c) in enumerate(chunk_fixed):
            if c is not None:
                assert chunk_extra[i].informal_problem == c.informal_problem
                chunk_extra[i] = c
        
        diff_cnt = 0
        for i, (d1, d2) in enumerate(zip(chunk, chunk_extra)):
            # if d1 is None:
            #     data_reassembled.append(d2)
            # elif d2 is None:
            #     data_reassembled.append(d1)
            # else:
            assert d1.informal_problem == d2.informal_problem
            meta1 = d1.metainfo if isinstance(d1.metainfo, dict) else json.loads(d1.metainfo)
            meta2 = d2.metainfo if isinstance(d2.metainfo, dict) else json.loads(d2.metainfo)
            assert meta1['id'] == meta2['id']
            
            # Whether falsified
            if 'falsified_model' in meta1.keys():
                continue
            elif 'falsified_model' in meta2.keys():
                diff_cnt += 1
                chunk[i] = d2
            else:
                # Whether reassembled
                if 'original_trajectory' in meta1.keys():
                    continue
                elif 'original_trajectory' in meta2.keys():
                    diff_cnt += 1
                    chunk[i] = d2
                else:
                    # Whether decomposed
                    if len(d1.trajectory) > 0:
                        continue
                    elif len(d2.trajectory) > 0:
                        diff_cnt += 1
                        chunk[i] = d2
                    else:
                        if 'proven_model' in meta1.keys():
                            continue
                        elif 'proven_model' in meta2.keys():
                            diff_cnt += 1
                            chunk[i] = d2
                        else:
                            pass
        print(f'Chunk({i_chunk * 1024}): diff_cnt={diff_cnt}')
    data_main.extend(chunk)

 58%|█████▊    | 291/498 [00:28<00:10, 19.35it/s]

Chunk(297984): diff_cnt=1
Chunk(299008): diff_cnt=0


 59%|█████▉    | 294/498 [00:28<00:14, 14.20it/s]

Chunk(300032): diff_cnt=0
Chunk(301056): diff_cnt=1


 59%|█████▉    | 296/498 [00:29<00:16, 12.57it/s]

Chunk(302080): diff_cnt=0
Chunk(303104): diff_cnt=1


 60%|█████▉    | 298/498 [00:32<01:30,  2.20it/s]

Chunk(304128): diff_cnt=3
Chunk(305152): diff_cnt=0
Chunk(306176): diff_cnt=2


 61%|██████    | 302/498 [00:32<00:57,  3.39it/s]

Chunk(307200): diff_cnt=2
Chunk(308224): diff_cnt=0
Chunk(309248): diff_cnt=2


 61%|██████    | 304/498 [00:33<00:47,  4.09it/s]

Chunk(310272): diff_cnt=2
Chunk(311296): diff_cnt=0


 62%|██████▏   | 307/498 [00:33<00:37,  5.11it/s]

Chunk(312320): diff_cnt=1
Chunk(313344): diff_cnt=1


 62%|██████▏   | 309/498 [00:33<00:31,  6.07it/s]

Chunk(314368): diff_cnt=1
Chunk(315392): diff_cnt=0


 62%|██████▏   | 311/498 [00:33<00:27,  6.81it/s]

Chunk(316416): diff_cnt=2
Chunk(317440): diff_cnt=1


 63%|██████▎   | 313/498 [00:34<00:25,  7.21it/s]

Chunk(318464): diff_cnt=1
Chunk(319488): diff_cnt=2


 63%|██████▎   | 315/498 [00:34<00:22,  8.00it/s]

Chunk(320512): diff_cnt=0
Chunk(321536): diff_cnt=1


 64%|██████▎   | 317/498 [00:34<00:21,  8.51it/s]

Chunk(322560): diff_cnt=1
Chunk(323584): diff_cnt=0


 64%|██████▍   | 319/498 [00:34<00:21,  8.38it/s]

Chunk(324608): diff_cnt=0
Chunk(325632): diff_cnt=3


 64%|██████▍   | 320/498 [00:35<00:21,  8.11it/s]

Chunk(326656): diff_cnt=0


 65%|██████▍   | 322/498 [00:38<02:22,  1.24it/s]

Chunk(327680): diff_cnt=1
Chunk(328704): diff_cnt=2


 65%|██████▌   | 324/498 [00:38<01:19,  2.19it/s]

Chunk(329728): diff_cnt=2
Chunk(330752): diff_cnt=2


 65%|██████▌   | 326/498 [00:38<00:48,  3.57it/s]

Chunk(331776): diff_cnt=0
Chunk(332800): diff_cnt=1


 66%|██████▌   | 328/498 [00:39<00:32,  5.19it/s]

Chunk(333824): diff_cnt=0
Chunk(334848): diff_cnt=0


 66%|██████▋   | 330/498 [00:39<00:26,  6.43it/s]

Chunk(335872): diff_cnt=0
Chunk(336896): diff_cnt=1


100%|██████████| 498/498 [01:00<00:00,  8.28it/s]


In [6]:
data_falsified = []
data_reassembled = []
data_decomposed = []
data_proven = []
data_failed = []

In [9]:
for d1 in data_main:
    if isinstance(d1.metainfo, str):
        d1.metainfo = json.loads(d1.metainfo)
    meta1 = d1.metainfo
    
    # Whether falsified
    if 'falsified_model' in meta1.keys():
        data_falsified.append(d1)
    else:
        # Whether reassembled
        if 'original_trajectory' in meta1.keys():
            data_reassembled.append(d1)
        else:
            # Whether decomposed
            if len(d1.trajectory) > 0:
                data_decomposed.append(d1)
            else:
                if 'proven_model' in meta1.keys():
                    data_proven.append(d1)
                else:
                    data_failed.append(d1)

In [10]:
len(data_falsified), len(data_reassembled), len(data_decomposed), len(data_proven), len(data_failed), 

(2614, 82438, 349, 11375, 412582)

In [12]:
C.Counter(
    d.trajectory == [
        ([dacite.from_dict(Variable, v) for v in S], s) for (S, s) in d.metainfo['original_trajectory']
    ] for d in data_reassembled
)

Counter({True: 51187, False: 31251})

In [11]:
C.Counter(len(d.trajectory) for d in data_reassembled)

Counter({2: 18512,
         4: 12370,
         3: 9425,
         6: 8156,
         8: 6274,
         5: 5253,
         10: 4112,
         7: 3650,
         11: 3275,
         9: 3220,
         12: 2586,
         14: 1403,
         13: 1346,
         15: 819,
         16: 604,
         17: 389,
         18: 283,
         19: 216,
         20: 146,
         21: 107,
         22: 81,
         23: 51,
         24: 35,
         26: 23,
         25: 19,
         27: 17,
         28: 11,
         31: 11,
         30: 10,
         29: 8,
         33: 5,
         38: 3,
         32: 3,
         35: 2,
         54: 2,
         57: 2,
         43: 1,
         53: 1,
         61: 1,
         36: 1,
         105: 1,
         39: 1,
         42: 1,
         110: 1,
         34: 1})

In [17]:
with open(osp.join('/home/ma-user/local_cache/m-a-p/FineLeanCorpus', 'FineLeanCorpus_v2.jsonl'), 'r') as f:
    data_informal = [json.loads(l) for l in f.readlines()]
print(len(data_informal))

509358


In [20]:
data_informal[0].keys()
d = data_informal[0]

In [23]:
id_to_idx = {i : d['id'] for (i, d) in enumerate(data_informal)}
print(len(id_to_idx))

509358


In [25]:
d['domain'], d['difficulty'], d['source'], d['domain_summary'], d['id']

(['Algebra -> Intermediate Algebra -> Functional Equations'],
 1,
 'AoPs',
 'The problem involves finding the value of a constant in a linear relationship between two variables and then using it to find the value of one variable given the other.',
 1)

In [26]:
C.Counter(
    d['source'] for d in data_informal
)

Counter({'AoPs': 350714,
         'DeepMath-103k': 45853,
         'NuminaMath-TIR': 45152,
         'DeepTheorem': 31409,
         'DeepScaleR': 22360,
         'DAPO-Math-17k': 8868,
         'Omni-MATH': 1181,
         'IneqMath': 1180,
         'BlueMO': 1099,
         'Multi-Source_Math_Competition': 993,
         'TAL-SCQ5K': 393,
         'OnlineMathContest': 156})

In [27]:
def format_forward_solution_step_prompt(idx: int, introduced_fvars: List[str], state: List[Variable]) -> str:
    d = data_informal[id_to_idx[idx]]
    context = ''
    vars_to_format = [v for v in state]
    while len(vars_to_format) > 0:
        for i in range(len(vars_to_format)):
            if i + 1 == len(vars_to_format) or not (vars_to_format[i].t == vars_to_format[i+1].t and vars_to_format[i].v is None and vars_to_format[i+1].v is None):
                break
        if i == 0:
            context += str(vars_to_format[0]) + '\n'
            vars_to_format.pop(0)
        else:
            context += ' '.join([v.name if v.name is not None else "_" for v in vars_to_format[:i+1]]) + f' : {vars_to_format[0].t}\n'
            vars_to_format = vars_to_format[i+1:]
    
    introduced_fvars = '\n'.join(introduced_fvars)
    prompt = f'''Given the introduced variables/hypotheses and the current context in Lean 4, propose the single most natural next step to explore toward a beautiful conclusion — either
- derive a new intermediate fact,
- introduce a fresh variable or hypothesis, or
- submit one of the local facts as the final answer.

Requirements
1. Flavoured {d['domain']} and of difficulty level {d['difficulty']}.
2. Fully formal Lean 4 code (inline comments in natural language are fine for planning and reasoning). Assume `import Mathlib`.

# Introduced Variables/Hypotheses
```lean4
{introduced_fvars}
```

# Lean 4 Context
```lean4
{context.rstrip()}
```
'''.strip()
    return prompt

def format_step(self):
    if self.proof is None:
        return self.step_draft  # Here do not remove comment
    else:
        normalized_step_draft = normalize_draft(self.step_draft)
        matches = list(re.finditer(':= sorry', normalized_step_draft))
        assert len(matches) == len(self.proof)
        for (m, p) in reversed(list(zip(matches, self.proof))):
            normalized_step_draft = replace_span(m.span(), ':= by {\n' + '\n'.join('  ' + l for l in p.splitlines()) + '\n}', normalized_step_draft)
        return normalized_step_draft

def format_forward_solution_step_response(step: ProblemGenerationStep):
    step_type = 'Derive' if step.is_deducing else 'Introduce' if step.is_introducing else 'Submit'
    response = f'''# Step {step_type}
```lean4
{format_step(step).rstrip()}
```
'''.strip()
    return response

print(SYSTEM_PROMPT_FPG)

You are an Olympiad problem setter and a Lean 4 expert.
You revel in conjuring elegant problems — starting from a spare set of hypotheses, you let rigorous deduction lead you to surprising and beautiful conclusions.


In [29]:
data_problem_generation = []

for result in data_reassembled:
    result: ProblemGenerationProcess
    steps = result.steps

    introduced_fvars = []
    data_problem_generation_chunk = []
    is_success = True
    
    for i, (context_fvars, step_id) in enumerate(result.trajectory):
        step: ProblemGenerationStep = steps[step_id]
        step.step_draft = step.step_draft.replace(' :  ', ' : ')
        
        if step.is_deducing:
            idents = set(step.step.split())
            for banned_token in BANNED_TOKENS:
                if banned_token in idents:
                    if any(v.name == banned_token for v in context_fvars):
                        logger.warning(f'Banned token "{banned_token}" in step "{step.step}", but is also in context.')
                    else:
                        logger.error(f'Banned token "{banned_token}" in step "{step.step}"')
                        is_success = False
                        break
        else:
            idents = set(step.step.split())
            for banned_token in BANNED_TOKENS[1:]:
                if banned_token in idents:
                    if any(v.name == banned_token for v in context_fvars):
                        logger.warning(f'Banned token "{banned_token}" in step "{step.step}", but is also in context.')
                    else:
                        logger.error(f'Banned token "{banned_token}" in step "{step.step}"')
                        is_success = False
                        break
        
        data_problem_generation_chunk.append({
            "conversation":[
                {
                    "system": SYSTEM_PROMPT_FPG,
                    "input": format_forward_solution_step_prompt(result.metainfo['id'], introduced_fvars, context_fvars),
                    "output": format_forward_solution_step_response(step)
                }
            ]
        })

        if step.is_introducing:
            lines = step.step_draft.splitlines()
            while len(lines) > 0 and lines[0].split()[0] in ['open', 'set_option']:
                lines.pop(0)
            step_code = '\n'.join(lines)
            assert step_code.startswith('have ') and step_code.endswith(' := sorry')
            introduced_fvars.append(step_code[len('have '):-len(' := sorry')].strip())

        if step.is_submitting:
            assert i == len(result.trajectory) - 1
    
    if is_success:
        data_problem_generation.extend(data_problem_generation_chunk)
    

In [30]:
len(data_problem_generation)

490644

In [31]:
def print_xtuner_sample(data: list):
    if 'system' in data['conversation'][0].keys():
        print('<SYSTEM>')
        print(data['conversation'][0]['system'])
        print('</SYSTEM>')
    print('<INPUT>')
    print(data['conversation'][0]['input'])
    print('</INPUT>\n<OUTPUT>')
    print(data['conversation'][0]['output'])
    print('</OUTPUT>')


In [35]:
print_xtuner_sample(data_problem_generation[0])

<SYSTEM>
You are an Olympiad problem setter and a Lean 4 expert.
You revel in conjuring elegant problems — starting from a spare set of hypotheses, you let rigorous deduction lead you to surprising and beautiful conclusions.
</SYSTEM>
<INPUT>
Given the introduced variables/hypotheses and the current context in Lean 4, propose the single most natural next step to explore toward a beautiful conclusion — either
- derive a new intermediate fact,
- introduce a fresh variable or hypothesis, or
- submit one of the local facts as the final answer.

Requirements
1. Flavoured ['Algebra -> Intermediate Algebra -> Other', 'Applied Mathematics -> Other -> Other'] and of difficulty level 1.
2. Fully formal Lean 4 code (inline comments in natural language are fine for planning and reasoning). Assume `import Mathlib`.

# Introduced Variables/Hypotheses
```lean4

```

# Lean 4 Context
```lean4

```
</INPUT>
<OUTPUT>
# Step Derive
```lean4
open Real in
have h1 : sqrt (6 * sqrt (2 * sqrt 3)) > sqrt (3 * 

In [34]:
with open('/home/ma-user/workspace/formal_problem_generation/data/FineLeanCorpus/problem_generation_steps.reasseblmed.82438.jsonl', 'w') as f:
    for s in data_problem_generation:
        f.write(json.dumps(s) + '\n')

In [None]:
print