Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ccc8459
feat: add bench for demo and supervisor agent
jmatejcz Aug 4, 2025
d96b113
feat: add validator to demo task
jmatejcz Aug 4, 2025
38b3799
feat: customized planner supervisor
jmatejcz Aug 5, 2025
cfd4313
refactor: moved planner supoervisor agent to rai core
jmatejcz Aug 5, 2025
8e386b0
feat: added megamind
jmatejcz Aug 11, 2025
43f9cef
refactor: level arg to logger
jmatejcz Aug 11, 2025
1eaed3e
feat: modifications to megamind
jmatejcz Aug 11, 2025
7941b41
fix: validators in SortTask fix
jmatejcz Aug 13, 2025
7ca6c75
style: formatting fixed
jmatejcz Aug 13, 2025
b331e58
refactor: updated prompts
jmatejcz Aug 13, 2025
9ea3239
style: remove unused code
jmatejcz Aug 18, 2025
68ee2a5
style: remove commented code
jmatejcz Sep 1, 2025
cfef6a7
style: change file name
jmatejcz Sep 1, 2025
100e409
refactor: move plan agent to rai core
jmatejcz Sep 1, 2025
3cae02b
style: format
jmatejcz Sep 1, 2025
b2633e7
style: names change
jmatejcz Sep 1, 2025
1249161
style: renamed example file
jmatejcz Sep 1, 2025
44d4a6d
fix: import fix
jmatejcz Sep 1, 2025
01cb10d
refactor: remove examples from megamind prompt
jmatejcz Sep 2, 2025
658d1ce
style: applied requested changes
jmatejcz Sep 4, 2025
ab78de7
refactor: added toolrunner for subagents
jmatejcz Sep 11, 2025
6d0acc4
refactor: megamind into generic class and fucntion
jmatejcz Sep 11, 2025
6b8b683
chore: delete planner supervisor
jmatejcz Sep 11, 2025
84205ef
fix: remove left import
jmatejcz Sep 12, 2025
2123d25
refactor: change to gpts in example
jmatejcz Sep 12, 2025
6c8f224
refactor: recurssion limit adjutment
jmatejcz Sep 12, 2025
5ad369e
fix: deafult langfuse to false
jmatejcz Sep 12, 2025
a8b3b5d
refactor: move langgraph agents to core
jmatejcz Sep 12, 2025
7ce2cf5
refactor: remove megamind class
jmatejcz Sep 12, 2025
f365ecd
style: applied requested changes
jmatejcz Sep 12, 2025
ac60d58
feat: introduced planning prompt
jmatejcz Sep 12, 2025
c98a3f7
feat: added helper funtions
jmatejcz Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/rai_bench/rai_bench/examples/manipulation_o3de.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
model_name=args.model_name,
vendor=args.vendor,
)

