In [None]:
import subprocess
import time
import os
import re
from openai import OpenAI

# Lean Config
LEAN_FILE_PATH = r"D:\MSc Research Project\MathlibProject\MathlibProject.lean"
LEAN_PROJECT_DIR = r"D:\MSc Research Project\MathlibProject"
MAX_ATTEMPTS = 100

# DeepSeek API Setup
client = OpenAI(
    api_key="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
    base_url="https://api.deepseek.com"
)

problem = (
    "Suppose the roots of the polynomial $x^2 - mx + n$ are positive prime integers (not necessarily distinct). "
    "Given that $m < 20,$ how many possible values of $n$ are there?"
)

# ----- Prompts -----
INITIAL_PROMPT = (
    "Given the following math word problem, return only the required steps/formula, only required explanation, and the final answer. "
    "Only main explanation, reduce the number of steps. No extra text. \n\n."
    "Problem: {problem}"
)

RETRY_SOLVE_PROMPT_TEMPLATE = (
    "Given the following math problem:\n"
    "{problem}\n\n"
    "You previously solved it like this:\n"
    "{previous_answer}\n\n"
    "There is an error in your reasoning.\n"
    "Retry solving the problem again carefully. "
    "Return only the required steps/formula, only required explanation, and the final answer. "
    "Only main explanation, reduce the number of steps. No extra text. \n\n"
)

