In [21]:
import pexpect
import json
import argparse
import os


class ProofSearch:
    def __init__(self, path_to_repl):
        # debug
        # print(f"LOOKING FOR REPL IN {path_to_repl}")
        self.proc = pexpect.spawn(
            "lake env lean --run REPL/Main.lean", cwd=path_to_repl, encoding="utf-8"
        )
        self.proc.debug = True

    def run_code(self, code, env=None, verbose=False):
        if env:
            command = (
                '{ "cmd" : "' + repr(code)[1:-1] + f'", "env" : {env}' + " }"
            )  # [1:-1] removes single quotes
        else:
            command = (
                '{ "cmd" : "' + repr(code)[1:-1] + '" }'
            )  # [1:-1] removes single quotes

        if verbose:
            print(command)
        self.proc.sendline(command)
        self.proc.expect_exact(command + "\r\n")

        # debugging
        # print(self.proc.before)

        self.proc.sendline()
        self.proc.expect_exact("\r\n")
        try:
            index = self.proc.expect('env": \d+\}', timeout=20)
            output = self.proc.before + self.proc.match.group()
            if verbose: 
                print(output)
            return json.loads(output)
        except pexpect.exceptions.TIMEOUT:
            print("FAILED DUE TO TIMEOUT")


# def main():
#     """For testing purposes"""
#     path = os.environ.get("PATH_TO_LEAN_REPL")
    
#     print("lean repl path: ", path)

#     proofsearch = ProofSearch(path)

#     # should return empty sorries and goals
#     proofsearch.run_code("import Mathlib.Data.List.Basic\ndef f := 2", verbose=True)
#     # should return goal state
#     proofsearch.run_code("example : 2 = 2 := by", verbose=True)
#     # should return error
#     proofsearch.run_code("example : 2 = 3 := rfl", verbose=True)

#     # should return goal state
#     feedback = proofsearch.run_code("def f := 37", verbose=True)
#     env = feedback["env"]
#     proofsearch.run_code("#check (rfl: f = 37)", env=env, verbose=True)


# if __name__ == "__main__":
#     main()


In [9]:
import requests
from requests import HTTPError
from dataclasses import dataclass
from typing import List
from tiktoken import get_encoding
import os
from pydantic import BaseModel
import backoff
import tiktoken

SYSTEM_MESSAGE: str = """\
You are a pure mathematician who is an expert in the Lean 4 theorem prover. 
Your job is help your user write Lean proofs.
I want to remind you that we're using Lean 4, not the older Lean 3, and there have been some syntax changes. 

In particular:
- Type constants are now UpperCamelCase, eg `Nat`, `List`.
- Term constants and variables are now `lowerCamelCase` rather than `snake_case`. 
  For example, we now have `NumberTheory.Divisors.properDivisors instead of `number_theory.divisors.proper_divisors`.
- Pure functions are now written with the syntax `fun x => f x`. The old `λ x, f x` syntax will not work.
- We now enter tactic mode using the `by` keyword. The syntax `begin... end` will not work.
- Instead of being separated by a comma, tactics are separated by a newline. 
    For example, we could write.
    ```lean
    theorem test (p q : Prop) (hp : p) (hq : q) : p ∧ q ∧ p := by
      apply And.intro hp
      exact And.intro hq hp
    ```
- In the `rw` tactic you must enclose the lemmas in square brackets, even if there is just one. 
  For example `rw h1` is now `rw [h1]`.
- The `induction` tactic now uses a structured format, like pattern matching. 
    For example, in Lean 4 we can write
    ```lean
    theorem zero_add (n : Nat) : 0 + n = n := by
      induction n with
      | zero => rfl
      | succ n ih => rw [Nat.add_succ, ih]
    ```

Alternatively you can still use `induction' with x y ih`, like in Lean 3.
- The `cases` tactic now uses a structured format, like pattern matching. For example, in Lean 4 we can write
```lean
example (p q : Prop) : p ∨ q → q ∨ p := by
  intro h
  cases h with
  | inl hp => apply Or.inr; exact hp
  | inr hq => apply Or.inl; exact hq\

The following is a description of some commonly used tactics. 
Of course, feel free to use tactics outside of this list. 
Remember that it is good style to use high-level automations like `simp` and `ring` instead of manually performing low-level manipulations. 
- `abel`: reduces expressions in additive, commutative monoids/groups to a normal form. 
- `apply`: the tactic `apply e` matches the current goal against the conclusion of `e`. If it succeeds, the new goal states are the premises of `e`.
- `continuity`: attempts to prove goals of the form `continuous f` by applying lemmas tagged with the `continuity` attribute. 
- `contrapose`: transforms the goal into its contrapositive.
- `convert`: The tactic `convert e` is similar to `refine e`, except the type of `e` is not required to exactly match the goal. Any rewrites required to transform `e` into the goal become the new goal state.
- `group`: normalizes expressions in multiplicative groups, without assuming commutativity.
- `have`: `have h : t := p` adds the hypothesis `h : t` to the current goal. If you want to prove `h` in tactic mode, use the syntax `have h : t := by --tactic proof goes here`. 
- `linarith`: proves any goal that consists of linear arithemtic.
- `nlinarith`: version of `linarith` that can tolerate some nonlinearity.
- `norm_num`: normalizes numerical expressions.
- `polyrith`: proves polynomial equalities.
- `push_neg`: pushes negations through quantifiers.
- `simp`: uses lemmas and hypotheses tagged with the `simp` attribute to simplify the goal. Use `simp [h1, h2,..., hn]` to add `h1, h2,..., hn` to the list of lemmas used by simp.
- `ring`: tactic for solving goals involving expressions in commutative rings and normalizing expressions in commutative rings.
"""