run_benchmark(
llm=llm,
out_dir=experiment_dir,
Expand Down
100 changes: 100 additions & 0 deletions src/rai_bench/rai_bench/examples/tool_calling_custom_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import uuid
from datetime import datetime
from pathlib import Path

from rai.agents.langchain.core import (
Executor,
create_megamind,
get_initial_megamind_state,
)

from rai_bench import (
define_benchmark_logger,
)
from rai_bench.tool_calling_agent.benchmark import ToolCallingAgentBenchmark
from rai_bench.tool_calling_agent.interfaces import TaskArgs
from rai_bench.tool_calling_agent.tasks.warehouse import SortingTask
from rai_bench.utils import get_llm_for_benchmark

if __name__ == "__main__":
now = datetime.now()
out_dir = f"src/rai_bench/rai_bench/experiments/tool_calling/{now.strftime('%Y-%m-%d_%H-%M-%S')}"
experiment_dir = Path(out_dir)
experiment_dir.mkdir(parents=True, exist_ok=True)
bench_logger = define_benchmark_logger(out_dir=experiment_dir, level=logging.DEBUG)

task = SortingTask(task_args=TaskArgs(extra_tool_calls=50))
task.set_logger(bench_logger)

supervisor_name = "gpt-4o"

executor_name = "gpt-4o-mini"
model_name = f"supervisor-{supervisor_name}_executor-{executor_name}"
supervisor_llm = get_llm_for_benchmark(model_name=supervisor_name, vendor="openai")
executor_llm = get_llm_for_benchmark(
model_name=executor_name,
vendor="openai",
)

benchmark = ToolCallingAgentBenchmark(
tasks=[task],
logger=bench_logger,
model_name=model_name,
results_dir=experiment_dir,
)
manipulation_system_prompt = """You are a manipulation specialist robot agent.
Your role is to handle object manipulation tasks including picking up and droping objects using provided tools.

Ask the VLM for objects detection and positions before perfomring any manipulation action.
If VLM doesn't see objects that are objectives of the task, return this information, without proceeding"""

navigation_system_prompt = """You are a navigation specialist robot agent.
Your role is to handle navigation tasks in space using provided tools.

After performing navigation action, always check your current position to ensure success"""

executors = [
Executor(
name="manipulation",
llm=executor_llm,
tools=task.manipulation_tools(),
system_prompt=manipulation_system_prompt,
),
Executor(
name="navigation",
llm=executor_llm,
tools=task.navigation_tools(),
system_prompt=navigation_system_prompt,
),
]
agent = create_megamind(
megamind_llm=supervisor_llm,
megamind_system_prompt=task.get_system_prompt(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also add a planning_prompt for megamind (example) and use it in strucuture_output_node to summarize the overall task progress based on steps done. In my experiments, this helps the decision-making for agent using qwen3 fairly consistently. This will be provided by bench task (example).

Note, the planning_prompt didn't help agent with gpt-4o-mini that much, perhaps we need a different prompt for it. Not sure how gpt-4o works for you, for me, the results vary a lot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea, i copied to this branch here ac60d58

executors=executors,
task_planning_prompt=task.get_planning_prompt(),
)

experiment_id = uuid.uuid4()
benchmark.run_next(
agent=agent,
initial_state=get_initial_megamind_state(task=task.get_prompt()),
experiment_id=experiment_id,
)

bench_logger.info("===============================================================")
bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!")
bench_logger.info("===============================================================")
73 changes: 32 additions & 41 deletions src/rai_bench/rai_bench/tool_calling_agent/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from typing import Iterator, List, Optional, Sequence, Tuple

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.errors import GraphRecursionError
from langgraph.graph.state import CompiledStateGraph
from rai.agents.langchain.core import (
create_conversational_agent,
)
from rai.messages import HumanMultimodalMessage
from rai.agents.langchain.core.react_agent import ReActAgentState

from rai_bench.agents import create_multimodal_to_tool_agent
from rai_bench.base_benchmark import BaseBenchmark, TimeoutException
Expand All @@ -38,9 +38,6 @@
TaskResult,
ToolCallingAgentRunSummary,
)
from rai_bench.tool_calling_agent.tasks.spatial import (
SpatialReasoningAgentTask,
)
from rai_bench.utils import get_llm_model_name


Expand All @@ -67,7 +64,12 @@ def __init__(
self.tasks_results: List[TaskResult] = []
self.csv_initialize(self.results_filename, TaskResult)

def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
def run_next(
self,
agent: CompiledStateGraph,
initial_state: ReActAgentState,
experiment_id: uuid.UUID,
) -> None:
"""Runs the next task of the benchmark.

Parameters
Expand All @@ -87,14 +89,16 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
)
callbacks = self.score_tracing_handler.get_callbacks()
run_id = uuid.uuid4()
# NOTE (jmatejcz) recursion limit calculated as all_nodes_num -> one pass though whole node
# plus (task.max_tool_calls_number-1 because the first pass is already added in)
# times number of nodes - 2 because we dont cout start and end node
# this can be to much for larger graphs that dont use all nodes on extra calls
# in such ase adjust this value
recurssion_limit = len(agent.get_graph().nodes) + (
task.max_tool_calls_number - 1
) * (len(agent.get_graph().nodes) - 2)
# NOTE (jmatejcz) recursion limit calculated as (all_nodes_num - 2) * required tool calls
# -2 because we don't want to include START and END node
# then we add numer of additional calls that can be made
# and +2 as we have to pass once though START and END

recurssion_limit = (
(len(agent.get_graph().nodes) - 2) * task.required_calls
+ task.additional_calls
+ 2
)
config: RunnableConfig = {
"run_id": run_id,
"callbacks": callbacks,
Expand All @@ -113,40 +117,27 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
messages: List[BaseMessage] = []
prev_count: int = 0
try:
with self.time_limit(20 * task.max_tool_calls_number):
if isinstance(task, SpatialReasoningAgentTask):
for state in agent.stream(
{
"messages": [
HumanMultimodalMessage(
content=task.get_prompt(), images=task.get_images()
)
]
},
config=config,
):
node = next(iter(state))
all_messages = state[node]["messages"]
for new_msg in all_messages[prev_count:]:
messages.append(new_msg)
prev_count = len(messages)
else:
for state in agent.stream(
{
"messages": [
HumanMultimodalMessage(content=task.get_prompt())
]
},
config=config,
):
node = next(iter(state))
with self.time_limit(200 * task.max_tool_calls_number):
for state in agent.stream(
initial_state,
config=config,
):
node = next(iter(state))
if "messages" in state[node]:
all_messages = state[node]["messages"]
for new_msg in all_messages[prev_count:]:
messages.append(new_msg)
if isinstance(new_msg, AIMessage):
self.logger.debug(
f"Message from node '{node}': {new_msg.content}, tool_calls: {new_msg.tool_calls}"
)
prev_count = len(messages)
except TimeoutException as e:
self.logger.error(msg=f"Task timeout: {e}")
except GraphRecursionError as e:
tool_calls = task.get_tool_calls_from_messages(messages=messages)
score = task.validate(tool_calls=tool_calls)
score = 0.0
self.logger.error(msg=f"Reached recursion limit {e}")

tool_calls = task.get_tool_calls_from_messages(messages=messages)
Expand Down
8 changes: 8 additions & 0 deletions src/rai_bench/rai_bench/tool_calling_agent/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,14 @@ def max_tool_calls_number(self) -> int:
+ self.extra_tool_calls
)

@property
def additional_calls(self) -> int:
"""number of additional calls that can be done to still pass task.
Includes extra tool calls params.
and optional tool calls number which depends on task.
"""
return self.optional_tool_calls_number + self.extra_tool_calls

@property
def required_calls(self) -> int:
"""Minimal number of calls required to complete task"""
Expand Down
Loading