In [None]:
import subprocess
import time
import random
import re
from openai import OpenAI

# === Configuration ===

LEAN_FILE_PATH = r"D:\MSc Research Project\MathlibProject\MathlibProject.lean"
LEAN_PROJECT_DIR = r"D:\MSc Research Project\MathlibProject"
MAX_ATTEMPTS = 1000 

# DeepSeek API Setup
client = OpenAI(
    api_key="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
    base_url="https://api.deepseek.com"
)

problem = (
    "The digits $1$, $2$, $3$, $4$, $5$, $6$, $7$, and $9$ are used to form four two-digit prime numbers, "
    "with each digit used exactly once. What is the sum of these four primes? "
)

# === Prompts ===

INITIAL_PROMPT = (
    "Given the following math word problem, provide a **concise and brief** solution: include only the minimal necessary explanation and the final answer.\n"
    "Return the answer in the format: \boxed{{answer}}.\n"
    "Do **not** include detailed steps, just enough explanation to justify the answer."
    "Problem: {problem}"
)

RETRY_SOLVE_ERROR_PROMPT = (
    "Given the following math problem:\n"
    "{problem}\n\n"
    "You previously solved it like this:\n"
    "{previous_answer}\n\n"
    "But it had the following error:\n"
    "{error}\n\n"
    "Retry from scratch\n"
    "Provide a **concise and brief** solution: include only the minimal necessary explanation and the final answer.\n"
    "Return the answer in the format: \boxed{{answer}}.\n"
    "Do **not** include detailed steps, just enough explanation to justify the answer."
)

BINARY_CORRECT_PROMPT = (
    "Given the following math problem:\n"
    "{problem}\n\n"
    "You solved it like this:\n"
    "{previous_answer}\n\n"
    "This answer was incorrect. Retry the problem from scratch.\n"
    "Provide a **concise and brief** solution: include only the minimal necessary explanation and the final answer.\n"
    "Return the answer in the format: \boxed{{answer}}.\n"
    "Do **not** include detailed steps, just enough explanation to justify the answer."
)

RETRY_SOLVE_WITHOUT_PREVIOUS_ANSWER_PROMPT = (
    "Given the following math problem:\n"
    "{problem}\n\n"
    "Your previous answer had the following error:\n"
    "{error}\n\n"
    "Retry the problem from scratch.\n"
    "Provide a **concise and brief** solution: include only the minimal necessary explanation and the final answer.\n"
    "Return the answer in the format: \boxed{{answer}}."
    "Do **not** include detailed steps, just enough explanation to justify the answer."
)