PROOF_INSTRUCTION = """\
1. Please write out a plan for proceeding with the proof. 
   Write your plan in English (with LaTeX).
   
2. Please add the next tactic step to the proof. 
   Include the new version of your (possibly incomplete) proof in a lean code block. 
   Make sure the code block is self-contained and runs. Do not add more than one new tactic step."""

AUTOFORMALIZE_PROOF_INSTRUCTION = """\
1. Please plan out a plan for your formal proof. 
   You can use the natural language proof as a guide, but there is no need to follow it exactly, or at all.

2. Please add the next tactic step to the proof. 
   Include the new version of your (possibly incomplete) proof in a lean code block. 
   Make sure the code block is self-contained and runs. 
   Do not add more than one new tactic step. 
   If you introduce a new lemma in a `have` statement, only supply one tactic step in the proof of the lemma.\
"""

def f2f_initial_prompt(code): 
    return f"""\
I am going to show you an incomplete proof and the accompanying goal state. 
I will ask you to complete the proof step by step, adding one tactic step in each response. 

Here is my Lean code so far: 
```lean
{code}
```
{PROOF_INSTRUCTION}"""
    

def autoformalize_proof_initial_prompt(nl_statement, nl_proof, code):
    return f"""\
I am going to show you a natural language proof of a theorem and a corresponding formal theorem statement in Lean 4. 
Your job will be to write a formal proof of the formal theorem statement, using the natural language proof as a hint.

Here are the natural language theorem and proof:
\\begin{{theorem}}
    {nl_statement}
\\end{{theorem}}
{nl_proof}

Below is the Lean code I would like you to complete.
```lean
{code}
```
{AUTOFORMALIZE_PROOF_INSTRUCTION}"""

def autoformalize_statement_and_proof_initial_prompt(nl_statement, nl_proof, code): 
    return f"""\
I am going to show you a natural language theorem statement and natural language proof of that theorem. 
Your job will be to formalize the statement of the theorem in Lean 4 and formally prove the statement. 

Here are the natural language theorem and proof:
\\begin{{theorem}}
    {nl_statement}
\\end{{theorem}}
{nl_proof}

Here is the code template for your formalization. 
```lean
{code}
```
{AUTOFORMALIZE_PROOF_INSTRUCTION}
"""


def prove_unsolved_goals_prompt(goal_state):
    return f"""\
Here is the new goal state:
```lean
{goal_state}
```
{PROOF_INSTRUCTION}"""


def sorry_prompt():
    return """\
There is a sorry in your code. 
Please do not write any code that contains sorries. 
Instead, finish typing at the location where you want to see the goal state. 
Remove the sorry, but do not add any new tactic steps.\
"""

class ChatMessage(BaseModel):
    role: str
    content: str

    def __str__(self): 
        return f">>>{self.role.upper()}\n" + self.content


class ChatState(BaseModel):
    messages: List[ChatMessage]

    def __str__(self): 
        return "\n".join(str(x) for x in self.messages)


