In [1]:
import re

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


model_name = "deepseek-ai/DeepSeek-Prover-V1.5-RL"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LLM(model=model_name, max_num_batched_tokens=8192, seed=1, trust_remote_code=True)


prompt = r'''Complete the following Lean 4 code:

```lean4
'''

code_prefix = r'''import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?
Show that it is $\frac{2\sqrt{3}}{3}$.-/
theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)
  (h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by
'''

sampling_params = SamplingParams(
    temperature=1.0,
    max_tokens=2048,
    top_p=0.95,
    n=1,
)
model_inputs = [prompt + code_prefix]
model_outputs = model.generate(
    model_inputs,
    sampling_params,
    use_tqdm=False,
)
result = prompt + code_prefix + model_outputs[0].outputs[0].text
print(result)

# Expected output:
'''  simp_all only [Nat.one_eq_succ_zero, Nat.zero_eq, zero_add, Nat.add_succ, Nat.add_zero,
    Nat.succ_add]
  have h₁' : a * r = 2 := by simpa [h₀] using h₁
  have h₂' : a * r ^ 3 = 6 := by simpa [h₀] using h₂
  have h₃ : r ^ 2 = 3 := by
    nlinarith
  have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by
    apply eq_or_eq_neg_of_sq_eq_sq <;>
    field_simp <;>
    nlinarith
  simpa [h₀] using h₄
```
'''

  from .autonotebook import tqdm as notebook_tqdm
2025-02-13 13:55:28,357	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 02-13 13:55:28 llm_engine.py:98] Initializing an LLM engine (v0.4.1) with config: model='deepseek-ai/DeepSeek-Prover-V1.5-RL', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-Prover-V1.5-RL', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=1)
INFO 02-13 13:55:29 utils.py:608] Found nccl from library /home/awhe/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 02-13 13:55:29 selector.py:28] Using FlashAttention backend.
INFO 02-13 13:55:30 weight_utils.py:193] Using model weights format ['*.safetensors']
INFO 02-13 13:57:37 model_runner.py:173] Loading model weights took 12.8725 GB
INFO 02-13 13:57:38 

"  simp_all only [Nat.one_eq_succ_zero, Nat.zero_eq, zero_add, Nat.add_succ, Nat.add_zero,\n    Nat.succ_add]\n  have h₁' : a * r = 2 := by simpa [h₀] using h₁\n  have h₂' : a * r ^ 3 = 6 := by simpa [h₀] using h₂\n  have h₃ : r ^ 2 = 3 := by\n    nlinarith\n  have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by\n    apply eq_or_eq_neg_of_sq_eq_sq <;>\n    field_simp <;>\n    nlinarith\n  simpa [h₀] using h₄\n```\n"

In [1]:
import threading
import pexpect
import json
import os
import tempfile
import re

SUBMIT_TIMEOUT = 50


