In [1]:
import getpass
import json
import os
from typing import Annotated

import sympy as sp
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("OPENAI_API_KEY")

llm = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0, seed=42)

In [2]:
@tool
def calculate_expression(expression, variables):
    """
    Evaluates a mathematical expression by substituting given variable values.

    Parameters:
        expression (str):
            - A string representing a mathematical expression to be evaluated.
            - The expression can include standard mathematical operations and functions (e.g., "x**2 + 2*x + 1", "sin(pi/4)").

        variables (dict of str: str):
            - A dictionary where:
                - The key is the variable name (str).
                - The value is the corresponding value to substitute for the variable in the expression. (str)
            - Example: {"x": "3", "pi": "3.14159265359"}.

    Returns:
        float:
            - The result of the evaluated expression after substituting the given variable values.
            - If an error occurs, an error message string is returned.

    Example usage:
        >>> calculate_expression("x**2 + 2*x + 1", {"x": "3"})
        "16.0"

        >>> calculate_expression("sin(pi/4) + cos(pi/4)", {"pi":" 3.14159265359"})
        "1.4142135623730951"

        >>> calculate_expression("x*y + z", {"x": "2", "y": "3", "z": "1"})
        "7.0"
    """
    try:
        # Create symbols dynamically and assign values
        sym_vars = {k: sp.symbols(k) for k in variables.keys()}

        # Convert the string values to float
        variables = {k: float(v) for k, v in variables.items()}

        # Convert the string to a SymPy expression
        expr = sp.sympify(expression)

        # Substitute the variables with the given values
        result = expr.subs(sym_vars).evalf(subs=variables)

        return str(result)
    except Exception as e:
        return f"Error: {e}"

In [3]:
# @tool
# def solve_equation(equations, variables):
#     """
#     Solves a system of equations or a single equation using SymPy.

#     Parameters:
#         equations (str or list of str):
#             - A single equation as a string (e.g., "x**2 - 4 = 0").
#             - A list of equations as strings (e.g., ["x + y = 5", "x - y = 3"]).
#             - Each equation must contain the '=' sign to be considered valid.
#             - If an equation contains a term with a variable and a coefficient (e.g., "7x + 11x"), ensure that multiplication is explicitly indicated using '*' (e.g., "7*x + 11*x").

#         variables (dict of str: value):
#             - A dictionary where:
#                 - The key is a variable name (str).
#                 - The value is either a number "None" (if unknown). Ensure that the value is a string.
#             - Example: {"x": '5.32, 'y': 'None'}.

#     Returns:
#         list of dict:
#             - A list of dictionaries where keys (str) are variables (str) and values are the solutions.
#             - Example: [{"x": "2"}, {"x": "-2"}].
#             - If there are no solutions or an error occurs, returns a descriptive error message.
#     Notes:
#         - The function uses SymPy to solve the equations.
#         - Ensure that multiplication is explicitly indicated using '*' (e.g., "7*x + 11*x").
#         - The function can handle both single equations and systems of equations.

#     Example usage:
#         >>> solve_equation("x**2 - 4 = 0", {"x": "None"})
#         [{'x': "2"}, {'x': '-2'}]

#         >>> solve_equation(["x + y = 5", "x - y = 3"], {"x": "None", "y": "None"})
#         [{'x': "4", 'y': '1'}]

#         >>> solve_equation(["x + y = 5"], {"x": "4", "y": "None"})
#         [{'y': '1'}]
#     """


#     try:
#         # Create symbols dynamically based on variable names
#         sym_vars = {k: sp.symbols(k) for k in variables.keys()}

#         # Convert single equation to a list for consistent handling
#         if isinstance(equations, str):
#             equations = [equations]

#         # Parse equations and create SymPy Eq objects
#         eqs = []
#         for eq in equations:
#             if "=" not in eq:
#                 return f"Error: Invalid equation format '{eq}' (missing '=')"
#             left, right = eq.split("=")

#             eqs.append(sp.Eq(sp.sympify(left), sp.sympify(right)))

#         # Separate known and unknown variables
#         known_vars = {sym_vars[k]: v for k, v in variables.items() if v != "None"}
#         unknown_vars = [sym_vars[k] for k, v in variables.items() if v == "None"]

#         # Solve the system of equations
#         solution = sp.solve(eqs, unknown_vars, dict=True)

#         # Substitute known variables into the solution
#         if known_vars:
#             solution = [{var: sol[var].subs(known_vars) for var in sol} for sol in solution]

#         return str(solution)
#     except Exception as e:
#         return f"Error: {e}"


In [4]:
# Define a LLM tool
@tool
def llm_tool(expression: str) -> str:
    """Use LLM like yourself to process the input string.
    This is useful for tasks that require reasoning or understanding of the context.
    For example, you can use it to provide explanations.

    Args:
        expression: input string to be processed by LLM

    Returns:
        The processed output as a string.
    """
    response = llm.invoke([SystemMessage(content=expression)])
    return response.content


# Augment the LLM with the sympy tool
tools = [calculate_expression, llm_tool]
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)