@backoff.on_exception(backoff.expo, HTTPError)
def generate_message(chat_state: ChatState, 
                     temperature=0.4, 
                     top_p=0.95, 
                     max_tokens=2048, 
                     model: str = "gpt-4") -> str:
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }
    
    payload = dict(
        model=model,
        messages=chat_state.dict()["messages"],
        max_tokens=max_tokens,
        stream=False,
        top_p=top_p,
        temperature=temperature,
    )
    
    r = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    if r.status_code != 200: 
        print(chat_state)

        enc = tiktoken.encoding_for_model(model)
        num_tokens = len(enc.encode(str(chat_state)))
        print(f"Chat contains approx {num_tokens} tokens")

        print(r, "retrying")

        raise HTTPError

    return r.json()["choices"][0]["message"]["content"]

def complete_chat(chat_state: ChatState, **kwargs): 
    print("waiting on api...")
    response_text = generate_message(chat_state, **kwargs)
    print(f"GOT RESPONSE")
    return ChatState(messages=[*chat_state.messages, ChatMessage(role="assistant", 
                                                                 content=response_text)])

def generate_message_lean_single(input: str):
    return generate_message(ChatState(messages=[ChatMessage(role="system", 
                                                            content=SYSTEM_MESSAGE), 
                                                ChatMessage(role="user", content=input)]))


In [11]:
from dataclasses import dataclass

@dataclass
class Config:
    CAPABILITIES_API_KEY: str = "75552573-61fe-763c-5597-d77f63c32e2c"

CONFIG = Config()

In [24]:
# from pysagredo.util import CONFIG
# from pysagredo.llm import *
# from pysagredo.gym import ProofSearch

from dataclasses import dataclass, asdict
from typing import List, Dict

from pydantic import BaseModel

import tiktoken

@dataclass
class ProverState:
    sketch: str
    code: str
    goals: List[str]
    errors: List[Dict]

    def __str__(self):
        str_repr = "\n".join([
            f">>>{key.upper()}" + ">"*30 + f"\n{item}"  
            for key, item in asdict(self).items()
            ])
        return f"ProverState:\n{str_repr}"


def code_of_chat_state(state: ChatState, lang="lean"):
    """
    Extracts contents of last lean code block from raw text

    Requires that ChatState.messages[-1] has role "assistant"
    """
    assert state.messages[-1].role == "assistant"

    text = state.messages[-1].content
    left_key = f"```{lang}"
    left_idx = text.rindex(left_key)
    right_idx = text.rindex("```")
    code = text[left_idx + len(left_key) + 1 : right_idx]  # +1 for the newline
    return code.strip()


def goals_errors_of_lean_state(lean_state):
    goals = [m["data"] for m in lean_state["messages"] if "unsolved goals" in m["data"]]
    goals += [m["goal"] for m in lean_state.get("sorries", []) if "goal" in m]
    # goals +=[m["goal"] for m in lean_state["sorries"]] 
    errors = [m for m in lean_state["messages"] if "unsolved goals" not in m["data"]
            and m["severity"]=="error"]
    return goals, errors


def sketch_prompt(code: str) -> str:
    return \
f"""I am trying to write a formal proof of the following theorem in Lean 4: 
```lean
{code}
```
I am going to ask you to finish this Lean proof. But first, plan out your proof in natural language and LaTeX. 

Formatting instructions: enclose your plan in a ```latex code block```"""


def sketch_of_code(code: str, verbose=False) -> str:
    """
    Given a stub of a theorem, returns a natural langauge proof sketch. 

    Args: 
        code (str)
        verbose (bool): whether to print stuff. Defaults to False. 
    """
    user_prompt = sketch_prompt(code)

    chat_state = ChatState(
        messages=[
            ChatMessage(role="system", content=SYSTEM_MESSAGE),
            ChatMessage(role="user", content=user_prompt),
        ]
    )

    chat_state = complete_chat(chat_state)

    if verbose: 
        print(chat_state)

    sketch = code_of_chat_state(chat_state, lang="latex")
    return sketch


def next_tactic_prompt(proverstate: ProverState) -> str:
    goals_string = "\n\n".join(proverstate.goals)
    return f"""I am trying to complete this proof in Lean 4: 
```lean
{proverstate.code}
```

I am following this natural language proof sketch: 
```latex
{proverstate.sketch}
```
These are the open goals in my Lean code: 
```
{goals_string}
```
1. Please write out a plan for completing the formal proof. Write your plan in English (with LaTeX). The above proof sketch may be helpful, but you do not have to follow it exactly.
2. Please add the next tactic step to the proof. Do not add more than one new tactic step

Formatting instructions: include the new version of your (possibly incomplete) proof in a ```lean code block```. Make sure the code block is self-contained and runs."""


