In [4]:
import getpass
import json
import os
from typing import Annotated

import sympy as sp
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("OPENAI_API_KEY")

llm = ChatOpenAI(model='gpt-4o-mini-2024-07-18', temperature=0)

In [5]:
# Define a sympy-based tool
@tool
def calculate(expression: str) -> str:
    """Calculate an arithmetic expression using sympy.

    Args:
        expression: arithmetic expression as a string (e.g., '3 + 4')

    Returns:
        The evaluated result as a string.
    """
    try:
        expr = sp.sympify(expression)
        evaluated = expr.evalf() if expr.is_number else expr
        return str(evaluated)
    except Exception as e:
        return f"Error: {e}"

# Define a LLM tool
@tool
def llm_tool(expression: str) -> str:
    """Use LLM like yourself to process the input string.
    This is useful for tasks that require reasoning or understanding of the context.
    For example, you can use it to provide explanations.

    Args:
        expression: input string to be processed by LLM

    Returns:
        The processed output as a string.
    """
    response = llm.invoke([SystemMessage(content=expression)])
    return response.content

# Augment the LLM with the sympy tool
tools = [calculate, llm_tool]
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)

In [6]:
class AgentState(TypedDict):
    messages: Annotated[list, add_messages]


def chatbot(state: AgentState):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}

class BasicToolNode:
    """A node that runs the tools requested in the last AIMessage."""

    def __init__(self, tools: list) -> None:
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        if messages := inputs.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")
        outputs = []
        for tool_call in message.tool_calls:
            tool_result = self.tools_by_name[tool_call["name"]].invoke(
                tool_call["args"]
            )
            outputs.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        return {"messages": outputs}

def route_tools(
    state: AgentState,
):
    """
    Use in the conditional_edge to route to the ToolNode if the last message
    has tool calls. Otherwise, route to the end.
    """
    if isinstance(state, list):
        ai_message = state[-1]
    elif messages := state.get("messages", []):
        ai_message = messages[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return END

def format_answer(state: AgentState):
    """Final answer formatter using llm_tool"""
    last_ai_message = next(
        msg for msg in reversed(state["messages"])
        if isinstance(msg, AIMessage)
    )

    format_prompt = f"""
    REFINE THIS ANSWER TO EXACT FORMAT:
    Original: {last_ai_message.content}

    Rules:
    1. Extract ONLY the numeric answer
    2. Format EXACTLY as: "Answer: <number>"
    3. Never add explanations or other text
    """

    formatted = llm_tool(format_prompt)

    return {"messages": [HumanMessage(content=formatted)]}

tool_node = BasicToolNode(tools=tools)

workflow = StateGraph(AgentState)
workflow.add_node("chatbot", chatbot)
workflow.add_node("tools", tool_node)
workflow.add_node("format_answer", format_answer)

workflow.add_edge(START, "chatbot")
workflow.add_conditional_edges(
    "chatbot",
    route_tools,
    {
        "tools": "tools",
        END: "format_answer",
    }
)
workflow.add_edge("tools", "chatbot")
workflow.add_edge("format_answer", END)

agent = workflow.compile()

task = "James is a first-year student at a University in Chicago. He has a budget of $1000 per semester. He spends 30% \
of his money on food, 15% on accommodation, 25% on entertainment, and the rest on coursework materials. How much money \
does he spend on coursework materials?"


# for m in agent.stream({"messages": [HumanMessage(content=task)]}):
#     print(m)
result = agent.invoke({"messages": [HumanMessage(content=task)]})

for m in result["messages"]:
    m.pretty_print()


James is a first-year student at a University in Chicago. He has a budget of $1000 per semester. He spends 30% of his money on food, 15% on accommodation, 25% on entertainment, and the rest on coursework materials. How much money does he spend on coursework materials?
Tool Calls:
  calculate (call_2cYqxtJIyq2yaswVnbEmdvYJ)
 Call ID: call_2cYqxtJIyq2yaswVnbEmdvYJ
  Args:
    expression: 1000 * (1 - 0.30 - 0.15 - 0.25)
Name: calculate

"300.000000000000"

James spends $300 on coursework materials.

Answer: 300


In [7]:
# dataset = load_dataset("gsm8k", "main")
# test_dataset = dataset["test"]
# val_sample = test_dataset.shuffle(seed=42).select(range(100))

# def check_answer(model_answer, true_answer):
#     try:
#         true_answer = true_answer.split(',')

#         true_answer = [int(x) for x in true_answer]
#     except Exception:
#         true_answer = [int(true_answer)]

#     return int(model_answer) in true_answer

# k = 0
# model_answers = []

# for problem in tqdm(val_sample):
#     # Get the math problem and the correct answer
#     math_problem = problem['question']
#     correct_answer = problem['answer'].split("### ")[1]

#     # Generate the model's response using your agent
#     state = agent.invoke({"messages": [HumanMessage(content=math_problem)]})
#     model_response = state["messages"][-1]  # Get final response
#     model_answers.append(model_response.content)

#     # Use regex to parse the numerical answer exactly as before
#     try:
#         model_ans = re.search(r'Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)', model_response.content).group(1).strip()
#         k += 1 if check_answer(model_ans, correct_answer) else 0
#     except Exception:
#         continue

# print(f"Precision: {k/len(val_sample)}")