class InteractiveThread(threading.Thread):
    def __init__(
        self, session_id, repl_path, lean_env_path, initial_context=None, timeout=600
    ):
        super().__init__()
        self.session_id = session_id
        self.repl_path = os.path.abspath(repl_path)
        self.lean_env_path = os.path.abspath(lean_env_path)
        self.context = initial_context
        self.session = None

        self.cmd_response_condition = threading.Event()
        self.cmd_query_condition = threading.Event()
        self.init_complete = threading.Event()
        self.response = None
        self.timeout_occurred = False  # New flag to track timeouts

        self.stop_flag = False
        self.timer = threading.Timer(timeout, self.stop)

    def initialize_check(self):
        try:
            if self.context is None:
                initialize_check = {"cmd": "def init_check : Nat := 42"}
                self.send_cmd(initialize_check)
            self.session.expect(
                '"env": 0}\r\n\r\n', timeout=60
            )  # If context contains sorries, it will have more keys other than env
            self.init_complete.set()
        except:
            self.init_complete.set()
            print(f"Session {self.session_id}: fail to initialize lean repl")
            print(self.context)
            print(self.session.before)
            self.stop()

    def send_cmd(self, cmd):
        cmd_str = json.dumps(cmd, ensure_ascii=False)
        self.session.sendline(cmd_str + "\n")

    def submit_and_receive(self, cmd):
        if self.stop_flag or self.timeout_occurred:  # Detect if the session has died
            return {"error": "timeout"}

        self.init_complete.wait()

        self.send_cmd(cmd)

        self.cmd_query_condition.set()

        self.cmd_response_condition.wait()
        # if not self.cmd_response_condition.wait(timeout=SUBMIT_TIMEOUT):  # Prevent indefinite waiting
        #     print("Command timeout detected, stopping session.")
        #     self.timeout_occurred = True  # Mark session as dead
        #     self.response = {"error": "timeout"}  # Set special response
        #     self.stop()  # Stop the session immediately
        #     return {"error": "timeout"}

        self.cmd_response_condition.clear()

        if self.response:
            output = self.response
            self.response = None
            return output
        return None

    def process_responses(self):
        while not self.stop_flag:
            self.cmd_query_condition.wait()
            self.cmd_query_condition.clear()

            if self.stop_flag:
                break

            try:
                self.session.expect("\r\n\r\n", timeout=SUBMIT_TIMEOUT)
                self.session.expect(["\r\n\r\n", pexpect.EOF], timeout=SUBMIT_TIMEOUT)
                output = self.session.before.strip()
                output_dict = json.loads(output)

                self.response = output_dict
                self.cmd_response_condition.set()

            except pexpect.TIMEOUT:
                print("Output timeout detected, stopping session.")
                self.timeout_occurred = True  # Mark session as dead
                self.response = {"error": "timeout"}  # Set special response
                self.stop()  # Stop the session immediately
                break
                # self.cmd_response_condition.set()
                # continue

            except pexpect.EOF:
                print("Session ended unexpectedly.")
                self.stop()
                break

            except json.JSONDecodeError as e:
                print(output)
                break

    def remove_last_comment(self):
        pattern = r"/--[^/]*?-/(\n*)$"
        self.context = re.sub(pattern, "", self.context, flags=re.DOTALL)

    def run(self):
        self.timer.start()
        try:
            self.session = pexpect.spawn(
                "bash", encoding="utf-8", cwd=self.lean_env_path
            )
            if self.context is not None:
                self.remove_last_comment()
                with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp:
                    json.dump({"cmd": self.context}, temp, ensure_ascii=False)
                    temp.write("\n\n")
                    temp.flush()
                command = f"lake env {self.repl_path}/.lake/build/bin/repl < <(cat {temp.name} -)"
            else:
                command = f"lake env {self.repl_path}/.lake/build/bin/repl"

            self.session.sendline(command)
            self.initialize_check()
            self.process_responses()
            self.stop()

        except Exception as e:
            print(f"Session {self.session_id}: An error occurred: {e}")
            self.stop()

    def stop(self):
        self.stop_flag = True
        self.init_complete.set()
        self.cmd_query_condition.set()
        self.cmd_response_condition.set()
        self.timer.cancel()

        if self.session is not None:
            try:
                self.session.terminate(
                    force=True
                )  # Forcefully kill the Lean REPL process
                self.session.wait()  # Ensure the process is fully terminated
                # print(f"Session {self.session_id} terminated successfully.")
            except Exception as e:
                pass
                # print(f"Session {self.session_id}: Failed to terminate process: {e}")

        self.session = None


In [2]:
REPL_PATH = "/home/awhe/projects/minictx-eval/repl"
LEAN_PATH = "/home/awhe/projects/minictx-eval/miniF2F-lean4"

context = """import MiniF2F.Minif2fImport\n  open BigOperators Real Nat Topology\n"""

thread = InteractiveThread(1, REPL_PATH, LEAN_PATH, initial_context=context, timeout=600)

thread.start()

# theorem_statement = "theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2) (h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by"

theorem_statement = "theorem mathd_algebra_33 (x y z : ℝ) (h₀ : x ≠ 0) (h₁ : 2 * x = 5 * y) (h₂ : 7 * y = 10 * z) : z / x = 7 / 25 := by"

output = thread.submit_and_receive({"cmd": theorem_statement + " sorry", "env": 0})

print(output)

# proof = """simp_all only [Nat.one_eq_succ_zero, Nat.zero_eq, zero_add, Nat.succ_inj', one_mul]
#   have h₃ : r ^ 2 = 3 := by nlinarith
#   have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by
#     apply or_iff_not_imp_right.2
#     intro h
#     field_simp [h, h₃] at h₁ h₂ ⊢
#     nlinarith
#   simp_all"""
# tactics = proof.split("\n")

tactics = [
    """ have h₃ : x = 25 * z / 7 := by
  apply Eq.symm
  field_simp
  linarith""",
  """have h₄ : y = 2 * x / 5 := by
  apply Eq.symm
  field_simp
  linarith""",
  "field_simp",
  "linarith", 
]

proof_state = 0
for tactic in tactics:
    print("Tactic: ", tactic)
    outcome = thread.submit_and_receive({"tactic": tactic, "proofState": proof_state})
    print(outcome)
    proof_state = outcome["proofState"]

