In [151]:
import pathlib
from collections import defaultdict
import json
from autoformalizer.eval_utils import lean_feedback
import pandas as pd
from tqdm import tqdm


In [170]:
cache_dir = pathlib.Path("/home/jia/auto_proofs_v2")

# list all jsonl
json_files = list(cache_dir.glob("*/*.json"))

In [171]:
len(json_files)

19963715

In [None]:
df = []
errors = []
for filepath in tqdm(json_files):
    with open(filepath, "r") as f:
        try:
            response = json.load(f)
        except json.JSONDecodeError:
            continue
    uuid = filepath.parent.name
    error_message = response.get("error", None)
    if error_message:
        errors.append(error_message)
        continue
    json_response = response.get("response", None)

    is_valid_no_sorry = (not bool(error_message)) and (
        not lean_feedback.has_error(json_response, accept_sorry=False)
    )
    connection_error = bool(error_message) and (
        "Lean process timed out" not in error_message
    )
    df.append(
        {
            "uuid": uuid,
            "is_valid_no_sorry": is_valid_no_sorry,
            "name": filepath.stem,
            "has_connection_error": bool(error_message),
        }
    )

In [None]:

df = pd.DataFrame(df)
print(df)

# calculate valid rate
valid_rate = df["is_valid_no_sorry"].sum() / len(df)
print(f"valid rate: {valid_rate}")

# connection error rate
connection_error_rate = df["has_connection_error"].sum() / len(df)
print(f"connection error rate: {connection_error_rate}")

# calculate valid rate for each uuid using groupby
uuid_group = df.groupby("uuid")["is_valid_no_sorry"].sum()

# find all uuids with at least one valid proof
valid_uuids = uuid_group[uuid_group > 0].index
print(f"Number of uuids: {len(uuid_group)}")
print(f"Number of uuids with at least one valid proof: {len(valid_uuids)}")


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19963715/19963715 [08:47<00:00, 37880.30it/s]


                                          uuid  is_valid_no_sorry  \
0         f3babb50-dfa3-560b-99ef-5652fca0ee4b              False   
1         f3babb50-dfa3-560b-99ef-5652fca0ee4b              False   
2         f3babb50-dfa3-560b-99ef-5652fca0ee4b              False   
3         f3babb50-dfa3-560b-99ef-5652fca0ee4b              False   
4         f3babb50-dfa3-560b-99ef-5652fca0ee4b              False   
...                                        ...                ...   
18700537  6e0e1956-b477-57a1-bb3a-912d4f0d48ce               True   
18700538  6e0e1956-b477-57a1-bb3a-912d4f0d48ce               True   
18700539  6e0e1956-b477-57a1-bb3a-912d4f0d48ce               True   
18700540  6e0e1956-b477-57a1-bb3a-912d4f0d48ce              False   
18700541  6e0e1956-b477-57a1-bb3a-912d4f0d48ce               True   

                                                name  has_connection_error  
0         f3babb50-dfa3-560b-99ef-5652fca0ee4b_4_218                 False  
1         f3babb5

In [174]:
know_errors = {"Connection Error", "Lean process timed out"}
for error in errors:
    if error not in know_errors:
        if not error.startswith("JSONDecodeError with text:"):
            print(error)
            break

In [175]:
from collections import Counter
Counter([error[:20] for error in errors])

Counter({'Connection Error': 1016880,
         'JSONDecodeError with': 185172,
         'Lean process timed o': 61085})

In [156]:
len(df), 1881 / 6499

(18471895, 0.2894291429450685)

In [206]:
df_valid = df[df["is_valid_no_sorry"]]
len(df_valid)

773447

In [207]:
sorted_uuid_group = uuid_group.sort_values()

# remove 0
sorted_uuid_group = sorted_uuid_group[sorted_uuid_group > 0]
sorted_uuid_group.iloc[20:40]

uuid
97af19ed-76f9-5788-b912-3771cd954dbc    1
75eddfd5-d0d4-5472-a82e-41190047456f    1
128bbf96-67d1-53e3-935c-7795728fa685    1
8e65843e-7a52-5612-b18c-7411e7286803    1
f05085fb-bd65-5232-baed-d0e94ae5f913    1
d28a6f66-f070-55aa-97d9-6635c7adfd74    1
5165c90c-324d-5723-b57b-72c097ece55e    1
59d104ef-7a23-563b-8981-e461d9d47d24    1
89e530ae-b353-50a2-b698-c09bb2685038    1
3531f921-32cd-5be6-9bcf-f4fe5fcc4862    1
05881b4c-5a10-5668-8076-001fda0ef3bc    1
7a3a59ac-0ca6-5fc6-a7c6-cdf29d30a218    1
4c1c84db-9018-5554-9c6b-67e5f40c98f3    1
6db3fc84-2894-54b1-bb35-e708f5640e95    1
2fbacd34-b454-5ff2-88ce-a05536fb211c    1
a942379a-9b7a-5323-98f4-4bf0fcc340f5    1
85ae472d-20ee-54b3-a30c-a7e2d18fe1ef    1
54570ca9-188e-5ef6-b784-9cbca2393e41    1
26b38fa9-579e-5a43-be21-ae11b8fa4b5e    1
c07304e5-c0e5-5749-a19d-99bf055a234e    1
Name: is_valid_no_sorry, dtype: int64

In [208]:
# load dataset
from datasets import load_dataset
ds = load_dataset("AI-MO/auto-proofs-v2-haiming", split="train", num_proc=10)

Resolving data files:   0%|          | 0/88 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/88 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/90 [00:00<?, ?it/s]

In [53]:
proof_id_column = ds["proof_id"]
proof_id_to_index = {pid: i for i, pid in enumerate(proof_id_column)}

In [86]:
sorted_uuid_group.index[10]

'3531f921-32cd-5be6-9bcf-f4fe5fcc4862'

In [415]:
uuid = sorted_uuid_group.index[530]
uuid_df = df[df["uuid"] == uuid]
print(len(uuid_df))
print(uuid)
# valid name
valid_names = uuid_df[uuid_df["is_valid_no_sorry"]]["name"]
print(len(valid_names.to_list()))

6144
8e328f0c-ec6d-5cb2-b9b6-ee00f2aa4af5
26


In [416]:
len(proof_id_to_index)

53619712

In [417]:
# proof_id = "f05085fb-bd65-5232-baed-d0e94ae5f913_6_960"
proof_id = valid_names.to_list()[0]
index = proof_id_to_index[proof_id]

In [418]:
sample = ds[index]
print(sample["formal_proof"])

import Mathlib

open Real

theorem algebra_1627 {f g : ℝ → ℝ}
    (hf : ∀ x, x > 0 → f x = 2 * sqrt x + 12 / sqrt x)
    (hg : ∀ x, g x = 2 * x ^ 2 - 2 * x - 3) :
    f (g 3) = 10 := by 
  have h₀ : g 3 = 2 * 3 ^ 2 - 2 * 3 - 3 := by apply hg
  have h₁ : g 3 = 9 := by linarith
  have h₂ : f (g 3) = f 9 := by rw [h₁]
  have h₃ : f 9 = 2 * sqrt 9 + 12 / sqrt 9 := by apply hf <;> linarith
  have h₄ : sqrt 9 = 3 := by rw [sqrt_eq_iff_mul_self_eq] <;> norm_num
  rw [h₄] at h₃
  linarith

