In [9]:
import json, random, re, os, glob, csv, hashlib
import os.path as osp
from typing import List, Tuple, Any, Dict, Optional
from rouge_score import rouge_scorer

import multiprocessing as mp
from tqdm import tqdm
import numpy as np

In [3]:
_GLOBAL_TEXTS = None
_GLOBAL_METRIC = None
_GLOBAL_SCORER: rouge_scorer.RougeScorer = None

def _mp_init(metric: str, use_stemmer: bool, texts: list[str]):
    """Pool ÂàùÂßãÂåñÔºöÂú®Â≠êËøõÁ®ãÈáåÊûÑÈÄ†‰∏ÄÊ¨° RougeScorerÔºåÂπ∂Â≠òÂÖ•ÂÖ®Â±Ä„ÄÇ"""
    global _GLOBAL_TEXTS, _GLOBAL_METRIC, _GLOBAL_SCORER
    _GLOBAL_TEXTS = texts
    _GLOBAL_METRIC = metric
    _GLOBAL_SCORER = rouge_scorer.RougeScorer([metric], use_stemmer=use_stemmer)

In [4]:
def _score_pair_ij(pair: tuple[int, int]) -> float:
    """Â≠êËøõÁ®ãÔºöËÆ°ÁÆó‰∏ÄÂØπ (i, j) ÁöÑ ROUGE-L/ROUGE-Lsum F1"""
    i, j = pair
    s = _GLOBAL_SCORER.score(_GLOBAL_TEXTS[i], _GLOBAL_TEXTS[j])[_GLOBAL_METRIC]
    return s.fmeasure

In [5]:
jsonl_path="sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch-valid_samples.jsonl"
field="formal_statement"


In [6]:
data_root = '/sfs/liuqi/data/fpg_valid_fixed_evaluated/'

print('exp_name', 'mean', 'std', sep='\t')
for exp_name in [
    'autoformalization_pg_kimina7b-PromptCoT-DS_kimina7b-valid_samples.jsonl',
    'autoformalization_pg_kimina7b-PromptCoT-QwQ_kimina7b-valid_samples.jsonl',
    'autoformalization_pg_kimina7b-ScaleQuest-Math_kimina7b-valid_samples.jsonl',
    'MUSTARDSauce_lean4_parsed-valid_samples.jsonl',
    'sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch.0913.staged-valid_samples.jsonl',
    'sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch-valid_samples.jsonl',
    'sft_wg_starified-Goedel-Prover-V2-8B.Numina-Lean.whole_statement_generatior.nopack-valid_samples.jsonl',
]:
    jsonl_path: str = osp.join(data_root, exp_name)
    # field="formal_statement"
    metric: str = "rougeL"          # "rougeL" Êàñ "rougeLsum"
    rounds: int = 3
    sample_size: int = 100
    use_stemmer: bool = True       # Ëã±ÊñáÂª∫ËÆÆ True
    seed: int | None = 42
    ensure_unique_rounds: bool = False
    num_workers: int = 64            # 0/1 = ÂçïËøõÁ®ãÔºõ>1 ‰ΩøÁî®Â§öËøõÁ®ã
    # mp_chunksize: int = 256          # map ÁöÑ‰ªªÂä°ÂàÜÂùóÂ§ßÂ∞è

    assert metric in ("rougeL", "rougeLsum")

    # ËØªÂèñÊñáÊú¨
    texts: List[str] = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        data = [json.loads(l) for l in f.readlines()]
        texts = [d['formal_statement'] + d['formal_proof'] for d in data]
        # for d in data:
        #     assert d['formal_statement'].strip().endswith(':= sorry')
        #     texts.append(
        #         d['formal_statement'].strip()[-len(':= sorry')] + ':= by\n' + d['formal_proof']
        #     )

    n = len(texts)
    if n == 0:
        raise

    k = min(sample_size, n)
    rnd = random.Random(seed) if seed is not None else random.Random()
    # ÂçïËøõÁ®ãÂ§áÁî® scorer
    single_scorer = rouge_scorer.RougeScorer([metric], use_stemmer=use_stemmer)

    per_round: List[float] = []
    seen_rounds = set()

    for i_round in range(rounds):
        # print(f'Round {i_round}')
        idxs = rnd.sample(range(n), k)
        if ensure_unique_rounds:
            tries = 0
            key = tuple(sorted(idxs))
            while key in seen_rounds and tries < 5:
                idxs = rnd.sample(range(n), k)
                key = tuple(sorted(idxs)); tries += 1
            seen_rounds.add(key)

        # üîÅ Êú¨ËΩÆÊâÄÊúâ (i, j)Ôºài Âú®ÊäΩÊ†∑ÈõÜÂêàÔºåj ÈÅçÂéÜÂÖ®ÈõÜ‰∏î j!=iÔºâ
        # ‚úÖ ‰∏∫‰∫ÜÁúÅÂÜÖÂ≠òÔºåÁî®ÁîüÊàêÂô®ËÄå‰∏çÊòØÊääÊâÄÊúâ pair ÊîæËøõÂ§ßÂàóË°®
        pairs_iter = ((i, j) for i in idxs for j in idxs)

        if num_workers and num_workers > 1:
            # ‚úÖ Â§öËøõÁ®ãÔºöÊØè‰∏™ËøõÁ®ãÈáåÂàùÂßãÂåñ‰∏ÄÊ¨° scorer & texts
            with mp.Pool(
                processes=50,
                initializer=_mp_init,
                initargs=(metric, use_stemmer, texts)
            ) as pool:
                # Áî® imap_unordered ÊãâÊµÅÂºèÂèñÂõûÁªìÊûúÔºåÈÅøÂÖç‰∏ÄÊ¨°ÊÄß materialize
                f1s_iter = pool.map(_score_pair_ij, pairs_iter, chunksize=200)
                pair_f1s = list(f1s_iter)
        else:
        # ÂçïËøõÁ®ãÂêéÂ§á
            pair_f1s = []
            for (i, j) in pairs_iter:
                s = single_scorer.score(texts[i], texts[j])[metric]
                pair_f1s.append(s.fmeasure)

        per_round.append(sum(pair_f1s) / len(pair_f1s) if pair_f1s else float("nan"))
    # break

    # overall = sum(per_round)/len(per_round) if per_round else float("nan")
    print(exp_name, np.mean(per_round), np.std(per_round), sep='\t')