LEAN_CODE_GENERATION_PROMPT = """
You are given a natural language solution to a math problem. 
Extract:
1. The list of valid prime pairs (p, q)
2. The set of distinct n values
3. The count of distinct n values

Then output Lean code using the template exactly. Replace placeholders with extracted values.

Template:
import Mathlib.Data.Nat.Prime.Basic
import Mathlib.Data.List.Basic
import Mathlib.Data.Finset.Basic
import Mathlib.Data.Finset.Card
import Init.Data.Nat.Basic

open Nat
open Finset
open List

def validPrimePairs : List (ℕ × ℕ) :=
  let primes := (List.range 20).filter Nat.Prime
  List.foldr (· ++ ·) [] (
    primes.map (fun p =>
      (List.range 20).filter Nat.Prime |>.filterMap (fun q =>
        if p ≤ q ∧ p + q < 20 then some (p, q) else none))
  )

def inputPairsList : List (ℕ × ℕ) :=
  <PRIME_PAIRS_LIST>

theorem verify_eqv : validPrimePairs.isEqv inputPairsList (· == ·) = true := by
  decide

def nValues : List ℕ := validPrimePairs.map (fun (p, q) => p * q)

def distinctNValues : Finset ℕ := nValues.toFinset

def inputNValues : Finset ℕ :=
  <N_VALUES_SET>

theorem verify_nValues_correct : distinctNValues = inputNValues := by
  decide

def inputCount : Nat := <COUNT>

theorem verify_cardinality_correct : distinctNValues.card = inputCount := by
  rfl

Use following problems as examples to generate the Lean code with given output format.

Example 1

Input:
**Steps/Formula:**
1. Let the roots be \( p \) and \( q \), where \( p \) and \( q \) are positive prime integers.
2. The polynomial is \( x^2 - (p+q)x + pq \), so \( m = p + q \) and \( n = pq \).
3. Given \( m < 20 \), find all pairs \((p, q)\) such that \( p + q < 20 \).

**Explanation:**
- List all prime pairs \((p, q)\) where \( p \leq q \) and \( p + q < 20 \).
- Calculate \( n = pq \) for each pair and count the distinct values.

**Prime Pairs and \( n \):**
- (2, 2): \( n = 4 \)
- (2, 3): \( n = 6 \)
- (2, 5): \( n = 10 \)
- (2, 7): \( n = 14 \)
- (2, 11): \( n = 22 \) (invalid, \( m = 13 < 20 \))
- (2, 13): \( n = 26 \) (invalid, \( m = 15 < 20 \))
- (2, 17): \( n = 34 \) (invalid, \( m = 19 < 20 \))
- (3, 3): \( n = 9 \)
- (3, 5): \( n = 15 \)
- (3, 7): \( n = 21 \)
- (3, 11): \( n = 33 \) (invalid, \( m = 14 < 20 \))
- (3, 13): \( n = 39 \) (invalid, \( m = 16 < 20 \))
- (5, 5): \( n = 25 \)
- (5, 7): \( n = 35 \)
- (5, 11): \( n = 55 \) (invalid, \( m = 16 < 20 \))
- (7, 7): \( n = 49 \)
- (7, 11): \( n = 77 \) (invalid, \( m = 18 < 20 \))
- (11, 11): \( n = 121 \) (invalid, \( m = 22 \geq 20 \))

**Valid \( n \) values:**
4, 6, 9, 10, 14, 15, 21, 25, 35, 49

**Final Answer:**
\(\boxed{{10}}\)

Output:
import Mathlib.Data.Nat.Prime.Basic
import Mathlib.Data.List.Basic
import Mathlib.Data.Finset.Basic
import Mathlib.Data.Finset.Card
import Init.Data.Nat.Basic

open Nat
open Finset
open List

def validPrimePairs : List (ℕ × ℕ) :=
  let primes := (List.range 20).filter Nat.Prime
  List.foldr (· ++ ·) [] (
    primes.map (fun p =>
      (List.range 20).filter Nat.Prime |>.filterMap (fun q =>
        if p ≤ q ∧ p + q < 20 then some (p, q) else none))
  )

def inputPairsList : List (ℕ × ℕ) :=
  [
  (2, 2),
  (2, 3),
  (2, 5),
  (2, 7),
  (3, 3),
  (3, 5),
  (3, 7),
  (5, 5),
  (5, 7),
  (7, 7)
]

theorem verify_eqv : validPrimePairs.isEqv inputPairsList (· == ·) = true := by
  decide

def nValues : List ℕ := validPrimePairs.map (fun (p, q) => p * q)

def distinctNValues : Finset ℕ := nValues.toFinset

def inputNValues : Finset ℕ :=
  ({{4, 6, 9, 10, 14, 15, 21, 25, 35, 49}} : Finset ℕ)

theorem verify_nValues_correct : distinctNValues = inputNValues := by
  decide

def inputCount : Nat := 10

theorem verify_cardinality_correct : distinctNValues.card = inputCount := by
  rfl

Example 2:

Input:

**Steps/Formula:**
1. Let the roots be \( p \) and \( q \), where \( p \) and \( q \) are positive prime integers.
2. The polynomial is \( x^2 - (p+q)x + pq \), so \( m = p + q \) and \( n = pq \).
3. Given \( m < 20 \), find all pairs \((p, q)\) such that \( p + q < 20 \).

**Explanation:**
- List all prime pairs \((p, q)\) where \( p \leq q \) and \( p + q < 20 \).
- Calculate \( n = pq \) for each pair and count the distinct values.

**Prime Pairs and \( n \):**
- (2, 2): \( n = 4 \)
- (2, 3): \( n = 6 \)
- (2, 5): \( n = 10 \)
- (2, 7): \( n = 14 \)
- (2, 11): \( n = 22 \)
- (2, 13): \( n = 26 \)
- (2, 17): \( n = 34 \)
- (3, 3): \( n = 9 \)
- (3, 5): \( n = 15 \)
- (3, 7): \( n = 21 \)
- (3, 11): \( n = 33 \)
- (3, 13): \( n = 39 \)
- (5, 5): \( n = 25 \)
- (5, 7): \( n = 35 \)
- (5, 11): \( n = 55 \)
- (7, 7): \( n = 49 \)
- (7, 11): \( n = 77 \)

**Valid \( n \) values:**
4, 6, 9, 10, 14, 15, 21, 22, 25, 26, 33, 34, 35, 39, 49, 55, 77

**Final Answer:**
\(\boxed{{17}}\)

Output:
import Mathlib.Data.Nat.Prime.Basic
import Mathlib.Data.List.Basic
import Mathlib.Data.Finset.Basic
import Mathlib.Data.Finset.Card
import Init.Data.Nat.Basic

open Nat
open Finset
open List

def validPrimePairs : List (ℕ × ℕ) :=
  let primes := (List.range 20).filter Nat.Prime
  List.foldr (· ++ ·) [] (
    primes.map (fun p =>
      (List.range 20).filter Nat.Prime |>.filterMap (fun q =>
        if p ≤ q ∧ p + q < 20 then some (p, q) else none))
  )

def inputPairsList : List (ℕ × ℕ) :=
  [
  (2, 2),
  (2, 3),
  (2, 5),
  (2, 7),
  (2, 11),
  (2, 13),
  (2, 17),
  (3, 3),
  (3, 5),
  (3, 7),
  (3, 11),
  (3, 13),
  (5, 5),
  (5, 7),
  (5, 11),
  (7, 7),
  (7, 11)
]

theorem verify_eqv : validPrimePairs.isEqv inputPairsList (· == ·) = true := by
  decide

def nValues : List ℕ := validPrimePairs.map (fun (p, q) => p * q)

def distinctNValues : Finset ℕ := nValues.toFinset

def inputNValues : Finset ℕ :=
  ({{4, 6, 9, 10, 14, 15, 21, 22, 25, 26, 33, 34, 35, 39, 49, 55, 77}} : Finset ℕ)

theorem verify_nValues_correct : distinctNValues = inputNValues := by
  decide

def inputCount : Nat := 17

theorem verify_cardinality_correct : distinctNValues.card = inputCount := by
  rfl

Input:
{solution_text}

Output:
"""