LEAN_CODE_GENERATION_PROMPT = """
You are given a natural language solution to a math problem. 
Extract the valid list of 4 two digit numbers (used to calculate the sum) from the solution and the sum of the list (final answer), then output Lean code using the template exactly. 
Replace placeholders with extracted values.

Template:
import Mathlib.Data.List.Basic
open Nat List
def digits (n : Nat) : List Nat :=
  [n / 10, n % 10]
def digitsDistinct (n : Nat) : Bool :=
  let ds := digits n
  ds.length = ds.eraseDup.length
def listElemBool (a : Nat) (l : List Nat) : Bool :=
  l.any (fun x => x = a)
def digitsInSet (allowed : List Nat) (n : Nat) : Bool :=
  let ds := digits n
  ds.all (fun d => listElemBool d allowed)
def isPrimeBool (n : Nat) : Bool :=
  if n < 2 then false else
  !((List.range (n - 2)).any (fun d => n % (d + 2) = 0))
def allDigits (nums : List Nat) : List Nat :=
  nums.flatMap digits
def allDigitsDistinct (nums : List Nat) : Bool :=
  let ds := allDigits nums
  ds.length = ds.eraseDup.length
def allowedDigits : List Nat := [1, 2, 3, 4, 5, 6, 7, 9]
def inputList : List Nat := <REPLACE_WITH_SOLUTION_LIST>
def inputSum : Nat := <REPLACE_WITH_SOLUTION_SUM>
def allPrime : Bool := inputList.all isPrimeBool
def noDigit8 : Bool := inputList.all (digitsInSet allowedDigits)
def allDistinctDigits : Bool := allDigitsDistinct inputList
def sumOfList : Nat := inputList.foldl (· + ·) 0
def sumCorrect : Bool := inputList.foldl (· + ·) 0 = inputSum
#eval ("All prime?", allPrime)
#eval ("No digit 8?", noDigit8)
#eval ("All digits distinct?", allDistinctDigits)
#eval ("Sum correct?", sumCorrect)

Use following problems as examples to generate the Lean code with given output format.

Example 1:

Input:
 ### Steps:
1. **List all two-digit primes using digits 1-9 without repetition.**
   - Possible primes: 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97.
2. **Select four primes that use each digit (1, 2, 3, 4, 5, 6, 7, 9) exactly once.**
   - One valid combination: 13, 29, 47, 61 (digits used: 1,2,3,4,6,7,9).
   - Another valid combination: 13, 29, 67, 41 (digits used: 1,2,3,4,6,7,9).
   - Another valid combination: 17, 23, 59, 41 (digits used: 1,2,3,4,5,7,9).
   - Another valid combination: 17, 23, 59, 61 (digits used: 1,2,3,5,6,7,9).
   - Another valid combination: 17, 29, 43, 61 (digits used: 1,2,3,4,6,7,9).
3. **Calculate the sum for one valid combination.**
   - For 13, 29, 47, 61:  
     Sum = 13 + 29 + 47 + 61 = 150.
### Final Answer:
\boxed{{150}}

Output:
import Mathlib.Data.List.Basic
open Nat List
def digits (n : Nat) : List Nat :=
  [n / 10, n % 10]
def digitsDistinct (n : Nat) : Bool :=
  let ds := digits n
  ds.length = ds.eraseDup.length
def listElemBool (a : Nat) (l : List Nat) : Bool :=
  l.any (fun x => x = a)
def digitsInSet (allowed : List Nat) (n : Nat) : Bool :=
  let ds := digits n
  ds.all (fun d => listElemBool d allowed)
def isPrimeBool (n : Nat) : Bool :=
  if n < 2 then false else
  !((List.range (n - 2)).any (fun d => n % (d + 2) = 0))
def allDigits (nums : List Nat) : List Nat :=
  nums.flatMap digits
def allDigitsDistinct (nums : List Nat) : Bool :=
  let ds := allDigits nums
  ds.length = ds.eraseDup.length
def allowedDigits : List Nat := [1, 2, 3, 4, 5, 6, 7, 9]
def inputList : List Nat := [13, 29, 47, 61]
def inputSum : Nat := 150
def allPrime : Bool := inputList.all isPrimeBool
def noDigit8 : Bool := inputList.all (digitsInSet allowedDigits)
def allDistinctDigits : Bool := allDigitsDistinct inputList
def sumOfList : Nat := inputList.foldl (· + ·) 0
def sumCorrect : Bool := inputList.foldl (· + ·) 0 = inputSum
#eval ("All prime?", allPrime)
#eval ("No digit 8?", noDigit8)
#eval ("All digits distinct?", allDistinctDigits)
#eval ("Sum correct?", sumCorrect)

Example 2:

Input:
 ### Steps:
1. **List all two-digit primes using digits 1-9 without repetition:**
   - 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97.
2. **Select four primes using each digit (1, 2, 3, 4, 5, 6, 7, 9) exactly once:**
   - Possible set: 61, 43, 47, 29 (uses all digits without repetition).
3. **Calculate the sum:**
   - \(61 + 43 + 47 + 29 = 180\).
### Final Answer:
\boxed{{180}}

Output:
import Mathlib.Data.List.Basic
open Nat List
def digits (n : Nat) : List Nat :=
  [n / 10, n % 10]
def digitsDistinct (n : Nat) : Bool :=
  let ds := digits n
  ds.length = ds.eraseDup.length
def listElemBool (a : Nat) (l : List Nat) : Bool :=
  l.any (fun x => x = a)
def digitsInSet (allowed : List Nat) (n : Nat) : Bool :=
  let ds := digits n
  ds.all (fun d => listElemBool d allowed)
def isPrimeBool (n : Nat) : Bool :=
  if n < 2 then false else
  !((List.range (n - 2)).any (fun d => n % (d + 2) = 0))
def allDigits (nums : List Nat) : List Nat :=
  nums.flatMap digits
def allDigitsDistinct (nums : List Nat) : Bool :=
  let ds := allDigits nums
  ds.length = ds.eraseDup.length
def allowedDigits : List Nat := [1, 2, 3, 4, 5, 6, 7, 9]
def inputList : List Nat := [61, 43, 47, 29]
def inputSum : Nat := 180
def allPrime : Bool := inputList.all isPrimeBool
def noDigit8 : Bool := inputList.all (digitsInSet allowedDigits)
def allDistinctDigits : Bool := allDigitsDistinct inputList
def sumOfList : Nat := inputList.foldl (· + ·) 0
def sumCorrect : Bool := inputList.foldl (· + ·) 0 = inputSum
#eval ("All prime?", allPrime)
#eval ("No digit 8?", noDigit8)
#eval ("All digits distinct?", allDistinctDigits)
#eval ("Sum correct?", sumCorrect)

Now do this for the following input:
{initial_solution}
"""