In [25]:
class AgentState(TypedDict):
    messages: Annotated[list, add_messages]


def chatbot(state: AgentState):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}


class BasicToolNode:
    """A node that runs the tools requested in the last AIMessage."""

    def __init__(self, tools: list) -> None:
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        if messages := inputs.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")
        outputs = []
        for tool_call in message.tool_calls:
            tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
            outputs.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        return {"messages": outputs}


def route_tools(
    state: AgentState,
):
    """
    Use in the conditional_edge to route to the ToolNode if the last message
    has tool calls. Otherwise, route to the end.
    """
    if isinstance(state, list):
        ai_message = state[-1]
    elif messages := state.get("messages", []):
        ai_message = messages[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return END


def format_answer(state: AgentState):
    """Final answer formatter using llm_tool"""
    last_ai_message = next(msg for msg in reversed(state["messages"]) if isinstance(msg, AIMessage))

    format_prompt = f"""
    REFINE THIS ANSWER TO EXACT FORMAT:
    Original: {last_ai_message.content}

    Rules:
    1. Extract ONLY the numeric answer
    2. Format EXACTLY as: "Answer: <number>"
    3. Never add explanations or other text
    """

    formatted = llm_tool(format_prompt)

    return {"messages": [HumanMessage(content=formatted)]}

#------------------
solver_prompt = """
### INSTRUCTIONS ###
1) You are a math solver.
2) You are given a task to solve.
3) Look at the previous messages to find the relevant information.
4) Respond with the answer in a format: "Answer: <value>". Value should be a number without \
any units. If you are not sure about the answer, respond with "I don't know".

### TASK ###
{task}

### ANSWER ###
"""

def solve(state: AgentState):

    prompt = solver_prompt.format(task=state["messages"][-1].content)
    # Use LLM to provide the final solution for the task
    result = llm.invoke(state["messages"] + [SystemMessage(content=prompt)])
    return {"messages": state["messages"] + [result]}
# -----------------

tool_node = BasicToolNode(tools=tools)

workflow = StateGraph(AgentState)
workflow.add_node("chatbot", chatbot)
workflow.add_node("tools", tool_node)
workflow.add_node("format_answer", solve)

workflow.add_edge(START, "chatbot")
workflow.add_conditional_edges(
    "chatbot",
    route_tools,
    {
        "tools": "tools",
        END: "format_answer",
    },
)
workflow.add_edge("tools", "chatbot")
workflow.add_edge("format_answer", END)

agent = workflow.compile()

task = "Darrell and Allen's ages are in the ratio of 7:11. If their total age now is 162, calculate Allen's age 10 years from now."


# for m in agent.stream({"messages": [HumanMessage(content=task)]}):
#     print(m)


# result = agent.invoke({"messages": [HumanMessage(content=task)]})
# for m in result["messages"]:
#     m.pretty_print()

## Calculate metrics

In [26]:
from datasets import load_dataset

# Load the GSM8K dataset
dataset = load_dataset("gsm8k", "main")

# Access the training and test splits
train_dataset = dataset["train"]
test_dataset = dataset["test"]

val_sample = test_dataset.shuffle(seed=42).select(range(100))

In [27]:
import re

import pandas as pd
from tqdm import tqdm


def check_answer(model_answer, true_answer):
    try:
        true_answer = true_answer.split(",")

        true_answer = [int(x) for x in true_answer]
    except Exception:
        true_answer = [int(true_answer)]

    return int(model_answer) in true_answer


def process_problem(problem):
    # Get the math problem and the correct answer
    math_problem = problem["question"]
    correct_answer = problem["answer"].split("### ")[1]

    human_message = HumanMessage(content=math_problem)

    # Generate the model's response asynchronously
    model_response = agent.invoke({"messages": [human_message]})
    # Extract the content of the last AI message
    content = model_response["messages"][-1].content

    # Use regex to parse the numerical answer
    try:
        model_ans = re.search(r"Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)", content).group(1).strip()
        is_correct = check_answer(model_ans, correct_answer)
    except Exception:
        is_correct = False
    return "\n".join(m.pretty_repr() for m in model_response["messages"]), is_correct, problem["question"], problem["answer"]


def bench(val_sample):
    k = 0
    model_answers = []

    # Process all tasks with progress tracking
    for problem in tqdm(val_sample, total=len(val_sample)):
        content, is_correct, problem, answer = process_problem(problem)
        model_answers.append({
            "Agent answer": content,
            "is_correct": is_correct,
            "Problem": problem,
            "Right Answer": answer
        })
        if is_correct:
            k += 1
    print(f"Precision: {k / len(val_sample)}")

    return pd.DataFrame(model_answers)


In [28]:
results = bench(val_sample)

100%|██████████| 100/100 [08:41<00:00,  5.22s/it]

Precision: 0.88





In [29]:
# Add a symbol for proper import in Excel
results['Agent answer'] = results['Agent answer'].apply(lambda x: '-' + x)

# Save the results to an Excel file
results.to_excel("results/react_agent_bench.xlsx", index=False)