#### Neural next-step prediction | part 3: proof search
Tutorial on neural theorem proving\
Author: Sean Welleck

----------------

#### High-level goal

Our next goal is to prove theorems with our neural next-step predictor, and check whether the theorems are correct.

Proving and checking a theorem involves generating a next-step candidate with our model, giving it to Lean, and receiving a next state from Lean (or an error message). \
To do so, we will need two components:

1. **Interacting** with Lean:  an automated way to give a next-step to Lean and receive a next state (or an error).
<!--  -->
2. A **search strategy** that uses the next-step model and Lean to find a proof (e.g. generate one next-step, get the next state, repeat).
<!-- For example, a naive algorithm alternates between generating a single step, giving it to Lean, and continuing until a proof is complete or an error message is reached. One can imagine many other strategies, e.g. generating *multiple* next steps and choosing the 'best' one according to some criterion, backtracking upon receiving an error message, etc. -->

Below, we'll walk through a simple example of each. 

-------------------

### 1. Interaction

To start, we'll walk through proving this theorem:

```lean4
import Mathlib.Data.Nat.Prime

example (x y z : ℝ) (h₀ : x ≤ y) (h₁ : y ≤ z) : x ≤ z := by
  apply le_trans h₀
  apply h₁
```

#### Interaction with `Lean REPL`