def fix_error_prompt(proverstate: ProverState) -> str:
    error_strings = [
            f'line {x["pos"]["line"]} col {x["pos"]["column"]}:\n{x["data"]}'
            for x in proverstate.errors
            ]

    errors_string = "\n\n".join(error_strings)
    # the newline after the `by` is important, if you dont' use a sorry. Use a sorry.
    return f"""The following is a Lean 4 proof I am working on: 

```lean
{proverstate.code}
```
This proof returns the following errors. 
```
{errors_string}
```
I am following this proof sketch: 
```latex
{proverstate.sketch}
```

Please describe how you are going to fix the errors. 
Modify the code to fix the error, but do not add any additional tactic steps.

Formatting instructions: Write the answer in a ```lean code block```."""


def prover_kernel(proverstate: ProverState, mode: str, verbose=False) -> ProverState:
    """
    Takes a ProverState with proverstate.goals or proverstate.errors nonempty, and prompts a language model
    to fix the error.

    Args:
        proverstate (ProverState)
        mode (str): equal to "prove" or "error"

    Requires:
        if `mode="next_tactic"`, requires `not proverstate.sorries`
        if `model="error"`, requires `proverstate.sorries`
    """
    if mode == "next_tactic":
        assert not proverstate.errors
        user_prompt = next_tactic_prompt(proverstate)
    elif mode == "error":
        assert proverstate.errors
        user_prompt = fix_error_prompt(proverstate)
    else:
        raise ValueError("`mode` not recognized")

    chat_state = ChatState(
        messages=[
            ChatMessage(role="system", content=SYSTEM_MESSAGE),
            ChatMessage(role="user", content=user_prompt),
        ]
    )

    chat_state = complete_chat(chat_state)

    if verbose:
        print(f">>>{mode} MODE" + ">"*30)
        print(">>>USER" + ">"*30)
        print(chat_state.messages[-2].content)
        print(">>>ASSISTANT" + ">"*30)
        print(chat_state.messages[-1].content)

    new_code = code_of_chat_state(chat_state).strip()

    if "PATH_TO_LEAN_REPL" in os.environ: 
        replpath = os.environ.get("PATH_TO_LEAN_REPL")
    else: 
        raise EnvironmentError("no PATH_TO_LEAN_REPL")
    lean = ProofSearch(replpath)

    lean_state = lean.run_code(new_code.strip(), verbose=verbose)

    goals, errors = goals_errors_of_lean_state(lean_state)

    new_proverstate = ProverState(
        sketch=proverstate.sketch,
        code=new_code,
        goals=goals,
        errors=errors,
    )

    if verbose:
        print(new_proverstate)

    return new_proverstate, chat_state


def prover(proverstate: ProverState, max_api_calls=10, verbose=False) -> Dict:
    proverstates = [proverstate]
    chat_states = []

    num_api_calls = 0 

    stop_reason = "max_calls"
    for _ in range(max_api_calls):
        if proverstate.errors:
            proverstate, chat_state = prover_kernel(
                proverstate, mode="error", verbose=verbose
            )
        else:
            proverstate, chat_state = prover_kernel(
                proverstate, mode="next_tactic", verbose=verbose
            )

        num_api_calls += 1

        proverstates.append(proverstate)
        chat_states.append(chat_state)

        if not proverstate.errors and not proverstate.goals:
            stop_reason = "done"
            break

    return {
        "proverstates": proverstates,
        "chat_states": chat_states,
        "stop_reason": stop_reason,
        "num_api_calls": num_api_calls
    }

def autoformalize_sketch(code: str, sketch: str, max_api_calls=10, verbose=False) -> Dict:
    replpath = os.environ.get("PATH_TO_LEAN_REPL")
    lean = ProofSearch(replpath)
    lean_state = lean.run_code(code.strip() + "\n", verbose=verbose)
    goals, errors = goals_errors_of_lean_state(lean_state)

    proverstate = ProverState(
            sketch=sketch,
            code=code, 
            goals=goals, 
            errors=errors, 
    )

    return prover(proverstate, max_api_calls=max_api_calls, verbose=verbose)

def f2f_prove(code: str, max_api_calls=10, verbose=False) -> Dict:
    sketch = sketch_of_code(code, verbose=verbose)
    return autoformalize_sketch(code, sketch, max_api_calls=max_api_calls, verbose=verbose)


In [25]:
import fire
import os
import sys
import time
import json
from pathlib import Path
from tqdm import tqdm
from itertools import islice