LEAN_CODE_INTERPRETATION_PROMPT = """
    Given the given Lean code input, and its output, interpret the results and provide a concise explanation of the error.
    Do not include any additional information or context, just the interpretation of the Lean output.
    
    Use the following examples as a guide for your response:

    Example 1:
    
    Input:

    Lean code:

    import Mathlib.Data.List.Basic
    open Nat List
    def digits (n : Nat) : List Nat :=
    [n / 10, n % 10]
    def digitsDistinct (n : Nat) : Bool :=
    let ds := digits n
    ds.length = ds.eraseDup.length
    def listElemBool (a : Nat) (l : List Nat) : Bool :=
    l.any (fun x => x = a)
    def digitsInSet (allowed : List Nat) (n : Nat) : Bool :=
    let ds := digits n
    ds.all (fun d => listElemBool d allowed)
    def isPrimeBool (n : Nat) : Bool :=
    if n < 2 then false else
    !((List.range (n - 2)).any (fun d => n % (d + 2) = 0))
    def allDigits (nums : List Nat) : List Nat :=
    nums.flatMap digits
    def allDigitsDistinct (nums : List Nat) : Bool :=
    let ds := allDigits nums
    ds.length = ds.eraseDup.length
    def allowedDigits : List Nat := [1, 2, 3, 4, 5, 6, 7, 9]
    def inputList : List Nat := [13, 29, 47, 61]
    def inputSum : Nat := 150
    def allPrime : Bool := inputList.all isPrimeBool
    def noDigit8 : Bool := inputList.all (digitsInSet allowedDigits)
    def allDistinctDigits : Bool := allDigitsDistinct inputList
    def sumOfList : Nat := inputList.foldl (· + ·) 0
    def sumCorrect : Bool := inputList.foldl (· + ·) 0 = inputSum
    #eval ("All prime?", allPrime)
    #eval ("No digit 8?", noDigit8)
    #eval ("All digits distinct?", allDistinctDigits)
    #eval ("Sum correct?", sumCorrect)
    
    Lean output:

    info: MathlibProject.lean:30:0: ("All prime?", true)
    info: MathlibProject.lean:31:0: ("No digit 8?", true)
    info: MathlibProject.lean:32:0: ("All digits distinct?", false)
    info: MathlibProject.lean:33:0: ("Sum correct?", true)

    Output:

    The digits in the numbers are not all distinct.

    Example 2:

    Input:

    Lean code:

    import Mathlib.Data.List.Basic
    open Nat List
    def digits (n : Nat) : List Nat :=
    [n / 10, n % 10]
    def digitsDistinct (n : Nat) : Bool :=
    let ds := digits n
    ds.length = ds.eraseDup.length
    def listElemBool (a : Nat) (l : List Nat) : Bool :=
    l.any (fun x => x = a)
    def digitsInSet (allowed : List Nat) (n : Nat) : Bool :=
    let ds := digits n
    ds.all (fun d => listElemBool d allowed)
    def isPrimeBool (n : Nat) : Bool :=
    if n < 2 then false else
    !((List.range (n - 2)).any (fun d => n % (d + 2) = 0))
    def allDigits (nums : List Nat) : List Nat :=
    nums.flatMap digits
    def allDigitsDistinct (nums : List Nat) : Bool :=
    let ds := allDigits nums
    ds.length = ds.eraseDup.length
    def allowedDigits : List Nat := [1, 2, 3, 4, 5, 6, 7, 9]
    def inputList : List Nat := [13, 29, 47, 65]
    def inputSum : Nat := 154
    def allPrime : Bool := inputList.all isPrimeBool
    def noDigit8 : Bool := inputList.all (digitsInSet allowedDigits)
    def allDistinctDigits : Bool := allDigitsDistinct inputList
    def sumOfList : Nat := inputList.foldl (· + ·) 0
    def sumCorrect : Bool := inputList.foldl (· + ·) 0 = inputSum
    #eval ("All prime?", allPrime)
    #eval ("No digit 8?", noDigit8)
    #eval ("All digits distinct?", allDistinctDigits)
    #eval ("Sum correct?", sumCorrect)
    
    Lean output:

    info: MathlibProject.lean:30:0: ("All prime?", false)
    info: MathlibProject.lean:31:0: ("No digit 8?", true)
    info: MathlibProject.lean:32:0: ("All digits distinct?", false)
    info: MathlibProject.lean:33:0: ("Sum correct?", true)

    Output:

    Not all numbers in the list are prime.

    Now do this for the following input:

    Lean code:
    {lean_code}

    Lean output:
    {lean_output}
    """