exp_name	mean	std
autoformalization_pg_kimina7b-PromptCoT-DS_kimina7b-valid_samples.jsonl	0.17490867375527755	0.004456541693394542
autoformalization_pg_kimina7b-PromptCoT-QwQ_kimina7b-valid_samples.jsonl	0.1649247955218178	0.0012984803563522314
autoformalization_pg_kimina7b-ScaleQuest-Math_kimina7b-valid_samples.jsonl	0.1730722285964148	0.0023255072629155067
MUSTARDSauce_lean4_parsed-valid_samples.jsonl	0.1901680421417211	0.0015621851996097343
sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch.0913.staged-valid_samples.jsonl	0.17445865763989446	0.004133605225044185
sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch-valid_samples.jsonl	0.17618002415371561	0.006359499727534762
sft_wg_starified-Goedel-Prover-V2-8B.Numina-Lean.whole_statement_generatior.nopack-valid_samples.jsonl	0.162224598557159	0.0077067709157789014


In [20]:
def format_informalization_response(
    problem_type: str,
    informal_problem: str,
    informal_solution: str,
    informal_answer: Optional[str]
) -> str:
    # assert all(split not in field
    #     for split in [f'## Problem-Solving Question', '## Proof Question', '## Answer', '## Solution', '## Proof']
    #     for field in [problem_type, informal_problem, informal_solution, (informal_answer or '')]
    # )
    response = f'## {problem_type.strip()}\n{informal_problem.strip()}\n\n'
    if problem_type == 'Problem-Solving Question':
        assert informal_answer is not None
        response += f'## Answer\n{informal_answer.strip()}\n\n'
        response += f'## Solution\n{informal_solution.strip()}\n\n'
    elif problem_type == 'Proof Question':
        response += f'## Proof\n{informal_solution.strip()}\n\n'
    else:
        raise ValueError(f'Invalid problem_type: "{problem_type}"')
    return response

def format_informal_code(d_orig: Dict) -> str:
    d = d_orig['informalization']
    return format_informalization_response(problem_type=d['problem_type'], informal_problem=(d.get('informal_problem') or d.get('informal_statement')), informal_solution=d.get('informal_solution') or d.get('informal_proof'), informal_answer=d.get('informal_answer')).strip()

In [None]:
data_root = '/sfs/liuqi/data/fpg_valid_fixed_evaluated/'