# from pysagredo.util import CONFIG
# from pysagredo.llm import *
# from pysagredo.gym import ProofSearch
# from pysagredo.prove import f2f_prove, autoformalize_sketch

from pydantic import BaseModel

import tiktoken
from datasets import load_dataset

import code 

import argparse

def test_proofsearch(source): 
    print("testing repl...")
    path = os.environ["PATH_TO_LEAN_REPL"]
    print(f"path to repl: {path}")
    proofsearch = ProofSearch(path_to_repl=path)
    out = proofsearch.run_code(source, verbose=True)
    print(out)
    print("repl worked...")

def _main(args):
    """
    Only for debugging/testing purposes
    """

    path = args.input

    start_time = time.time()

    if args.prove and path != "minif2f": 
        source = open(path).read() 
        test_proofsearch(source)
        summary = f2f_prove(source, max_api_calls=10, verbose=True)  

    elif args.prove and path=="minif2f": 
        Path(args.logdir).mkdir(parents=True, exist_ok=True)
        dataset = load_dataset("hoskinson-center/minif2f-lean4", split="validation")
        for x in tqdm(islice(dataset, 2)): 
            print(x)
            eyed = x["id"]
            source = x["header"] + "\n\n" + x["formal_statement"]
            test_proofsearch(source) # sanity check to make sure repl works
            summary = f2f_prove(source, max_api_calls=10, verbose=True)  
            
            print(f"saving summary for {id}...")
            with open(args.logdir, "w") as f: 
                json.dump(summary, f)
                
    elif args.autoformalize: 
        source = json.load(open(path))
        test_proofsearch(source)
        sketch = source["sketch"]
        code = source["code"]
        summary = autoformalize_sketch(code, sketch, max_api_calls=10, verbose=True)

    end_time = time.time()

    print("\nSUMMARY\nSTOP REASON: ", summary["stop_reason"])
    print(f'{summary["num_api_calls"]} interactions in {end_time-start_time:.2f} seconds')

    #code.interact(local=locals())

    program = summary["proverstates"][-1].code

    print(program)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('-a', '--autoformalize', type=bool, default='True')
    parser.add_argument('-p', '--prove', type=bool, default='True')
    parser.add_argument('-i', '--input', type=str, default="minif2f")
    parser.add_argument('-l', '--logdir', type=str, default="./logs")
    args = parser.parse_args('')
    _main(args)


True


0it [00:00, ?it/s]

{'id': 'amc12a_2019_21', 'header': 'import Mathlib.Algebra.BigOperators.Basic\nimport Mathlib.Data.Real.Basic\nimport Mathlib.Data.Complex.Basic\nimport Mathlib.Data.Nat.Log\nimport Mathlib.Data.Complex.Exponential\nimport Mathlib.NumberTheory.Divisors\nimport Mathlib.Data.ZMod.Defs\nimport Mathlib.Data.ZMod.Basic\nimport Mathlib.Topology.Basic\nimport Mathlib.Data.Nat.Digits\n\nopen BigOperators\nopen Real\nopen Nat\nopen Topology', 'formal_statement': 'theorem amc12a_2019_p21\n  (z : ℂ)\n  (h₀ : z = (1 + Complex.I) / Real.sqrt 2) :\n  (∑ k in Finset.Icc 1 12, (z^(k^2))) * (∑ k in Finset.Icc 1 12, (1 / z^(k^2))) = 36 := sorry', 'informal_stmt': 'Let $z=\\frac{1+i}{\\sqrt{2}}.$What is $\\left(z^{1^2}+z^{2^2}+z^{3^2}+\\dots+z^{{12}^2}\\right) \\cdot \\left(\\frac{1}{z^{1^2}}+\\frac{1}{z^{2^2}}+\\frac{1}{z^{3^2}}+\\dots+\\frac{1}{z^{{12}^2}}\\right)?$\n\n$\\textbf{(A) } 18 \\qquad \\textbf{(B) } 72-36\\sqrt2 \\qquad \\textbf{(C) } 36 \\qquad \\textbf{(D) } 72 \\qquad \\textbf{(E) } 72+36

0it [04:03, ?it/s]


KeyboardInterrupt: 

In [28]:
from pylean import LeanServer
from pprint import pprint

code = """
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) : Nat.gcd m n = 1 := by {}
"""

lean = LeanServer()
state = lean.run_code(code)
lean.proc.close()
pprint(state)

/localscratch/hsun409/github/repl
{'env': 0,
 'messages': [{'data': 'unsolved goals\nm n : ℕ\n⊢ Nat.gcd m n = 1',
               'endPos': {'column': 55, 'line': 4},
               'pos': {'column': 54, 'line': 4},
               'severity': 'error'}]}


In [30]:
from pylean import LeanServer
from pprint import pprint

code = """
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.Coprime n) : m.gcd n = 1 := by {}
"""

lean = LeanServer()
state = lean.run_code(code)
lean.proc.close()
pprint(state)

/localscratch/hsun409/github/repl
{'env': 0,
 'messages': [{'data': 'unsolved goals\n'
                       'm n : ℕ\n'
                       'h : Nat.Coprime m n\n'
                       '⊢ Nat.gcd m n = 1',
               'endPos': {'column': 69, 'line': 4},
               'pos': {'column': 68, 'line': 4},
               'severity': 'error'}]}


In [31]:
def get_goal(state):
    goal = None
    for msg in state['messages']:
        if msg['data'].startswith('unsolved goals\n'):
            goal = '\n'.join(msg['data'].split('\n')[1:])
        elif msg['severity'] == 'error':
            return None
    return goal

print(get_goal(state))


m n : ℕ
h : Nat.Coprime m n
⊢ Nat.gcd m n = 1


In [32]:
# Load model and tokenizer
import os
import transformers
model_name = 'wellecks/llmstep-mathlib4-pythia2.8b'
model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)
tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # prevents an annoying warning


