#### Neural next-step prediction | part 3: proof search
Tutorial on neural theorem proving\
Author: Sean Welleck

----------------

#### High-level goal

Our next goal is to prove theorems with our neural next-step predictor, and check whether the theorems are correct.

Proving and checking a theorem involves generating a next-step candidate with our model, giving it to Lean, and receiving a next state from Lean (or an error message). \
To do so, we will need two components:

1. **Interacting** with Lean:  an automated way to give a next-step to Lean and receive a next state (or an error).
<!--  -->
2. A **search strategy** that uses the next-step model and Lean to find a proof (e.g. generate one next-step, get the next state, repeat).
<!-- For example, a naive algorithm alternates between generating a single step, giving it to Lean, and continuing until a proof is complete or an error message is reached. One can imagine many other strategies, e.g. generating *multiple* next steps and choosing the 'best' one according to some criterion, backtracking upon receiving an error message, etc. -->

Below, we'll walk through a simple example of each. 

-------------------

### 1. Interaction

To start, we'll walk through proving this theorem:

```lean4
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by 
  rw [Nat.coprime] at h  
  exact h  
```

#### Interaction with `pylean`

The [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master) library gives us a lightweight interface to a lean REPL.

##### Installation

In [21]:
!cd ../../../ && git clone https://github.com/zhangir-azerbayev/repl
%cd /workspace/repl
!git checkout bddf452deda0df2240b248e651bcc37fb8e59d01
%cd /workspace/repl/pylean
!python setup.py develop

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
fatal: destination path 'repl' already exists and is not an empty directory.
/workspace/repl
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Note: switching to 'bddf452deda0df2240b248e651bcc37fb8e59d01'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.

If you want to create a new branch to retain commits you crea

In [25]:
%%writefile /workspace/repl/lakefile.lean
import Lake
open Lake DSL

package REPL {
  -- add package configuration options here
}

lean_lib REPL {
  -- add library configuration options here
}

require mathlib from git
  "https://github.com/leanprover-community/mathlib4.git" @ "38dbcd8285bc4b1391619c12f158a7409f3dfc12"

-- Unfortunately the compiled version doesn't work: `unknown package 'Init'`
@[default_target]
lean_exe repl where
  root := `REPL.Main
  supportInterpreter := true

Overwriting /workspace/repl/lakefile.lean


In [46]:
%cd /workspace/ntptutorial
!cd ../repl && lake exe cache get && lake build

/workspace/ntptutorial
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
No files to download
Decompressing 2301 file(s)


##### Interact

We can pass `pylean` the import and theorem statement:

In [47]:
from pylean import LeanServer
from pprint import pprint

code = """
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}
"""

lean = LeanServer()
state = lean.run_code(code)
lean.proc.close()
pprint(state)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


EOF: End Of File (EOF). Exception style platform.
<pexpect.pty_spawn.spawn object at 0x7f79b69770a0>
command: /root/.elan/bin/lake
args: [b'/root/.elan/bin/lake', b'env', b'lean', b'--run', b'REPL/Main.lean']
buffer (last 100 chars): ''
before (last 100 chars): ''
after: <class 'pexpect.exceptions.EOF'>
match: None
match_index: None
exitstatus: None
flag_eof: True
pid: 28112
child_fd: 149
closed: False
timeout: 30
delimiter: <class 'pexpect.exceptions.EOF'>
logfile: None
logfile_read: None
logfile_send: None
maxread: 2000
ignorecase: False
searchwindowsize: None
delaybeforesend: 0.05
delayafterclose: 0.1
delayafterterminate: 0.1
searcher: searcher_re:
    0: re.compile('env": \\d+\\}')

We see that inside of `'data'`, `pylean` gives us the current proof state $x_t$; here's basic parsing code:

In [2]:
def get_goal(state):
    goal = None
    for msg in state['messages']:
        if msg['data'].startswith('unsolved goals\n'):
            goal = '\n'.join(msg['data'].split('\n')[1:])
        elif msg['severity'] == 'error':
            return None
    return goal

print(get_goal(state))

m n : ℕ
h : Nat.coprime m n
⊢ Nat.gcd m n = 1


We can use $x_t$ as input to our model $p_\theta(y_t|x_t)$.\
Next, we load the trained model and generate a next step, $\hat y_t\sim q(p_\theta(y_t|x_t))$.

(Here $q(\cdot)$ is a decoding algorithm such as greedy decoding or temperature sampling.)

In [None]:
# Load model and tokenizer
import os
import transformers
model_name = 'wellecks/llmstep-mathlib4-pythia2.8b'
model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)
tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # prevents an annoying warning


def generate(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    out = model.generate(
        input_ids,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id
    )
    text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
    return text

In [4]:
# Generate a next step
prompt = f"[GOAL]{get_goal(state)}[PROOFSTEP]"

next_step = generate(prompt)
print(next_step)

rw [← h.gcd_eq_one]


Finally, we can give the generated next step to Lean and receive the next state.

In [5]:
code = """
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by 

""" + next_step

lean = LeanServer()
state = lean.run_code(code)
lean.proc.close()

pprint(state)

{'env': 0, 'messages': [], 'sorries': []}


There are no error messages, and no remaining goals - the proof is complete! If you want, paste this into VS Code to convince yourself that it's complete:

```lean4
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by
    rw [← h.gcd_eq_one]