thread.stop()
thread.join()


Tactic:  have h₃ : x = 25 * z / 7 := by
  apply Eq.symm
  field_simp
  linarith
{'proofState': 1, 'goals': ['x y z : ℝ\nh₀ : x ≠ 0\nh₁ : 2 * x = 5 * y\nh₂ : 7 * y = 10 * z\nh₃ : x = 25 * z / 7\n⊢ z / x = 7 / 25']}
Tactic:  have h₄ : y = 2 * x / 5 := by
  apply Eq.symm
  field_simp
  linarith
{'proofState': 2, 'goals': ['x y z : ℝ\nh₀ : x ≠ 0\nh₁ : 2 * x = 5 * y\nh₂ : 7 * y = 10 * z\nh₃ : x = 25 * z / 7\nh₄ : y = 2 * x / 5\n⊢ z / x = 7 / 25']}
Tactic:  field_simp
{'proofState': 3, 'goals': ['x y z : ℝ\nh₀ : x ≠ 0\nh₁ : 2 * x = 5 * y\nh₂ : 7 * y = 10 * z\nh₃ : x = 25 * z / 7\nh₄ : y = 2 * x / 5\n⊢ z * 25 = 7 * x']}
Tactic:  linarith
{'proofState': 4, 'goals': []}


In [20]:
proof_state = 2
tactic = """have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by    
    apply or_iff_not_imp_right.2
    intro h
    field_simp [h, h₃] at h₁ h₂ ⊢
    nlinarith"""
outcome = thread.submit_and_receive({"tactic": tactic, "proofState": proof_state})
print(outcome)

{'proofState': 3, 'messages': [{'severity': 'error', 'pos': {'line': 0, 'column': 0}, 'endPos': {'line': 0, 'column': 0}, 'data': 'linarith failed to find a contradiction\ncase h1.h\na r : ℝ\nu : ℕ → ℝ\nh₀ : ∀ (k : ℕ), u k = a * r ^ k\nh₂ : a * r ^ 3 = 6\nh₃ : r ^ 2 = 3\nh : ¬a = -(2 / Real.sqrt 3)\nh₁ : a * r = 2\na✝ : a * Real.sqrt 3 < 2\n⊢ False\nfailed'}], 'goals': ['a r : ℝ\nu : ℕ → ℝ\nh₀ : ∀ (k : ℕ), u k = a * r ^ k\nh₁ : a * r ^ succ 0 = 2\nh₂ : a * r ^ 3 = 6\nh₃ : r ^ 2 = 3\nh₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3)\n⊢ a * r ^ 0 = 2 / Real.sqrt 3 ∨ a * r ^ 0 = -(2 / Real.sqrt 3)']}


In [21]:
proof_state = 3
tactic = """simp_all"""
outcome = thread.submit_and_receive({"tactic": tactic, "proofState": proof_state})
print(outcome)

{'proofState': 4, 'goals': []}


In [4]:
REPL_PATH = "/home/awhe/projects/minictx-eval/repl"
LEAN_PATH = "/home/awhe/projects/minictx-eval/miniF2F-lean4"

context = """import MiniF2F.Minif2fImport\n  open BigOperators Real Nat Topology\n"""

thread = InteractiveThread(1, REPL_PATH, LEAN_PATH, initial_context=context, timeout=600)

thread.start()

theorem_statement = "theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2) (h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by"

output = thread.submit_and_receive({"cmd": theorem_statement + " sorry", "env": 0})

print(output)

proof = """simp_all only [Nat.one_eq_succ_zero, Nat.zero_eq, zero_add, Nat.succ_inj', one_mul]
  have h₃ : r ^ 2 = 3 := by nlinarith
  have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by
    apply or_iff_not_imp_right.2
    intro h
    field_simp [h, h₃] at h₁ h₂ ⊢
    nlinarith
  simp_all"""
tactics = ["\n".join(proof.split("\n")[0:2])]

proof_state = 0
for tactic in tactics:
    print("Tactic: ", tactic)
    outcome = thread.submit_and_receive({"tactic": tactic, "proofState": proof_state})
    print(outcome)
    proof_state = outcome["proofState"]

thread.stop()
thread.join()


Tactic:  simp_all only [Nat.one_eq_succ_zero, Nat.zero_eq, zero_add, Nat.succ_inj', one_mul]
  have h₃ : r ^ 2 = 3 := by nlinarith
{'message': 'Lean error:\n<input>:2:2: expected end of input'}


KeyError: 'proofState'