# === Helper functions ===

def call_deepseek(user_prompt: str, system_prompt: str) -> str:
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.0,
        max_tokens=4000,
        stream=False
    )
    return response.choices[0].message.content.strip()

def write_lean_code(code: str):
    with open(LEAN_FILE_PATH, "w", encoding="utf-8") as f:
        f.write(code)

def run_lean_code() -> tuple[bool, str]:
    result = subprocess.run(
        ["lake", "build"],
        cwd=LEAN_PROJECT_DIR,
        capture_output=True,
        text=True
    )
    stdout, stderr = result.stdout, result.stderr

    # Assume build success if no 'error' in stderr
    success = "error" not in stderr.lower()

    output = f"=== STDOUT ===\n{stdout}\n\n=== STDERR ===\n{stderr}"

    # Clear Lean file after build (optional)
    # with open(LEAN_FILE_PATH, "w", encoding="utf-8") as f:
    #     f.write("")

    return success, output

def clean_lean_code(raw_code: str) -> str:
    """
    Extract Lean code from API response, stripping any markdown fences and leading/trailing text.
    """
    # Remove markdown triple backticks if present
    if raw_code.startswith("```"):
        # Find the first and last triple backtick
        parts = raw_code.split("```")
        # The code is usually the second element after split: ```lang\ncode\n```
        if len(parts) >= 3:
            code = parts[1]
        else:
            code = raw_code
    else:
        code = raw_code

    # Strip leading/trailing whitespace
    code = code.strip()

    return code

# === Main Feedback Loop ===

