In [None]:
import re
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex
# load from os.env
import os
from dotenv import load_dotenv
load_dotenv()
# Set the OpenAI API key from environment variable
# or directly set it here
api_key= os.getenv("OPENAI_API_KEY")
import re
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex
import openai
from openai import OpenAI

# Initialize the client using your API key.
client = OpenAI(api_key=api_key)

def normalize_expr(expr):
    """
    Recursively traverse the expression and replace any instance of
    log(expr, E) with log(expr). If expr is a tuple (e.g., representing an
    equation with left and right sides), normalize each element.
    """
    # If expr is a tuple (e.g., (lhs, rhs)), apply normalization to each part.
    if isinstance(expr, tuple):
        return tuple(normalize_expr(e) for e in expr)
    
    # Otherwise, expr should be a Sympy object
    if expr.is_Atom:
        return expr
    new_args = tuple(normalize_expr(arg) for arg in expr.args)
    if expr.func == sp.log and len(new_args) == 2 and new_args[1] == sp.E:
        return sp.log(new_args[0])
    return expr.func(*new_args)

def llm_to_sympy(expr_str: str) -> sp.Expr:
    """
    Convert a mathematical expression string to a Sympy expression using an LLM-based conversion.
    
    This function calls the LLM API and then parses the returned string (which should be a valid 
    Sympy expression) into a Sympy object.
    """
    result = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": """You are a Sympy expression converter. Return only the final Sympy expression 
(e.g., (1/S(2))*log(...) + C). Do not include code fences or additional text. 
No imports, definitions, or explanations. Just one line: the Sympy expression."""},
            {"role": "user", "content": f"""
"You are an agent that converts mathematical expressions to valid Sympy expressions. 
Output only the Sympy expression itself on a single line, with no imports, no extra 
Python code, and no explanations. Use standard Sympy functions like exp, log, sin, 
cos, etc. The result must parse successfully with 'sympy.parse_expr'. 

Convert the following expression to Sympy format: {expr_str}"
"""}
        ],
    )
    output = result.choices[0].message.content.strip()
    
    # Define a local dictionary for standard functions and constants.
    local_dict = {
        'cos': sp.cos,
        'sin': sp.sin,
        'tan': sp.tan,
        'log': sp.log,
        'exp': sp.exp,
        'Derivative': sp.Derivative,
        'Integral': sp.Integral,
        'E': sp.E,
        'e': sp.E,
        'C': sp.Symbol('C')
    }
    try:
        sympy_expr = parse_expr(output, local_dict=local_dict)
        return sympy_expr
    except Exception as e:
        raise ValueError(f"Error parsing LLM output:\nOutput: {output}\n{e}")

def preprocess_expression(expr_str: str):
    """
    Preprocess an expression string to produce a Sympy object using an LLM-based conversion.
    
    If the string contains an "=" sign, it is split into left-hand and right-hand parts.
    For each part, the llm_to_sympy function is called to obtain its Sympy expression.
    """
    expr_str = expr_str.strip()
    if '=' in expr_str:
        left_str, right_str = expr_str.split('=', 1)
        left_expr = preprocess_expression(left_str.strip())
        right_expr = preprocess_expression(right_str.strip())
        return (left_expr, right_expr)
    else:
        # Call the LLM-based function to convert this expression to a Sympy object.
        return llm_to_sympy(expr_str)

def compare_expressions(expr1, expr2) -> bool:
    """
    Compare two Sympy expressions (or equation tuples). First, normalize both expressions
    (e.g., converting log(expr, E) to log(expr)). Then compare both sides for equations,
    or simply check if the difference is zero for single expressions.
    """
    norm_expr1 = normalize_expr(expr1)
    norm_expr2 = normalize_expr(expr2)
    
    if isinstance(norm_expr1, tuple) and isinstance(norm_expr2, tuple):
        left_equal = sp.simplify(norm_expr1[0] - norm_expr2[0]) == 0
        right_equal = sp.simplify(norm_expr1[1] - norm_expr2[1]) == 0
        return left_equal and right_equal
    else:
        return sp.simplify(norm_expr1 - norm_expr2) == 0

def compare_pipeline(expr_str1: str, expr_str2: str) -> bool:
    """
    The complete pipeline that:
      1. Checks if an expression is an equation (using '='), splitting it if necessary.
      2. Calls the LLM-based conversion (llm_to_sympy) to get the Sympy expression.
      3. Normalizes and compares the resulting expressions.
      
    The function prints the parsed expressions and returns whether they are equivalent.
    """
    try:
        parsed_expr1 = preprocess_expression(expr_str1)
        parsed_expr2 = preprocess_expression(expr_str2)
    except ValueError as e:
        print("Parsing error:", e)
        return False
    
    equivalent = compare_expressions(parsed_expr1, parsed_expr2)
    
    print("Parsed Expression 1:", parsed_expr1)
    print("Parsed Expression 2:", parsed_expr2)
    print("The expressions are equivalent." if equivalent else "The expressions are NOT equivalent.")
    return equivalent

# ---------------- Example Usage ----------------
if __name__ == "__main__":
    # Example with plain text vs. LaTeX. They should be equivalent.
    expr_plain = r'x = (2√5 - 1) / 2'
    expr_latex = r'x = \dfrac{1 + \sqrt{2} + \sqrt{2\sqrt{2} - 1}}{2}'
    
    print("Comparing expressions")
    result = compare_pipeline(expr_plain, expr_latex)
    print("Result:", "Equivalent" if result else "Not equivalent")


Comparing expressions
Parsed Expression 1: (x, -1/2 + sqrt(5))
Parsed Expression 2: (x, 1/2 + sqrt(-1 + 2*sqrt(2))/2 + sqrt(2)/2)
The expressions are NOT equivalent.
Result: Not equivalent
