In [69]:
import sys
import os
sys.path.append(os.path.abspath('..'))
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import re

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from prover.lean.verifier import Lean4ServerScheduler
from typing import List, Dict, Tuple, Optional
import pickle
from prover.lean.verifier import verify_lean4_file
from prover.lean.verifier import Lean4ServerScheduler
from tqdm import tqdm


In [4]:
with open("results/best_of_n_samples_imo.pkl", "rb") as f:
    imo_problems = pickle.load(f)

for question in imo_problems:
    print(question[0])
    print("--------------------------------")

with open("results/verified_outputs_imo.pkl", "rb") as f:
    verified_outputs_imo = pickle.load(f)

Complete the following Lean 4 code:

```lean4
import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- Find one pair of positive integers $a,b$ such that $ab(a+b)$ is not divisible by $7$, but $(a+b)^7-a^7-b^7$ is divisible by $7^7$.-/
theorem imo_1984_p2 (a b : ℤ) (h₀ : 0 < a ∧ 0 < b) (h₁ : ¬7 ∣ a) (h₂ : ¬7 ∣ b) (h₃ : ¬7 ∣ a + b)
  (h₄ : 7 ^ 7 ∣ (a + b) ^ 7 - a ^ 7 - b ^ 7) : 19 ≤ a + b := by
  have := h₀.2
  norm_num at h₁ h₂ h₃ h₄
  contrapose! h₄
  have h₅ : a + b < 19 := by linarith
  have h₆ : (a + b) ^ 7 - a ^ 7 - b ^ 7 < 7 ^ 7 := by
    calc
      (a + b) ^ 7 - a ^ 7 - b ^ 7 ≤ (a + b) ^ 7 - a ^ 7 - b ^ 7 := by rfl
      _ < 19 ^ 7 - 19 ^ 7 - 19 ^ 7 := by gcongr
      _ = 7 ^ 7 := by norm_num
  omega
```
--------------------------------
Complete the following Lean 4 code:

```lean4
import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- Show that for any real numbers $a$, $b$, and $c$, we 

In [3]:
imo_problem = [re.search(r'```lean4\n(.*?)\n```', attempt, re.DOTALL).group(1) for attempt in imo_problems[4]]

verified_outputs = []
for attempt in tqdm(imo_problem):
    verified_outputs.append(
        verify_lean4_file(
            attempt, 
            timeout=50, 
            allTactics=True, 
            ast=True, 
            premises=True, 
            tactics=True
        )
    )

with open("results/verified_outputs_imo_problem_4.pkl", "wb") as f:
    pickle.dump(verified_outputs, f)



100%|██████████| 32/32 [06:29<00:00, 12.18s/it]


In [49]:

def collect_proof_states(verifier_output: Dict):

    ast = verifier_output["ast"]
    states = []
    for tactic in ast["tactics"]:
        for key in ["stateBefore", "stateAfter"]:
            if tactic[key] not in states:
                states.append(tactic[key])

    # Get initial state's premises and goal
    initial_state = states[0]
    initial_premises = []
    initial_goal = None
    
    # Parse initial state
    for line in initial_state.split('\n'):
        if '⊢' in line:
            if initial_goal is not None:
                # We already found a goal
                raise ValueError("Multiple goals found in initial state")
            initial_goal = line.split('⊢')[1].strip()
        elif ':' in line and line.strip():  # Premise line
            premise = line.strip()
            if premise not in initial_premises:
                initial_premises.append(premise)
    
    new_premises = []
    new_goals = []
    
    initial_premises_set = set(initial_premises)
    # Process each state after the initial one
    for state in states[1:]:

        current_premises = set()
        current_goals = []
        
        # Handle cases by splitting on 'case'
        cases = state.split('case')
        
        for case in cases:
            lines = case.strip().split('\n')
            case_goal = None
            
            for line in lines:
                if '⊢' in line:
                    case_goal = line.split('⊢')[1].strip()
                    current_goals.append(case_goal)
                elif ':' in line and line.strip():  # Premise line
                    current_premises.add(line.strip())
            
        new_goals.append(tuple(current_goals))
        new_premises.extend(list(current_premises))
    
    new_premises = set(new_premises) - initial_premises_set

    new_goals = set(new_goals) - set([(initial_goal,)])

    return new_premises, new_goals, initial_premises, initial_goal


class Theorem:
    def __init__(
        self,
        header: str,
        premises: Optional[List[str]] = None, 
        goal: Optional[str] = None,
        name: Optional[str] = None
    ):

        if premises is None:
            self.premises = []
        else:
            self.premises = premises.copy()

        self.goal = goal
        self.name = name
        self.header = header

    def add_premise(self, premise: str):
        self.premises.append(premise)

    def set_goal(self, goal: str):
        self.goal = goal

    def to_string(self):
        if len(self.premises) == 0:
            raise ValueError("No premises")
        elif self.goal is None:
            raise ValueError("No goal")

        name = self.name if self.name is not None else "problem_384"
        stub = f"theorem {name}"
        
        for premise in self.premises:
            stub += f" ({premise})"

        stub += f" : {self.goal} := by\n"

        return self.header + stub

    def __eq__(self, other: 'Theorem') -> bool:
        if not isinstance(other, Theorem):
            return False
        
        # Compare goals
        if self.goal != other.goal:
            return False
            
        # Compare premises (order-independent)
        return set(self.premises) == set(other.premises)
    
    def __hash__(self) -> int:
        # Hash based on goal and frozen set of premises
        return hash((self.goal, frozenset(self.premises)))

def generate_theorem(
    new_premises: List[str], 
    new_goals: List[Tuple[str]], 
    original_goal: str, 
    original_premises: List[str],
    theorem_header: str,
    theorem_name: Optional[str] = None,
) -> Tuple[List[Theorem], List[Theorem]]:

    def find_unused_subscript(premises: List[str]) -> str:
        """Find an unused subscript for a goal."""
        subscripts = ["h\u2080", "h\u2081", "h\u2082", "h\u2083", "h\u2084", "h\u2085", "h\u2086", "h\u2087", "h\u2088", "h\u2089", "h\u2081\u2080", "h\u2081\u2081", "h\u2081\u2082", "h\u2081\u2083", "h\u2081\u2084", "h\u2081\u2085", "h\u2081\u2086", "h\u2081\u2087", "h\u2081\u2088", "h\u2081\u2089"]

        for sub in subscripts:
            for premise in premises:
                if sub in premise:
                    break
            else:
                # Not in any premise
                return sub
        else:
            raise ValueError("No unused subscript found")


    new_premise_theorems = []
    new_goal_theorems = []

    assert(len(original_premises) <= 6), f"Too many premises: {original_premises}"
    # For each premise in premises, we will set this up as a new goal
    for premise in new_premises:
        # Turn premise into a goal 
        assert premise.count(":") == 1, f"Premise '{premise}' should contain exactly one colon"
        premise_goal = premise.split(":")[1].strip()
        theorem = Theorem(
            header=theorem_header,
            premises=original_premises,
            goal=premise_goal,
            name=theorem_name
        )

        new_premise_theorems.append(theorem)

    # This mean we simplified the problem to solving this goal
    # So if we assume these goals, we should be able to solve the original goal
    for goal in new_goals:
        # Turn goal into a premise, to do so we need to give it a name
        premises = original_premises.copy()
        for goal in goal:
            premise_name = find_unused_subscript(premises)
            premises.append(f"{premise_name} : {goal}")
         
        theorem = Theorem(
            header=theorem_header,
            premises=premises,
            goal=original_goal,
            name=theorem_name
        )
        new_goal_theorems.append(theorem)

    return new_premise_theorems, new_goal_theorems


# Test deduplication of theorems
def test_theorem_deduplication():
    # Create three theorems
    theorem1 = Theorem(
        header="import Mathlib.Data.Real.Basic",
        premises=["h1 : x > 0", "h2 : y > 0"],
        goal="x + y > 0"
    )
    
    # Same as theorem1 but with premises in different order
    theorem2 = Theorem(
        header="import Mathlib.Data.Real.Basic",
        premises=["h2 : y > 0", "h1 : x > 0"],  # Reversed order
        goal="x + y > 0"
    )
    
    # Different theorem
    theorem3 = Theorem(
        header="import Mathlib.Data.Real.Basic",
        premises=["h1 : x > 0", "h2 : y > 0"],
        goal="x * y > 0"  # Different goal
    )
    
    # Create a list and convert to set
    theorems = [theorem1, theorem2, theorem3]
    unique_theorems = set(theorems)
    
    # Print results
    print(f"Original number of theorems: {len(theorems)}")
    print(f"Number of unique theorems: {len(unique_theorems)}")
    
    # Test individual equality
    print(f"\nTheorem 1 == Theorem 2: {theorem1 == theorem2}")  # Should be True
    print(f"Theorem 1 == Theorem 3: {theorem1 == theorem3}")  # Should be False
    
    # Print the actual theorems to verify
    print("\nTheorem 1:")
    print(theorem1.to_string())
    print("\nTheorem 2:")
    print(theorem2.to_string())
    print("\nTheorem 3:")
    print(theorem3.to_string())

# Run the test
test_theorem_deduplication()

Original number of theorems: 3
Number of unique theorems: 2

Theorem 1 == Theorem 2: True
Theorem 1 == Theorem 3: False

Theorem 1:
import Mathlib.Data.Real.Basictheorem problem_384 (h1 : x > 0) (h2 : y > 0) : x + y > 0 := by


Theorem 2:
import Mathlib.Data.Real.Basictheorem problem_384 (h2 : y > 0) (h1 : x > 0) : x + y > 0 := by


Theorem 3:
import Mathlib.Data.Real.Basictheorem problem_384 (h1 : x > 0) (h2 : y > 0) : x * y > 0 := by



In [98]:
new_premise_theorems, new_goal_theorems = generate_theorem(
    new_premises,
    new_goals,
    initial_goal,
    initial_premises,
    theorem_header = out["verified_code"].split("/-")[0]
)

for theorem in new_premise_theorems:
    print(theorem.to_string())
    print("--------------------------------")
    print()

print("============")
for theorem in new_goal_theorems:
    theorem_string = theorem.to_string()
    print(theorem.to_string())
    print()



import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat


theorem problem_384: (x : ℝ)  (h₀ : 0 ≤ x)  (h₁ : x ≤ 2 * π)  (h₂ : 2 * x.cos ≤ |√(1 + (2 * x).sin) - √(1 - (2 * x).sin)|)  (h₃ : |√(1 + (2 * x).sin) - √(1 - (2 * x).sin)| ≤ √2)  (h₄ : π / 4 ≤ x)  (h₅ : x ≤ 7 * π / 4)  (h₆ : π / 4 ≤ x)  (h₇ : x ≤ 7 * π / 4)  (h₈ : π / 4 ≤ x)  (h₉ : x ≤ 7 * π / 4)  : -(√(1 + (2 * x).sin) - √(1 - (2 * x).sin)) ≤ √2 := by
--------------------------------

import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat


theorem problem_384: (x : ℝ)  (h₀ : 0 ≤ x)  (h₁ : x ≤ 2 * π)  (h₂ : 2 * x.cos ≤ |√(1 + (2 * x).sin) - √(1 - (2 * x).sin)|)  (h₃ : |√(1 + (2 * x).sin) - √(1 - (2 * x).sin)| ≤ √2)  (h₄ : π / 4 ≤ x)  (h₅ : x ≤ 7 * π / 4)  (h₆ : π / 4 ≤ x)  (h₇ : x ≤ 7 * π / 4)  (h₈ : π / 4 ≤ x)  (h₉ : x ≤ 7 * π / 4)  : √(1 + (2 * x).sin) - √(1 - (2 * x).sin) ≤ √2 := by
--------------------------------

import Mathlib
import Aesop

# Prossessing the attempts

In [60]:
with open("results/verified_outputs_imo_problem_4.pkl", "rb") as f:
    verified_outputs_imo = pickle.load(f)

# Go through the attempts
# For each one, get the reward
# Reformat it and add the reformats to a list
# Dedpuplicate the reformats
# See how many rewrites we get from this 

# Then generate 5 proofs per reformat and see if we get any reward
# Collect the same information for these

def rewrite_problems(
    original_attempt_verifications: List[Dict]
) -> Tuple[List[Theorem], List[Theorem]]:
    """Rewrite the attempts and grade them"""

    original_theorems: List[Theorem] = []
    new_premise_theorems: List[Theorem] = []
    new_goal_theorems: List[Theorem] = []
    num_rewrites = []
    # Get the reward for each attempt
    for attempt in original_attempt_verifications:

        header = attempt["verified_code"].split("/-")[0]

        new_premises, new_goals, initial_premises, initial_goal = collect_proof_states(attempt)

        new_p, new_g = generate_theorem(
            new_premises=new_premises,
            new_goals=new_goals,
            original_goal=initial_goal,
            original_premises=initial_premises,
            theorem_header = header
        )

        new_premise_theorems.extend(new_p)
        new_goal_theorems.extend(new_g)

        unique_rewrites = set(new_premise_theorems) | set(new_goal_theorems)
        num_rewrites.append(len(unique_rewrites))

    print("Original number of samples:", len(original_attempt_verifications))
    # Depublicate the theorems
    print("Before de-duping:", len(new_premise_theorems) + len(new_goal_theorems))
    new_premise_theorems = list(set(new_premise_theorems))
    new_goal_theorems = list(set(new_goal_theorems))
    print("After de-duping:", len(new_premise_theorems) + len(new_goal_theorems))
    print(f"Rewrites: {num_rewrites}")


    return new_premise_theorems, new_goal_theorems


def generate_proofs(
    premise_theorems: List[Theorem],
    goal_theorems: List[Theorem],
    model: LLM,
    num_proofs: int = 5
) -> List[Theorem]:

    theorems_dict = {
        "premise": premise_theorems,
        "goal": goal_theorems
    }

    sampling_params = SamplingParams(
        temperature=1.0,
        max_tokens=2048,
        top_p=0.95,
        n=num_proofs,
    )

    prompt = r'''Complete the following Lean 4 code:

```lean4
''' 

    all_responses = []

    for key in ["premise", "goal"]:
        theorems: List[str] = [x.to_string() for x in theorems_dict[key]]

        print(f"Generating {key} theorem proofs")
        for theorem in tqdm(theorems):
            # Construct full prompt for this theorem
            full_prompt = prompt + theorem
            
            # Generate responses
            outputs = model.generate(
                prompts=[full_prompt],
                sampling_params=sampling_params,
                use_tqdm=False,
            )
            
            generated_texts = [
                completion.text 
                for request in outputs
                for completion in request.outputs
            ]
            assert len(outputs) == 1
            assert len(generated_texts) == num_proofs
            all_responses.append({
                'type': key,
                'generation': [full_prompt + x for x in generated_texts]
            })

    return all_responses

    

In [19]:
model = LLM(model="../../models/deepseek-prover-RL", max_num_batched_tokens=8192, seed=1, trust_remote_code=True, dtype="auto")

INFO 02-10 20:47:15 llm_engine.py:98] Initializing an LLM engine (v0.4.1) with config: model='../../models/deepseek-prover-RL', speculative_config=None, tokenizer='../../models/deepseek-prover-RL', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=1)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 02-10 20:47:15 utils.py:608] Found nccl from library /home/lukebailey/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 02-10 20:47:16 selector.py:77] Cannot use FlashAttention backend because the flash_attn package is not found. Please install it for better performance.
INFO 02-10 20:47:16 selector.py:33] Using XFormers backend.
INFO 02-10 20:47:26 model_runner.py:173] Loading model weights took 12.8725 GB
INFO 02-10 20:47:28 gpu_executor.py:119] # GPU blocks: 3858, # CPU blocks: 546
INFO 02-10 20:47:30 model_runner.py:976] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 02-10 20:47:30 model_runner.py:980] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory

In [61]:
new_premise_theorems, new_goal_theorems = rewrite_problems(verified_outputs_imo)

outputs = generate_proofs(
    premise_theorems=new_premise_theorems, 
    goal_theorems=new_goal_theorems,
    model=model,
    num_proofs=5
)

Original number of samples: 32
Before de-duping: 208
After de-duping: 83
Rewrites: [4, 21, 21, 22, 22, 23, 33, 37, 37, 37, 40, 40, 41, 41, 41, 41, 46, 46, 57, 57, 57, 57, 57, 62, 65, 67, 67, 67, 67, 70, 70, 82]
Generating premise theorem proofs


100%|██████████| 42/42 [03:10<00:00,  4.55s/it]


Generating goal theorem proofs


100%|██████████| 41/41 [05:14<00:00,  7.66s/it]


In [62]:
with open("results/her_imo_4_outputs.pkl", "wb") as f:
    pickle.dump(outputs, f)

In [68]:
premise_theorems = []
goal_theorems = []
for output in outputs:
    if output["type"] == "premise":
        premise_theorems.extend(output["generation"])
    else:
        goal_theorems.extend(output["generation"])
print(len(premise_theorems))
print(len(goal_theorems))

210
205


In [70]:
!gpustat

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1m[37mrnn.ist.berkeley.edu[m  Mon Feb 10 21:40:28 2025  [1m[30m535.104.12[m
[36m[0][m [34mNVIDIA RTX A6000[m |[31m 28'C[m, [32m  0 %[m | [36m[1m[33m  469[m / [33m49140[m MB |
[36m[1][m [34mNVIDIA RTX A6000[m |[31m 31'C[m, [32m  0 %[m | [36m[1m[33m44268[m / [33m49140[m MB | [1m[30mlukebailey[m([33m43782M[m)
[36m[2][m [34mNVIDIA RTX A6000[m |[31m 31'C[m, [32m  0 %[m | [36m[1m[33m  466[m / [33m49140[m MB |
[36m[3][m [34mNVIDIA RTX A6000[m |[31m 30'C[m, [32m  0 %[m | [36m[1m[33m  466[m / [33m49140[m MB |
[36m[4][m [34mNVIDIA RTX A6000[m |[31m 32'C[m, [32m  0 %[m | [36m[1m[33m 5824[m / [33m49140[m MB | [1m[30mmicah[m([33m5352M[m)
[36m[5][m [34mNVIDIA RTX A6000[m |[31m 32'C[m, [32m  0 %[m | [36m[1m[33m 6164[m / [33m49140[m MB | [1m[30mmicah[m([33m5692M[m)
[36m[6][m [34mNVIDIA RTX A6000[m |[31m 30'C[m, [32m  0 %[m | [36m[1m[33m  466[m / [33m49140[m MB |
[36m[7][m [34mNVID