```

Also, notice that the machine-generated proof is different from the human written one shown at the starting of this section.

-----------------

### 2. Search strategy

In the proof above, we simply generated one next step and the proof was complete.

In general, proofs are multiple steps. Therefore we need an algorithm for generating a multiple step proof, which we refer to as a *search algorithm*.


First, let's consider a naive algorithm that generates a next step, then continues to the next state. Upon receiving an error message
the algorithm generates another next step.

In [6]:
import sys
sys.path.append('../ntp_python/')

import proofsearch_pylean as proofsearch # some utilities for running code (as we did above) and parsing states/model outputs

In [7]:
transformers.set_seed(43)

def prove_simple(model, tokenizer, header, theorem_statement, search_budget):
    success = False

    code = header + theorem_statement
    steps = []
    proof = ''

    for i in range(search_budget):
        print("== Current (%d): " % i, theorem_statement[:-3] + '\n' + proof, sep='\n')

        # Run the code (header + proof-so-far)
        state = proofsearch.run_code(code)
        
        # Stop if the proof is complete.
        if proofsearch.is_done(state):
            success = True
            break

        # Get the new state.
        goal_candidate = proofsearch.get_goal(state)
        if goal_candidate is None:
            print("-- Error: backtracking")
            steps = steps[:-1]
        else:
            goal = goal_candidate

        print("-- Goal: ", goal, sep='\n')

        # Generate a next-step
        prompt = f"[GOAL]{goal}[PROOFSTEP]"
        texts, _= proofsearch.generate(prompt, model, tokenizer, temperatures=[0.5], num_samples=1)
        step = proofsearch.parse_step(texts[0])

        # Add the next-step to the proof-so-far
        steps.append(step)
        proof = '\n'.join(steps)
        code = header + theorem_statement.replace(" {}", "") + '\n' + proof
        print()

    if success:
        print("\nSUCCESS!")
    else:
        print("\nFAILED")
    
    print(theorem_statement.replace(" {}", ""))
    print ('  ' + proof.replace('\n', '\n  '))
    
    return {'theorem_statement': theorem_statement, 'proof': proof, 'success': success}


header = """
import Mathlib.Data.Nat.Prime

"""
theorem_statement = """theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}"""


out = prove_simple(
    model, 
    tokenizer,
    header, 
    theorem_statement, 
    search_budget=100
)

== Current (0): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by

-- Goal: 
a b c : ℕ
⊢ a + b = c → a ≤ c

== Current (1): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
rintro rfl
-- Goal: 
a b : ℕ
⊢ a ≤ a + b

== Current (2): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
rintro rfl
exact le_add_left _ _
-- Error: backtracking
-- Goal: 
a b : ℕ
⊢ a ≤ a + b

== Current (3): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
rintro rfl
apply Nat.le_add_right sperr a
-- Error: backtracking
-- Goal: 
a b : ℕ
⊢ a ≤ a + b

== Current (4): 
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
rintro rfl
apply Nat.le_add_right

SUCCESS!
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by
  rintro rfl
  apply Nat.le_add_right


Above (setting `seed = 43` for reproducibility) the model generates `rintro rfl`. \
Next it generates `exact le_add_left _ _`, which receives an error, so the model tries again (backtracks). \
After backtracking one more time, the model generates `apply Nat.le_add_right` and the proof is complete.

### Best-first search

Typically a less naive search procedure is used. These searches are usually variants of a tree search, in which nodes are states and edges are next-steps. 

The most common search in neural theorem proving is *best-first search*. This search:

- generates multiple next-step suggestions to form (proof-so-far + next-step) *candidates*
- scores all candidates so far
- selects the highest scoring candidate

A typical scoring function is the model's log probability, $\log p_\theta(y_t|x_t)$, summed across steps. Next-steps that lead to an error receive a score of $-\infty$ (in practice, we discard these steps). In the literature, the scoring function is called a *value function* $v(y_{\leq t}, x_t)$.

#### Intuition

A key idea is generating multiple suggestions at each step, ${y_t^{(1)},\ldots,y_t^{(k)}}\sim p_\theta(\cdot|x_t)$. Intuitively, the goal is to select a next-step that will lead to a correct proof. In general, we do not know whether a next-step will lead to a correct proof, so we use a heuristic value function for selecting a next-step.

Here's what multiple suggestions and their (normalized) log-probabilities look like in our example:

In [8]:
prompt = '[GOAL]m n : ℕ\nh : Nat.coprime m n\n⊢ Nat.gcd m n = 1[PROOFSTEP]'
texts, scores = proofsearch.generate(prompt, model, tokenizer, temperatures=[0.0], num_samples=5)
for text, score in zip(texts, scores):
    print('%.3f' % score, text, sep='\t')

-0.277	rw [Nat.coprime, gcd_comm] at h
-0.279	rw [← h.gcd_eq_one]
-0.335	apply Nat.eq_one_of_dvd_dvd
-0.349	rw [Nat.coprime] at h
-0.350	rw [gcd_comm]


### Implementation

A minimal implementation of best first search is available in `proofsearch_pylean.py`.\
A version that uses LeanDojo for interaction is in `proofsearch_dojo.py`.

We will use these in the next notebook to evaluate our model on a set of evaluation theorems.\
Below, we run best first search and print out the search trajectory:

In [9]:
proofsearch.best_first_search(
    model, tokenizer, header, theorem_statement, 
    max_iters=32,
    num_samples=4,
    temperatures=[0.0],
    verbose=True
)

--- current:
	theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	


100%|██████████| 4/4 [00:03<00:00,  1.10it/s]


--- type-checked candidates:
	(-0.066) rintro rfl
	(-0.307) rintro ⟨rfl, rfl⟩
	(-0.035) intro h
	(-0.230) rintro ⟨d, rfl⟩
--- current:
	theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	intro h


100%|██████████| 4/4 [00:03<00:00,  1.11it/s]


--- type-checked candidates:
	(-0.172) apply le_of_add_le_add_right
	(-0.093) rw [← h]
	(-0.453) cases c
--- current:
	theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	rintro rfl


100%|██████████| 4/4 [00:03<00:00,  1.10it/s]

--- type-checked candidates:
	(-0.109) apply Nat.le_add_right
	(-0.173) exact Nat.le_add_right _ _





{'theorem_statement': 'theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}',
 'proof': ['rintro rfl', 'apply Nat.le_add_right'],
 'state': {'sorries': [], 'messages': [], 'env': 0},
 'score': 0.1747819110751152,
 'success': True}

The search selects a candidate trajectory, and generates 4 next-step suggestions.\
`intro h` is selected at the first step. The best expansion of `intro h` has score -0.093. \
This is less than the score of `rintro rfl` (-0.066), so `rintro rfl` is picked. This is backtracking, since `intro h` is no longer in the proof.\
Then `apply Nat.le_add_right` is suggested and the proof is complete.


--------------------


## Extensions

Several works have proposed to improve the search strategy, either with a learned value function or a sophisticated search:

- [Polu & Sutskever 2020](https://arxiv.org/pdf/2009.03393.pdf) propose to learn a value function $v(y_{\leq t}, x_t)$ that estimates the probability of successfully proving the theorem with the model $p_\theta$ starting at state $x_t$. To do so, they use proof search trajectories obtained by doing proof search with the model.

- [Polu et al ICLR 2023](https://openreview.net/pdf?id=-P7G-8dmSh4) train the value function to predict the eventual length of the proof (or 0 if it is predicted to fail). The learned value function improves pass rate by ~10\% on mathlib theorems compared to log-probability, with a ~1\% improvement over the learned value function from [Polu & Sutskever 2020].

- [Lample et al NeurIPS 2022](https://openreview.net/pdf?id=J4pX8Q8cxHH) propose a sophisticated MCTS-like search that explores multiple trajectories in parallel, collecting statistics on visited states in order to prioritize search trajectories.

Reproducing, analyzing, and improving the search algorithm remains an open area for future work in neural theorem proving (for instance, these works were not open-sourced).

Search algorithms are also an active area of research in LLMs, including methods like [tree-of-thought](https://arxiv.org/abs/2305.10601), [stepwise beam search](https://arxiv.org/pdf/2205.12910.pdf), [self-consistency](https://arxiv.org/pdf/2203.11171.pdf), and search with [learned stepwise verifiers](https://arxiv.org/pdf/2305.20050.pdf). In theorem proving, the final output is verifiable, but the quality of intermediate steps is difficult to evaluate.