In [32]:
import getpass
import os
import re
from typing import Literal

import sympy as sp
from datasets import load_dataset
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from tqdm import tqdm


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("OPENAI_API_KEY")

llm = ChatOpenAI(model='gpt-4o-mini-2024-07-18', temperature=0)

In [33]:
# Define a single sympy-based tool
@tool
def calculate(expression: str) -> str:
    """Calculate an arithmetic expression using sympy.

    Args:
        expression: arithmetic expression as a string (e.g., '3 + 4')

    Returns:
        The evaluated result as a string.
    """
    try:
        expr = sp.sympify(expression)
        evaluated = expr.evalf() if expr.is_number else expr
        return str(evaluated)
    except Exception as e:
        return f"Error: {e}"

# Define a LLM tool
@tool
def llm_tool(expression: str) -> str:
    """Use LLM like yourself to process the input string.
    This is useful for tasks that require reasoning or understanding of the context.
    For example, you can use it to provide explanations.

    Args:
        expression: input string to be processed by LLM

    Returns:
        The processed output as a string.
    """
    response = llm.invoke([SystemMessage(content=expression)])
    return response.content

# Augment the LLM with the sympy tool
tools = [calculate, llm_tool]
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)

In [34]:
prompt_system = """
### INSTRUCTIONS ###
1. You are a math problem solver. You are given a math problem and you need to solve it using two tools.
2.Respond with the answer in a format: "Answer: <value>". Value should be a number without any units. \
If you are not sure about the answer, respond with "I don't know".

TOOLS:
(1) Calculate[input]: A tool that is used for solving math expressions using sympy.
(2) LLM[input]: A pretrained LLM like yourself. Useful when you need to act with general world knowledge and \
common sense. Prioritize it when you are confident in solving the problem yourself. Input can be any instruction.

### EXAMPLE ###
## TASK
Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there \
are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How \
many flowers does Mark have in his garden?
## ANSWER
There are 80/100 * 10 = <<80/100*10=8>>8 more purple flowers than yellow flowers.
So in Mark's garden, there are 10 + 8 = <<10+8=18>>18 purple flowers.
Purple and yellow flowers sum up to 10 + 18 = <<10+18=28>>28 flowers.
That means in Mark's garden there are 25/100 * 28 = <<25/100*28=7>>7 green flowers.
So in total Mark has 28 + 7 = <<28+7=35>>35 plants in his garden.
Answer: 35
"""

prompt_human = """
### YOUR TASK ###
Solve the following math problem
TASK: {task}
ANSWER:
"""

prompt_template = ChatPromptTemplate.from_messages([
    ("system", prompt_system),
    ("human", prompt_human)
])

In [40]:
# Nodes
def llm_call(state: MessagesState):
    """LLM decides whether to use tools or answer directly"""
    task = next(msg.content for msg in state["messages"] if isinstance(msg, HumanMessage))
    messages = prompt_template.format_messages(task=task)
    response = llm_with_tools.invoke(messages)
    return {"messages": state["messages"] + [response]}

def tool_node(state: dict):
    """Execute tool calls"""
    result = []
    for tool_call in state["messages"][-1].tool_calls:
        tool = tools_by_name[tool_call["name"]]
        observation = tool.invoke(tool_call["args"])
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
    return {"messages": state["messages"] + result}


def should_continue(state: MessagesState) -> Literal["tools", "end"]:
    last_message = state["messages"][-1]
    return "tools" if last_message.tool_calls else "end"

workflow = StateGraph(MessagesState)
workflow.add_node("llm_call", llm_call)
workflow.add_node("tools", tool_node)

workflow.add_edge(START, "llm_call")
workflow.add_conditional_edges(
    "llm_call",
    should_continue,
    {
        "tools": "tools",
        "end": END,
    }
)
workflow.add_edge("tools", "llm_call")

agent = workflow.compile()

task = "James is a first-year student at a University in Chicago. He has a budget of $1000 per semester. He spends 30% \
of his money on food, 15% on accommodation, 25% on entertainment, and the rest on coursework materials. How much money \
does he spend on coursework materials?"

result = agent.invoke({"messages": [HumanMessage(content=task)]})

for m in result["messages"]:
    print(m.pretty_print())

GraphRecursionError: Recursion limit of 25 reached without hitting a stop condition. You can increase the limit by setting the `recursion_limit` config key.
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/GRAPH_RECURSION_LIMIT

In [None]:
# dataset = load_dataset("gsm8k", "main")
# test_dataset = dataset["test"]
# val_sample = test_dataset.shuffle(seed=42).select(range(100))

# def check_answer(model_answer, true_answer):
#     try:
#         true_answer = true_answer.split(',')

#         true_answer = [int(x) for x in true_answer]
#     except Exception:
#         true_answer = [int(true_answer)]

#     return int(model_answer) in true_answer

# k = 0
# model_answers = []

# for problem in tqdm(val_sample):
#     # Get the math problem and the correct answer
#     math_problem = problem['question']
#     correct_answer = problem['answer'].split("### ")[1]

#     # Generate the model's response using your agent
#     state = agent.invoke({"messages": [HumanMessage(content=math_problem)]})
#     model_response = state["messages"][-1]  # Get final response
#     model_answers.append(model_response.content)

#     # Use regex to parse the numerical answer exactly as before
#     try:
#         model_ans = re.search(r'Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)', model_response.content).group(1).strip()
#         k += 1 if check_answer(model_ans, correct_answer) else 0
#     except Exception:
#         continue

# print(f"Precision: {k/len(val_sample)}")