The [`Lean REPL`](https://github.com/zhangir-azerbayev/repl/tree/master) gives us a programmatic interface to communicate with Lean.

We make a lightweight Python wrapper (based on code from the [pylean](https://github.com/zhangir-azerbayev/repl/tree/master) repo).


Set `PATH_TO_REPL` to the `ntp-interact/repl` directory (contained in this repository):

In [1]:
PATH_TO_REPL = '/Users/wellecks/projects/ntptutorial/partI_nextstep/ntp-interact/repl'

In [2]:
import pexpect
import json

import os

class LeanServer:
    # Based on code from the [pylean](https://github.com/zhangir-azerbayev/repl/tree/master) repo
    def __init__(self, path_to_repl):
        self.proc = pexpect.spawn(
            "lake env lean --run REPL/Main.lean", cwd=path_to_repl, encoding="utf-8")

    def run_code(self, code, env=None, verbose=False):
        if env:
            command = (
                json.dumps(dict(cmd=code, env=env))
            )  
        else:
            command = (
                '{ "cmd" : "' + repr(code)[1:-1] + '" }'
            )  

        if verbose:
            print(command)
        self.proc.sendline(command)
        self.proc.expect_exact(command + "\r\n")
        self.proc.sendline()
        self.proc.expect_exact("\r\n")
        try:
            index = self.proc.expect('env": \d+\}', timeout=20)
            output = self.proc.before + self.proc.match.group()
            if verbose: 
                print(output)
            return json.loads(output)
        except pexpect.exceptions.TIMEOUT:
            raise pexpect.exceptions.TIMEOUT

Now we can submit Lean code (e.g., imports, theorem declarations) and receive messages from Lean:

In [3]:
from pprint import pprint

code = """
import Mathlib

open Real

example (x y z : ℝ) (h₀ : x ≤ y) (h₁ : y ≤ z) : x ≤ z := by {}
"""


lean = LeanServer(PATH_TO_REPL)
msg = lean.run_code(code)
lean.proc.close()
pprint(msg)

{'env': 0,
 'messages': [{'data': 'unsolved goals\n'
                       'x y z : ℝ\n'
                       'h₀ : x ≤ y\n'
                       'h₁ : y ≤ z\n'
                       '⊢ x ≤ z',
               'endPos': {'column': 62, 'line': 6},
               'pos': {'column': 61, 'line': 6},
               'severity': 'error'}]}


We see that inside of `'data'`, the Lean REPL gives us the current proof state $x_t$; here's basic parsing code:

In [4]:
def get_goal(msg):
    goal = None
    for msg_ in msg['messages']:
        if msg_['data'].startswith('unsolved goals\n'):
            goal = '\n'.join(msg_['data'].split('\n')[1:])
        elif msg_['severity'] == 'error':
            return None
    return goal

print(get_goal(msg))

x y z : ℝ
h₀ : x ≤ y
h₁ : y ≤ z
⊢ x ≤ z


We can use $x_t$ as input to our model $p_\theta(y_t|x_t)$.\
Next, we load the trained model and generate a next step, $\hat y_t\sim q(p_\theta(y_t|x_t))$.

(Here $q(\cdot)$ is a decoding algorithm such as greedy decoding or temperature sampling.)

In [7]:
# Load model and tokenizer
import os
import transformers
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # prevents an annoying warning

MODEL = 'l3lab/ntp-mathlib-st-deepseek-coder-1.3b'
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)

def generate(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    out = model.generate(
        input_ids,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True).split('[/TAC')[0].strip()
    return text

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
# Generate a next step
prompt = """/- You are proving a theorem in Lean 4.
You are given the following information:
- The current proof state, inside [STATE]...[/STATE]

Your task is to generate the next tactic in the proof.
Put the next tactic inside [TAC]...[/TAC]
-/
[STATE]
%s
[/STATE]
[TAC]
""" % get_goal(msg)

prefix = ''
next_step = prefix + ' ' + generate(prompt + prefix)
print(next_step)

 linarith


Finally, we can give the generated next step to Lean and receive the next state.

In [9]:
code = """
import Mathlib

open Real

example (x y z : ℝ) (h₀ : x ≤ y) (h₁ : y ≤ z) : x ≤ z := by
""" + next_step

lean = LeanServer(PATH_TO_REPL)
state = lean.run_code(code)
lean.proc.close()

pprint(state)

{'env': 0}


There are no error messages, and no remaining goals - the proof is complete! If you want, paste this into VS Code to convince yourself that it's complete:

```lean4
import Mathlib.Data.Nat.Prime

example (x y z : ℝ) (h₀ : x ≤ y) (h₁ : y ≤ z) : x ≤ z := by
  linarith
```

Also, notice that the machine-generated proof is different from the human written one shown at the starting of this section.


**Exercise I:** modify the `prefix` variable above to obtain an invalid proof.

**Exercise II:** modify the `prefix` variable above to obtain an alternative, valid one-step proof.

-----------------

### 2. Search strategy

In the proof above, we simply generated one next step and the proof was complete.

In general, proofs are multiple steps. Therefore we need an algorithm for generating a multiple step proof, which we refer to as a *search algorithm*.


First, let's consider a naive algorithm that generates a next step, then continues to the next state. Upon receiving an error message
the algorithm generates another next step.

In [10]:
import sys
sys.path.append('../ntp-interact/')

import proofsearch # some utilities for running code (as we did above) and parsing states/model outputs

In [27]:
transformers.set_seed(42)

def _prompt_fn(goal):
    return """/- You are proving a theorem in Lean 4.
You are given the following information:
- The current proof state, inside [STATE]...[/STATE]

Your task is to generate the next tactic in the proof.
Put the next tactic inside [TAC]...[/TAC]
-/
[STATE]
%s
[/STATE]
[TAC]
""" % goal

def prove_simple(model, tokenizer, header, theorem_statement, search_budget):
    success = False

    code = header + theorem_statement
    steps = []
    proof = ''

    for i in range(search_budget):
        print("== Current (%d): " % i, theorem_statement[:-3] + '\n' + proof, sep='\n')

        # Run the code (header + proof-so-far)
        state = proofsearch.run_code(code, path_to_repl=PATH_TO_REPL)
        
        # Stop if the proof is complete.
        if proofsearch.is_done(state):
            success = True
            break

        # Get the new state.
        goal_candidate = proofsearch.get_goal(state)
        if goal_candidate is None:
            print("-- Error: backtracking")
            steps = steps[:-1]
        else:
            goal = goal_candidate

        print("-- Goal: ", goal, sep='\n')

        # Generate a next-step
        prompt = _prompt_fn(goal)
        texts, _= proofsearch.generate(
            prompt, model, tokenizer, temperatures=[0.5], num_samples=1,
            stop=['[/TAC]']
            )
        step = proofsearch.parse_step(texts[0])

        # Add the next-step to the proof-so-far
        steps.append(step)
        proof = '\n'.join(steps)
        code = header + theorem_statement.replace(" {}", "") + '\n' + proof
        print()

    if success:
        print("\nSUCCESS!")
    else:
        print("\nFAILED")
    
    print(theorem_statement.replace(" {}", ""))
    print ('  ' + proof.replace('\n', '\n  '))
    
    return {'theorem_statement': theorem_statement, 'proof': proof, 'success': success}


header = """
import Mathlib.Data.Nat.Prime

"""
theorem_statement = """theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}"""


out = prove_simple(
    model, 
    tokenizer,
    header, 
    theorem_statement, 
    search_budget=100
)

== Current (0): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by

-- Goal: 
a b c : ℕ
⊢ a + b = c → a ≤ c

== Current (1): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
rintro rfl
-- Goal: 
a b : ℕ
⊢ a ≤ a + b

== Current (2): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
rintro rfl
simp

SUCCESS!
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
  rintro rfl
  simp


Above the model first generates `rintro rfl`, and receives a new proof state.\
Then it generates `simp` and the proof is complete.

### Best-first search

Typically a less naive search procedure is used. These searches are usually variants of a tree search, in which nodes are states and edges are next-steps. 

The most common search in neural theorem proving is *best-first search*. This search:

- generates multiple next-step suggestions to form (proof-so-far + next-step) *candidates*
- scores all candidates so far
- selects the highest scoring candidate

A typical scoring function is the model's log probability, $\log p_\theta(y_t|x_t)$, summed across steps. Next-steps that lead to an error receive a score of $-\infty$ (in practice, we discard these steps). In the literature, the scoring function is called a *value function* $v(y_{\leq t}, x_t)$.

#### Intuition

A key idea is generating multiple suggestions at each step, ${y_t^{(1)},\ldots,y_t^{(k)}}\sim p_\theta(\cdot|x_t)$. Intuitively, the goal is to select a next-step that will lead to a correct proof. In general, we do not know whether a next-step will lead to a correct proof, so we use a heuristic value function for selecting a next-step.

Here's what multiple suggestions and their (normalized) log-probabilities look like in our example:

In [30]:
transformers.set_seed(40)

prompt = _prompt_fn(goal="""a b c : ℕ\n⊢ a + b = c → a ≤ c""")
texts, scores = proofsearch.generate(prompt, model, tokenizer, temperatures=[1.0], num_samples=10, stop=['[/TAC]'])
for text, score in zip(texts, scores):
    text = text.strip()
    print('%.3f' % score, text, sep='\t')

-0.131	rintro rfl
[/TAC]
-0.346	intro h
[/TAC]
-0.372	rintro ⟨⟩
[/TAC]
-0.580	simpa only [Nat.zero_add] using le_of_add_le_add_left
[/TAC]
-0.604	rw [add_comm, add_le_add_iff_right]
[/TAC]
-0.686	let g : ℕ → ℕ → Prop := fun a b ↦ (a + b = c)
[/TAC]


### Implementation

A minimal implementation of best first search is available in `proofsearch.py`.

We will use this in the next notebook to evaluate our model on a set of evaluation theorems.\
Below, we run best first search and print out the search trajectory:

In [31]:
proofsearch.best_first_search(
    model, tokenizer, header, theorem_statement, 
    max_iters=32,
    num_samples=4,
    temperatures=[0.0],
    verbose=True,
    path_to_repl=PATH_TO_REPL
)

--- current:
	theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	


100%|███████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.08it/s]


--- type-checked candidates:
	(-0.131) rintro rfl
	(-0.131) rintro rfl
	(-0.346) intro h
	(-0.372) rintro ⟨⟩
--- current:
	theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	rintro rfl


100%|███████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.07it/s]

--- type-checked candidates:
	(-0.325) simp





{'theorem_statement': 'theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}',
 'proof': ['rintro rfl', 'simp'],
 'state': {'env': 0},
 'score': 0.4551752954721451,
 'success': True}

The search selects a candidate trajectory, and generates 4 next-step suggestions.\
`rintro rfl` is selected at the first step because it has the best score. \
The best expansion of `rintro rfl` is `simp`. `simp` completes the proof, so the proof terminates.


**Exercise:** suppose that `simp` did not complete the proof. Which tactic would be expanded next?

--------------------


## Extensions

Several works have proposed to improve the search strategy, either with a learned value function or a sophisticated search:

- [Polu & Sutskever 2020](https://arxiv.org/pdf/2009.03393.pdf) propose to learn a value function $v(y_{\leq t}, x_t)$ that estimates the probability of successfully proving the theorem with the model $p_\theta$ starting at state $x_t$. To do so, they use proof search trajectories obtained by doing proof search with the model.

- [Polu et al ICLR 2023](https://openreview.net/pdf?id=-P7G-8dmSh4) train the value function to predict the eventual length of the proof (or 0 if it is predicted to fail). The learned value function improves pass rate by ~10\% on mathlib theorems compared to log-probability, with a ~1\% improvement over the learned value function from [Polu & Sutskever 2020].

- [Lample et al NeurIPS 2022](https://openreview.net/pdf?id=J4pX8Q8cxHH) propose a sophisticated MCTS-like search that explores multiple trajectories in parallel, collecting statistics on visited states in order to prioritize search trajectories.

Reproducing, analyzing, and improving the search algorithm remains an open area for future work in neural theorem proving (for instance, these works were not open-sourced).

Search algorithms are also an active area of research in LLMs, including methods like [tree-of-thought](https://arxiv.org/abs/2305.10601), [stepwise beam search](https://arxiv.org/pdf/2205.12910.pdf), [self-consistency](https://arxiv.org/pdf/2203.11171.pdf), and search with [learned stepwise verifiers](https://arxiv.org/pdf/2305.20050.pdf). In theorem proving, the final output is verifiable, but the quality of intermediate steps is difficult to evaluate.