In [1]:
import getpass
import os
import re
from typing import Dict, List, TypedDict

import sympy as sp
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")

llm = ChatOpenAI(model='gpt-4o-mini-2024-07-18', temperature=0, seed=42)

In [2]:
class AgentState(TypedDict):
    messages: List[MessagesState]
    task: str
    plan_string: str
    steps: List
    results: Dict
    result: str


# Define a sympy-based tool
@tool
def calculate(expression: str) -> str:
    """Calculate an arithmetic expression using sympy.

    Args:
        expression: arithmetic expression as a string (e.g., '3 + 4')

    Returns:
        The evaluated result as a string.
    """
    try:
        expr = sp.sympify(expression)
        evaluated = expr.evalf() if expr.is_number else expr
        return str(evaluated)
    except Exception as e:
        return f"Error: {e}"

# Define a LLM tool
@tool
def llm_tool(expression: str) -> str:
    """Use LLM like yourself to process the input string.
    This is useful for tasks that require reasoning or understanding of the context.
    For example, you can use it to provide explanations.

    Args:
        expression: input string to be processed by LLM

    Returns:
        The processed output as a string.
    """
    response = llm.invoke([SystemMessage(content=expression)])
    return response.content


# Augment the LLM with the sympy tool
tools = [calculate, llm_tool]
tools_by_name = {tool.name: tool for tool in tools}

In [3]:
planer_prompt_system = """
### INSTRUCTIONS ###
You are a math problem solving planner. For the following task, make plans that can solve \
the problem step by step. For each plan, indicate which external tool together with tool input to retrieve \
evidence. You can store the evidence into a variable #E that can be called by later tools \
(Plan, #E1, Plan, #E2, Plan, ...)

Tools can be one of the following:
(1) Calculate[input]: A tool that is used for solving math expressions using sympy.
(2) LLM[input]: A pretrained LLM like yourself. Useful when you need to act with general world knowledge and \
common sense. Prioritize it when you are confident in solving the problem yourself. Input can be any instruction.

### EXAMPLE ###
Task: A person has $500. They spend 40% on groceries, 20% on utilities, and 10% on transportation. \
How much money do they have left?

Plan: Calculate the total amount spent on groceries, utilities, and transportation.
#E1 = Calculate[0.40 * 500 + 0.20 * 500 + 0.10 * 500]

Plan: Subtract the total amount spent from the initial amount to find the remaining money.
#E2 = Calculate[500 - #E1]
"""

planer_prompt_human = """
### YOUR TASK ###
Describe your plans with rich details. Each Plan should be followed by only one #E.

Task: {task}
Plan:
"""

planer_prompt_template = ChatPromptTemplate.from_messages(
    [
        ("system", planer_prompt_system),
        ("user", planer_prompt_human)
    ]
)

solver_prompt = """
### INSTRUCTIONS ###
Solve the following task or problem. To solve the problem, we have made step-by-step Plan and \
retrieved corresponding Evidence to each Plan. Use them with caution since long evidence might \
contain irrelevant information.

Respond with the answer in a format: "Answer: <value>". Value should be a number without \
any units. If you are not sure about the answer, respond with "I don't know".

### TASK ###
{task}

### PLAN ###
{plan}

### ANSWER ###
"""

In [4]:
# Planner node
def planner(state: AgentState):
    """Generate a step-by-step plan to solve the problem"""
    regex_pattern = r"Plan:\s*(.+)\s*(#E\d+)\s*=\s*(\w+)\s*\[([^\]]+)\]"

    task = state["task"]

    prompt = planer_prompt_template.format_messages(task=task)
    response = llm.invoke(prompt)

    # Extract the plan string from the response
    matches = re.findall(regex_pattern, response.content)
    return {"steps": matches, "plan_string": response.content, "messages": state["messages"] + [response]}

def _get_current_task(state: AgentState):
    if "results" not in state or state["results"] is None:
        return 1
    if len(state["results"]) == len(state["steps"]):
        return None
    else:
        return len(state["results"]) + 1

