In [1]:
import getpass
import os


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")

In [7]:
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.graph import MessagesState, StateGraph, START, END
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from typing import Literal
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import sympy as sp
import re

llm = ChatOpenAI(model='gpt-4o-mini-2024-07-18', temperature=0)

# Define a single sympy-based tool
@tool
def calculate(expression: str) -> str:
    """Calculate an arithmetic expression using sympy.

    Args:
        expression: arithmetic expression as a string (e.g., '3 + 4')

    Returns:
        The evaluated result as a string.
    """
    try:
        expr = sp.sympify(expression)
        evaluated = expr.evalf() if expr.is_number else expr
        return str(evaluated)
    except Exception as e:
        return f"Error: {e}"
    

@tool
def format_answer(text: str) -> str:
    """Returns only the numerical result (no text)."""
    numbers = re.findall(r"[-+]?\d*\.\d+|\d+", text)
    if numbers:
        return numbers[-1].split('.')[0] if numbers[-1].endswith('.0') else numbers[-1]
    return text



# Augment the LLM with the sympy tool
tools = [calculate, format_answer]
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)

# Nodes
def llm_call(state: MessagesState):
    """LLM decides whether to call the calculate tool"""

    return {
        "messages": [
            llm_with_tools.invoke(
                [
                    SystemMessage(
                        content="You are a math problem solver. Follow these steps:\n"
                                "1. Use calculate tool for math operations\n"
                                "2. Use format_answer to present final results\n"
                                "3. Always return answers in 'Answer: number' format"
                    )
                ]
                + state["messages"]
            )
        ]
    }

def tool_node(state: dict):
    """Performs the sympy tool call"""

    result = []
    for tool_call in state["messages"][-1].tool_calls:
        tool = tools_by_name[tool_call["name"]]
        observation = tool.invoke(tool_call["args"])
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
    return {"messages": result}

# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["environment", END]:
    """Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""

    messages = state["messages"]
    last_message = messages[-1]
    # If the LLM makes a tool call, then perform an action
    if last_message.tool_calls:
        return "Action"
    # Otherwise, we stop (reply to the user)
    return END

# Build workflow
agent_builder = StateGraph(MessagesState)

# Add nodes
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("environment", tool_node)

# Add edges to connect nodes
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges(
    "llm_call",
    should_continue,
    {
        # Name returned by should_continue : Name of next node to visit
        "Action": "environment",
        END: END,
    },
)
agent_builder.add_edge("environment", "llm_call")

# Compile the agent
agent = agent_builder.compile()

# Show the agent
# display(Image(agent.get_graph(xray=True).draw_mermaid_png()))

# Invoke with an example arithmetic query using sympy
# messages = [HumanMessage(content="James is a first-year student at a University in Chicago. He has a budget of $1000 per semester. He spends 30% of his money on food, 15% on accommodation, 25% on entertainment, and the rest on coursework materials. How much money does he spend on coursework materials?")]
# messages = agent.invoke({"messages": messages})
# # for m in messages["messages"]:
# #     print(m)


# final_answer = None

# for m in reversed(messages["messages"]):
#     if hasattr(m, 'content') and m.content and re.search(r'\d+', m.content):
#         final_answer = m.content
#         break

# print(final_answer) 

Answer: 300


In [None]:
dataset = load_dataset("gsm8k", "main")
test_dataset = dataset["test"]
val_sample = test_dataset.shuffle(seed=42).select(range(100))

def check_answer(model_answer, true_answer):
    try:
        true_answer = true_answer.split(',')
        
        true_answer = [int(x) for x in true_answer]
    except Exception:
        true_answer = [int(true_answer)]
    
    return int(model_answer) in true_answer

k = 0
model_answers = []

for problem in tqdm(val_sample):
    # Get the math problem and the correct answer
    math_problem = problem['question']
    correct_answer = problem['answer'].split("### ")[1]

    # Generate the model's response using your agent
    state = agent.invoke({"messages": [HumanMessage(content=math_problem)]})
    model_response = state["messages"][-1]  # Get final response
    model_answers.append(model_response.content)

    # Use regex to parse the numerical answer exactly as before
    try:
        model_ans = re.search(r'Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)', model_response.content).group(1).strip()
        k += 1 if check_answer(model_ans, correct_answer) else 0
    except Exception:
        continue

print(f"Precision: {k/len(val_sample)}")