def feedback_loop(problem_text: str, max_attempts: int = MAX_ATTEMPTS):
    first_try_correct = 0
    no_feedback_correct = 0
    binary_correct = 0
    error_feedback = 0
    error_feedback_without_previous_answer = 0
    incomplete_answer = 0

    for attempt in range(1, max_attempts + 1):
        print(f"\n--- Trial {attempt} ---")

        # Step 1: Initial solve attempt
        user_prompt = INITIAL_PROMPT.format(problem=problem_text)
        print("User Prompt:", user_prompt)
        response = call_deepseek(user_prompt, system_prompt="You are a concise math problem solver.")
        print("\nInitial Response:\n", response)

        match = re.search(r'\\boxed\{([^}]*)\}', response)
        if not match:
          print("No boxed answer found — retrying without incrementing attempt count.")
          incomplete_answer += 1
          attempt -= 1 
          continue  
        
        answer = match.group(1).strip() if match else None

        # Check if answer is correct
        if answer == "190":
            print("First try correct — skipping retries.")
            first_try_correct += 1
            continue  # or `return` to exit loop/function immediately

        else:
            print("First try incorrect, proceeding with retries.")    
            
            lean_code = call_deepseek(
            user_prompt=LEAN_CODE_GENERATION_PROMPT.format(initial_solution=response),
            system_prompt="You are a Lean code generator and verifier for math solutions."
        )
            print("\nGenerated Lean code:\n", lean_code)

            # Write Lean code to file
            clean_code = clean_lean_code(lean_code)
            write_lean_code(clean_code)
            
            # Step 4: Run Lean build and get checks output
            success, lean_output = run_lean_code()
            print(f"Lean build success: {success}")
            print(f"Lean output:\n{lean_output}")  

        # Step 5: Parse Lean output for error messages
        error_description = call_deepseek(
            user_prompt=LEAN_CODE_INTERPRETATION_PROMPT.format(lean_code=clean_code, lean_output=lean_output),
            system_prompt="You are a Lean code interpreter and verifier for math solutions."
        )
        print("Parsed error description:", error_description)

        # Step 6: Prepare retry arms with prompts
        arms = [
            ("no_feedback", INITIAL_PROMPT.format(problem=problem_text)),
            ("binary_correctness", BINARY_CORRECT_PROMPT.format(problem=problem_text, previous_answer=response)),
            ("error_feedback_without_previous_answer", RETRY_SOLVE_WITHOUT_PREVIOUS_ANSWER_PROMPT.format(
                problem=problem_text,
                error=error_description
            )),
            ("error_feedback", RETRY_SOLVE_ERROR_PROMPT.format(
                problem=problem_text,
                previous_answer=response,
                error=error_description
            ))
        ]

        random.shuffle(arms)

        # Step 7: Run retries with different feedback styles
        for arm_type, retry_prompt in arms:
            # Random delay to avoid timing patterns
            time.sleep(random.uniform(0.2, 1.0))
            print("Retry prompt:", retry_prompt)
            retry_response = call_deepseek(retry_prompt, system_prompt="You are a concise math problem solver.")
            print(f"\n[{arm_type.upper()}] Retry Response:\n", retry_response)

            match_retry = re.search(r'\\boxed\{([^}]*)\}', retry_response)
            retry_answer = match_retry.group(1).strip() if match_retry else None

            if retry_answer == "190":
                print(f"[{arm_type.upper()}] Retry correct!")
                if arm_type == "no_feedback":
                    no_feedback_correct += 1
                elif arm_type == "binary_correctness":
                    binary_correct += 1
                elif arm_type == "error_feedback":
                    error_feedback += 1
                elif arm_type == "error_feedback_without_previous_answer":
                    error_feedback_without_previous_answer += 1
            else:
                print(f"[{arm_type.upper()}] Retry incorrect or no boxed answer.")

    # === Final summary ===
    print("\n=== RESULTS ===")
    print(f"First try correct: {first_try_correct}")
    print(f"No feedback correct: {no_feedback_correct}")
    print(f"Binary correctness correct: {binary_correct}")
    print(f"Error feedback without previous answer correct: {error_feedback_without_previous_answer}")
    print(f"Error feedback correct: {error_feedback}")
    print(f"Incomplete answers (no boxed answer): {incomplete_answer}")


# === Run if main ===

if __name__ == "__main__":
    feedback_loop(problem)