print('exp_name', 'mean', 'std', sep='\t')
for exp_name in [
    'autoformalization_pg_kimina7b-PromptCoT-DS_kimina7b-valid_samples.jsonl',
    'autoformalization_pg_kimina7b-PromptCoT-QwQ_kimina7b-valid_samples.jsonl',
    'autoformalization_pg_kimina7b-ScaleQuest-Math_kimina7b-valid_samples.jsonl',
    'MUSTARDSauce_lean4_parsed-valid_samples.jsonl',
    'sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch.0913.staged-valid_samples.jsonl',
    'sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch-valid_samples.jsonl',
    'sft_wg_starified-Goedel-Prover-V2-8B.Numina-Lean.whole_statement_generatior.nopack-valid_samples.jsonl',
]:
    jsonl_path: str = osp.join(data_root, exp_name)
    # field="formal_statement"
    metric: str = "rougeL"          # "rougeL" Êàñ "rougeLsum"
    rounds: int = 10
    sample_size: int = 100
    use_stemmer: bool = True       # Ëã±ÊñáÂª∫ËÆÆ True
    seed: int | None = 42
    ensure_unique_rounds: bool = False
    num_workers: int = 64            # 0/1 = ÂçïËøõÁ®ãÔºõ>1 ‰ΩøÁî®Â§öËøõÁ®ã
    # mp_chunksize: int = 256          # map ÁöÑ‰ªªÂä°ÂàÜÂùóÂ§ßÂ∞è

    assert metric in ("rougeL", "rougeLsum")

    # ËØªÂèñÊñáÊú¨
    texts: List[str] = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        data = [json.loads(l) for l in f.readlines()]
        texts = [format_informal_code(d) for d in data]
        # for d in data:
        #     assert d['formal_statement'].strip().endswith(':= sorry')
        #     texts.append(
        #         d['formal_statement'].strip()[-len(':= sorry')] + ':= by\n' + d['formal_proof']
        #     )

    n = len(texts)
    if n == 0:
        raise

    k = min(sample_size, n)
    rnd = random.Random(seed) if seed is not None else random.Random()
    # ÂçïËøõÁ®ãÂ§áÁî® scorer
    single_scorer = rouge_scorer.RougeScorer([metric], use_stemmer=use_stemmer)

    per_round: List[float] = []
    seen_rounds = set()

    for i_round in range(rounds):
        # print(f'Round {i_round}')
        idxs = rnd.sample(range(n), k)
        if ensure_unique_rounds:
            tries = 0
            key = tuple(sorted(idxs))
            while key in seen_rounds and tries < 5:
                idxs = rnd.sample(range(n), k)
                key = tuple(sorted(idxs)); tries += 1
            seen_rounds.add(key)

        # üîÅ Êú¨ËΩÆÊâÄÊúâ (i, j)Ôºài Âú®ÊäΩÊ†∑ÈõÜÂêàÔºåj ÈÅçÂéÜÂÖ®ÈõÜ‰∏î j!=iÔºâ
        # ‚úÖ ‰∏∫‰∫ÜÁúÅÂÜÖÂ≠òÔºåÁî®ÁîüÊàêÂô®ËÄå‰∏çÊòØÊääÊâÄÊúâ pair ÊîæËøõÂ§ßÂàóË°®
        pairs_iter = ((i, j) for i in idxs for j in idxs)

        if num_workers and num_workers > 1:
            # ‚úÖ Â§öËøõÁ®ãÔºöÊØè‰∏™ËøõÁ®ãÈáåÂàùÂßãÂåñ‰∏ÄÊ¨° scorer & texts
            with mp.Pool(
                processes=50,
                initializer=_mp_init,
                initargs=(metric, use_stemmer, texts)
            ) as pool:
                # Áî® imap_unordered ÊãâÊµÅÂºèÂèñÂõûÁªìÊûúÔºåÈÅøÂÖç‰∏ÄÊ¨°ÊÄß materialize
                f1s_iter = pool.map(_score_pair_ij, pairs_iter, chunksize=200)
                pair_f1s = list(f1s_iter)
        else:
        # ÂçïËøõÁ®ãÂêéÂ§á
            pair_f1s = []
            for (i, j) in pairs_iter:
                s = single_scorer.score(texts[i], texts[j])[metric]
                pair_f1s.append(s.fmeasure)

        per_round.append(sum(pair_f1s) / len(pair_f1s) if pair_f1s else float("nan"))
    # break

    # overall = sum(per_round)/len(per_round) if per_round else float("nan")
    print(exp_name, np.mean(per_round), np.std(per_round), sep='\t')


exp_name	mean	std
autoformalization_pg_kimina7b-PromptCoT-DS_kimina7b-valid_samples.jsonl	0.19369023963894133	0.00315670448712554
autoformalization_pg_kimina7b-PromptCoT-QwQ_kimina7b-valid_samples.jsonl	0.1941728990306546	0.004576968791866816
autoformalization_pg_kimina7b-ScaleQuest-Math_kimina7b-valid_samples.jsonl	0.2196545764288267	0.008478152781527581
MUSTARDSauce_lean4_parsed-valid_samples.jsonl	0.21061036473314515	0.006448873287891353
sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch.0913.staged-valid_samples.jsonl	0.17681872342548774	0.0067182980520842425
sft_ar_v3-Goedel-Prover-V2-8B.Numina-Lean-reasseblmed.39509.problem_generator.nopack.3epoch-valid_samples.jsonl	0.17918771777939496	0.004235547805905238
sft_wg_starified-Goedel-Prover-V2-8B.Numina-Lean.whole_statement_generatior.nopack-valid_samples.jsonl	0.18394956324510553	0.0071242939839667625


In [13]:
for d in data:
    format_informal_code(d)

AssertionError: 