# --- Helper functions ---

def call_deepseek(user_prompt: str, system_prompt: str) -> str:
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.0,
        max_tokens=2048,
        stream=False
    )
    return response.choices[0].message.content.strip()

def write_lean_code(code: str):
    with open(LEAN_FILE_PATH, "w", encoding="utf-8") as f:
        f.write(code)

def run_lean_code() -> tuple[bool, str]:
    result = subprocess.run(
        ["lake", "build"],
        cwd=LEAN_PROJECT_DIR,
        capture_output=True,
        text=True
    )
    stdout, stderr = result.stdout, result.stderr

    success = "error" not in stderr.lower()

    output = f"=== STDOUT ===\n{stdout}\n\n=== STDERR ===\n{stderr}"

    # Clear Lean file after build
    with open(LEAN_FILE_PATH, "w", encoding="utf-8") as f:
        f.write("")

    return success, output


# --- Main Feedback Loop ---

def feedback_loop(problem_text: str, max_attempts: int = MAX_ATTEMPTS):
    print("\n Asking DeepSeek to solve the problem...\n")
    user_prompt = INITIAL_PROMPT.format(problem=problem_text)
    response = call_deepseek(user_prompt, system_prompt="You are a concise math problem solver.")

    for attempt in range(1, max_attempts + 1):
        print(f"\n Attempt {attempt} ----------------------------")
        print(f"\n DeepSeek Response (natural language solution):\n{response}")

        lean_code_prompt = LEAN_CODE_GENERATION_PROMPT.format(solution_text=response)
        lean_code = call_deepseek(lean_code_prompt, system_prompt="You are a Lean code generator")
        print("\nLean code:\n", lean_code)
        write_lean_code(lean_code)

        print("\n Running Lean build...")
        success, lean_output = run_lean_code()
        print(f"\n Lean Output ({'Success' if success else 'Failure'}):\n{lean_output}")

        if success:
            print("\n Reasoning verified by Lean successfully.")
            return response
        else:
          print("\n Lean build failed. Sending error feedback to DeepSeek.")
    
          retry_prompt = RETRY_SOLVE_PROMPT_TEMPLATE.format(
            problem=problem_text,
            previous_answer=response,
          )

          print(f"\n Retry Prompt:\n{retry_prompt}") 
    
          response = call_deepseek(user_prompt, system_prompt="You are a concise math problem solver.")
          time.sleep(1)

    print("\n Maximum attempts reached. Verification unsuccessful.")
    return None

# === Run Example ===
if __name__ == "__main__":
    feedback_loop(problem)