# Executor node
def executor(state: AgentState):
    """Worker node that executes the tools of a given plan."""
    _step = _get_current_task(state)
    try:
        _, step_name, tool_name, tool_input = state["steps"][_step - 1]
    except IndexError:
        # There was an error in the plan
        return {"messages": state["messages"] + [SystemMessage(content="Error in the plan.")]}

    _results = (state["results"] or {}) if "results" in state else {}
    for k, v in _results.items():
        tool_input = tool_input.replace(k, v)
    if tool_name == "Calculate":
        result = calculate.invoke(tool_input)
    elif tool_name == "LLM":
        result = llm_tool.invoke(tool_input)
    else:
        raise ValueError
    _results[step_name] = str(result)

    tool_message = ToolMessage(content=f"{tool_input}\nResult: {result}", artifact=result, tool_call_id=step_name)

    return {"results": _results, "messages": state["messages"] + [tool_message]}

def solve(state: AgentState):
    plan = ""
    for _plan, step_name, tool_name, tool_input in state["steps"]:
        _results = (state["results"] or {}) if "results" in state else {}
        for k, v in _results.items():
            tool_input = tool_input.replace(k, v)
            step_name = step_name.replace(k, v)
        plan += f"Plan: {_plan}\n{step_name} = {tool_name}[{tool_input}]"
    prompt = solver_prompt.format(plan=plan, task=state["task"])
    result = llm.invoke(prompt)
    return {"result": result.content, "messages": state["messages"] + [result]}

def _route(state: AgentState):
    _step = _get_current_task(state)
    if _step is None:
        # We have executed all tasks
        return "solve"
    else:
        # We are still executing tasks, loop back to the "tool" node
        return "tool"

graph = StateGraph(AgentState)

graph.add_node("plan", planner)
graph.add_node("tool", executor)
graph.add_node("solve", solve)

graph.add_edge("plan", "tool")
graph.add_edge("solve", END)
graph.add_conditional_edges("tool", _route)
graph.add_edge(START, "plan")

agent = graph.compile()


task = "Indras has 6 letters in her name. Her sister's name has 4 more letters than half of the letters in Indras' name. How many letters are in Indras and her sister's names?"

messages = [HumanMessage(content=task)]

# result = agent.invoke({"task": task, "messages": messages})
# # Print full conversation with steps
# for m in result["messages"]:
#     m.pretty_print()

# # Print full conversation with steps
# for m in agent.stream({"task": task, "messages": messages}):
#     print(m)

## Calculate metrics

In [5]:
from datasets import load_dataset

# Load the GSM8K dataset
dataset = load_dataset("gsm8k", "main")

# Access the training and test splits
train_dataset = dataset["train"]
test_dataset = dataset["test"]

val_sample = test_dataset.shuffle(seed=42).select(range(100))

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import re

import pandas as pd
from tqdm import tqdm


def check_answer(model_answer, true_answer):
    true_answer = true_answer.replace(",", "")

    return str(model_answer) in true_answer


def process_problem(problem):
    # Get the math problem and the correct answer
    math_problem = problem["question"]
    correct_answer = problem["answer"].split("### ")[1]

    human_message = HumanMessage(content=math_problem)

    # Generate the model's response asynchronously
    model_response = agent.invoke({"task": problem, "messages": [human_message]}, {"recursion_limit": 50})
    # Extract the content of the last AI message
    content = model_response["messages"][-1].content

    # Use regex to parse the numerical answer
    try:
        model_ans = re.search(r"Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)", content).group(1).strip()
        is_correct = check_answer(model_ans, correct_answer)
    except Exception:
        is_correct = False
    return "\n".join(m.pretty_repr() for m in model_response["messages"]), is_correct, problem["question"], problem["answer"]


def bench(val_sample):
    k = 0
    model_answers = []

    # Process all tasks with progress tracking
    for problem in tqdm(val_sample, total=len(val_sample)):
        content, is_correct, problem, answer = process_problem(problem)
        model_answers.append({
            "Agent answer": content,
            "is_correct": is_correct,
            "Problem": problem,
            "Right Answer": answer
        })
        if is_correct:
            k += 1
    print(f"Precision: {k / len(val_sample)}")

    return pd.DataFrame(model_answers)


In [9]:
results = bench(val_sample)

100%|██████████| 100/100 [07:52<00:00,  4.72s/it]

Precision: 1.0





In [10]:
# Add a symbol for proper import in Excel
results['Agent answer'] = results['Agent answer'].apply(lambda x: '-' + x)

# Save the results to an Excel file
results.to_excel("results/rewoo_agent_bench.xlsx", index=False)