def generate(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    out = model.generate(
        input_ids,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id
    )
    text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
    return text

config.json: 100%|█████████████████████████████████████████████████████████████████████████████| 653/653 [00:00<00:00, 871kB/s]
pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████| 5.68G/5.68G [02:39<00:00, 35.7MB/s]
  return self.fget.__get__(instance, owner)()
generation_config.json: 100%|██████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 158kB/s]
tokenizer_config.json: 100%|███████████████████████████████████████████████████████████████████| 264/264 [00:00<00:00, 374kB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 16.0MB/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 159kB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [33]:


# Generate a next step
prompt = f"[GOAL]{get_goal(state)}[PROOFSTEP]"

next_step = generate(prompt)
print(next_step)



rw [← gcd_eq_gcd_ab, gcd_comm]


In [41]:
import pexpect
import json
import os

class LeanServer:
    def __init__(self):
        # Get the path where you download repl
        path_to_repl = os.environ.get('PATH_TO_LEAN_REPL')

        # Run the command
        self.proc = pexpect.spawn(
            "lake env lean --run REPL/Main.lean", 
            cwd=path_to_repl, 
            encoding="utf-8")

    def run_code(self, code, env=None, verbose=False):
        if env:
            command = (
                json.dumps(dict(cmd=code, env=env))
            )  # [1:-1] removes single quotes
        else:
            command = (
                '{ "cmd" : "' + repr(code)[1:-1] + '" }'
            )  # [1:-1] removes single quotes

        if verbose: print(command)
        self.proc.sendline(command)
        self.proc.expect_exact(command + "\r\n")
        self.proc.sendline()
        self.proc.expect_exact("\r\n")
        try:
            index = self.proc.expect('env": \d+\}', timeout=20)
            output = self.proc.before + self.proc.match.group()
            if verbose: print(output)
            return json.loads(output)
            
        except pexpect.exceptions.TIMEOUT:
            raise pexpect.exceptions.TIMEOUT

In [40]:
code = """
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.Coprime n) : m.gcd n = 1 := by 

""" + 'rw [← h.gcd_eq_one]'

lean = LeanServer()
state = lean.run_code(code)
lean.proc.close()

pprint(state)

{'env': 0}


In [None]:
lean_state

In [14]:
!pip install fire

Collecting fire
  Using cached fire-0.5.0-py2.py3-none-any.whl
Collecting termcolor (from fire)
  Using cached termcolor-2.4.0-py3-none-any.whl.metadata (6.1 kB)
Using cached termcolor-2.4.0-py3-none-any.whl (7.7 kB)
Installing collected packages: termcolor, fire
Successfully installed fire-0.5.0 termcolor-2.4.0


In [3]:
!which pip

/localscratch/hsun409/anaconda3/envs/jepa/bin/pip


In [6]:
!pip install backoff

Collecting backoff
  Downloading backoff-2.2.1-py3-none-any.whl (15 kB)
Installing collected packages: backoff
Successfully installed backoff-2.2.1
