From 49a869e5aeda654839825337d13f62e23c838efe Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Fri, 8 Aug 2025 02:58:41 -0700 Subject: [PATCH 01/14] feat: Add tracing to LLM invoke (#650) --- .../langchain/core/conversational_agent.py | 7 +- .../rai/agents/langchain/core/react_agent.py | 10 ++- .../langchain/core/state_based_agent.py | 10 ++- .../agents/langchain/invocation_helpers.py | 74 +++++++++++++++++++ 4 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 src/rai_core/rai/agents/langchain/invocation_helpers.py diff --git a/src/rai_core/rai/agents/langchain/core/conversational_agent.py b/src/rai_core/rai/agents/langchain/core/conversational_agent.py index 8b940cdf6..b74d07654 100644 --- a/src/rai_core/rai/agents/langchain/core/conversational_agent.py +++ b/src/rai_core/rai/agents/langchain/core/conversational_agent.py @@ -22,12 +22,14 @@ BaseMessage, SystemMessage, ) +from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt.tool_node import tools_condition from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing class State(TypedDict): @@ -39,6 +41,7 @@ def agent( logger: logging.Logger, system_prompt: str | SystemMessage, state: State, + config: RunnableConfig, ): logger.info("Running thinker") @@ -54,7 +57,9 @@ def agent( else system_prompt ) state["messages"].insert(0, system_msg) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) return state diff --git a/src/rai_core/rai/agents/langchain/core/react_agent.py b/src/rai_core/rai/agents/langchain/core/react_agent.py index d49fd4617..f872add3f 100644 --- a/src/rai_core/rai/agents/langchain/core/react_agent.py +++ b/src/rai_core/rai/agents/langchain/core/react_agent.py @@ -22,12 +22,13 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing from rai.initialization import get_llm_model from rai.messages import SystemMultimodalMessage @@ -48,6 +49,7 @@ def llm_node( llm: BaseChatModel, system_prompt: Optional[str | SystemMultimodalMessage], state: ReActAgentState, + config: RunnableConfig, ): """Process messages using the LLM. @@ -57,6 +59,8 @@ def llm_node( The language model to use for processing state : ReActAgentState Current state containing messages + config : RunnableConfig + Configuration including callbacks for tracing Returns ------- @@ -75,7 +79,9 @@ def llm_node( # at this point, state['messages'] length should at least be 1 if not isinstance(state["messages"][0], SystemMessage): state["messages"].insert(0, SystemMessage(content=system_prompt)) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) diff --git a/src/rai_core/rai/agents/langchain/core/state_based_agent.py b/src/rai_core/rai/agents/langchain/core/state_based_agent.py index 0ffcff6d5..ea651acd3 100644 --- a/src/rai_core/rai/agents/langchain/core/state_based_agent.py +++ b/src/rai_core/rai/agents/langchain/core/state_based_agent.py @@ -26,12 +26,13 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing from rai.initialization import get_llm_model from rai.messages import HumanMultimodalMessage, SystemMultimodalMessage @@ -52,6 +53,7 @@ def llm_node( llm: BaseChatModel, system_prompt: Optional[str | SystemMultimodalMessage], state: ReActAgentState, + config: RunnableConfig, ): """Process messages using the LLM. @@ -61,6 +63,8 @@ def llm_node( The language model to use for processing state : ReActAgentState Current state containing messages + config : RunnableConfig + Configuration including callbacks for tracing Returns ------- @@ -79,7 +83,9 @@ def llm_node( # at this point, state['messages'] length should at least be 1 if not isinstance(state["messages"][0], SystemMessage): state["messages"].insert(0, SystemMessage(content=system_prompt)) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) diff --git a/src/rai_core/rai/agents/langchain/invocation_helpers.py b/src/rai_core/rai/agents/langchain/invocation_helpers.py new file mode 100644 index 000000000..c2a667f29 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/invocation_helpers.py @@ -0,0 +1,74 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, List, Optional + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.runnables import RunnableConfig + +from rai.initialization import get_tracing_callbacks + +logger = logging.getLogger(__name__) + + +def invoke_llm_with_tracing( + llm: BaseChatModel, + messages: List[BaseMessage], + config: Optional[RunnableConfig] = None, +) -> Any: + """ + Invoke an LLM with enhanced tracing callbacks. + + This function automatically adds tracing callbacks (like Langfuse) to LLM calls + within LangGraph nodes, solving the callback propagation issue. + + Parameters + ---------- + llm : BaseChatModel + The language model to invoke + messages : List[BaseMessage] + Messages to send to the LLM + config : Optional[RunnableConfig] + Existing configuration (may contain some callbacks) + + Returns + ------- + Any + The LLM response + """ + tracing_callbacks = get_tracing_callbacks() + + if len(tracing_callbacks) == 0: + # No tracing callbacks available, use config as-is + return llm.invoke(messages, config=config) + + # Create enhanced config with tracing callbacks + enhanced_config = config.copy() if config else {} + + # Add tracing callbacks to existing callbacks + existing_callbacks = config.get("callbacks", []) if config else [] + + if hasattr(existing_callbacks, "handlers"): + # Merge with existing CallbackManager + all_callbacks = existing_callbacks.handlers + tracing_callbacks + elif isinstance(existing_callbacks, list): + all_callbacks = existing_callbacks + tracing_callbacks + else: + all_callbacks = tracing_callbacks + + enhanced_config["callbacks"] = all_callbacks + + return llm.invoke(messages, config=enhanced_config) From 04c7900a05fdd3c2c8de411538ff1850a7711c70 Mon Sep 17 00:00:00 2001 From: Magdalena Kotynia Date: Mon, 11 Aug 2025 15:28:00 +0200 Subject: [PATCH 02/14] feat: vlm benchmark with structured output (#666) Co-authored-by: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> --- docs/simulation_and_benchmarking/rai_bench.md | 1 - docs/tutorials/benchmarking.md | 1 - src/rai_bench/README.md | 16 ++ src/rai_bench/rai_bench/__init__.py | 2 + .../docs/tool_calling_agent_benchmark.md | 2 +- .../rai_bench/examples/benchmarking_models.py | 1 - .../rai_bench/examples/dual_agent.py | 8 +- .../rai_bench/examples/vlm_benchmark.py | 48 ++++ .../langfuse_scores_tracing.py | 4 +- src/rai_bench/rai_bench/test_models.py | 2 - .../rai_bench/tool_calling_agent/benchmark.py | 42 +--- .../tool_calling_agent/predefined/tasks.py | 101 +-------- .../tool_calling_agent/tasks/spatial.py | 126 ----------- src/rai_bench/rai_bench/utils.py | 55 ++--- .../rai_bench/vlm_benchmark/__init__.py | 18 ++ .../rai_bench/vlm_benchmark/benchmark.py | 211 ++++++++++++++++++ .../rai_bench/vlm_benchmark/interfaces.py | 174 +++++++++++++++ .../predefined/images/image_1.jpg | Bin .../predefined/images/image_2.jpg | Bin .../predefined/images/image_3.jpg | Bin .../predefined/images/image_4.jpg | Bin .../predefined/images/image_5.jpg | Bin .../predefined/images/image_6.jpg | Bin .../predefined/images/image_7.jpg | Bin .../vlm_benchmark/predefined/tasks.py | 112 ++++++++++ .../vlm_benchmark/results_tracking.py | 35 +++ .../rai_bench/vlm_benchmark/tasks/tasks.py | 76 +++++++ .../rai/agents/langchain/core/__init__.py | 2 + .../langchain/core/structured_output_agent.py | 60 +++++ 29 files changed, 798 insertions(+), 299 deletions(-) create mode 100644 src/rai_bench/rai_bench/examples/vlm_benchmark.py delete mode 100644 src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/__init__.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/benchmark.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/interfaces.py rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_1.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_2.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_3.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_4.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_5.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_6.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_7.jpg (100%) create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py create mode 100644 src/rai_core/rai/agents/langchain/core/structured_output_agent.py diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index d07905ac5..ea3ff34f1 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -127,7 +127,6 @@ Tasks of this benchmark are grouped by type: - Basic - basic usage of tools - Navigation -- Spatial reasoning - questions about surroundings with images attached - Manipulation - Custom Interfaces - requires using messages with custom interfaces diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index fb4cb663d..db2a56cd6 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -94,7 +94,6 @@ if __name__ == "__main__": extra_tool_calls=5, # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", - "spatial_reasoning", "manipulation", ], repeats=1, diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md index d475ad42c..b9a720aa3 100644 --- a/src/rai_bench/README.md +++ b/src/rai_bench/README.md @@ -163,6 +163,22 @@ python src/rai_bench/rai_bench/examples/tool_calling_agent/main.py --model-name > [!NOTE] > The configs of vendors are defined in [config.toml](../../config.toml) Change ithem if needed. +## VLM Benchmark + +The VLM Benchmark is a benchmark for VLM models. It includes a set of tasks containing questions related to images and evaluates the performance of the agent that returns the answer in the structured format. + +### Running + +To set up tracing backends, please follow the instructions in the [tracing.md](../../docs/tracing.md) document. + +To run the benchmark: + +```bash +cd rai +source setup_shell.sh +python src/rai_bench/rai_bench/examples/vlm_benchmark.py --model-name gemma3:4b --vendor ollama +``` + ## Testing Models To test multiple models, different benchamrks or couple repeats in one go - use script [test_models](./rai_bench/examples/test_models.py) diff --git a/src/rai_bench/rai_bench/__init__.py b/src/rai_bench/rai_bench/__init__.py index 395b5cf9a..9ca566649 100644 --- a/src/rai_bench/rai_bench/__init__.py +++ b/src/rai_bench/rai_bench/__init__.py @@ -22,6 +22,7 @@ get_llm_for_benchmark, parse_manipulation_o3de_benchmark_args, parse_tool_calling_benchmark_args, + parse_vlm_benchmark_args, ) __all__ = [ @@ -31,6 +32,7 @@ "get_llm_for_benchmark", "parse_manipulation_o3de_benchmark_args", "parse_tool_calling_benchmark_args", + "parse_vlm_benchmark_args", "test_dual_agents", "test_models", ] diff --git a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md index 0472170c5..387b01d84 100644 --- a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md +++ b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md @@ -14,4 +14,4 @@ Implementations can be found: - Validators [Validators](../tool_calling_agent/validators.py) - Subtasks [Validators](../tool_calling_agent/tasks/subtasks.py) -- Tasks, including navigation, spatial, custom interfaces and other [Tasks](../tool_calling_agent/tasks/) +- Tasks, including navigation, custom interfaces and other [Tasks](../tool_calling_agent/tasks/) diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index 3a3feefa3..6a962417b 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -35,7 +35,6 @@ extra_tool_calls=5, # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", - "spatial_reasoning", "manipulation", ], repeats=1, diff --git a/src/rai_bench/rai_bench/examples/dual_agent.py b/src/rai_bench/rai_bench/examples/dual_agent.py index cf38cc8d9..5ba89b5ac 100644 --- a/src/rai_bench/rai_bench/examples/dual_agent.py +++ b/src/rai_bench/rai_bench/examples/dual_agent.py @@ -16,7 +16,6 @@ from rai_bench import ( ManipulationO3DEBenchmarkConfig, - ToolCallingAgentBenchmarkConfig, test_dual_agents, ) @@ -29,11 +28,6 @@ tool_llm = ChatOpenAI(model="gpt-4o-mini", base_url="https://api.openai.com/v1/") # Define benchmarks that will be used - tool_conf = ToolCallingAgentBenchmarkConfig( - extra_tool_calls=0, # how many extra tool calls allowed to still pass - task_types=["spatial_reasoning"], - repeats=15, - ) man_conf = ManipulationO3DEBenchmarkConfig( o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config @@ -48,6 +42,6 @@ test_dual_agents( multimodal_llms=[m_llm], tool_calling_models=[tool_llm], - benchmark_configs=[man_conf, tool_conf], + benchmark_configs=[man_conf], out_dir=out_dir, ) diff --git a/src/rai_bench/rai_bench/examples/vlm_benchmark.py b/src/rai_bench/rai_bench/examples/vlm_benchmark.py new file mode 100644 index 000000000..65f8526a6 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/vlm_benchmark.py @@ -0,0 +1,48 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from rai_bench import ( + define_benchmark_logger, + parse_vlm_benchmark_args, +) +from rai_bench.utils import get_llm_for_benchmark +from rai_bench.vlm_benchmark import get_spatial_tasks, run_benchmark + +if __name__ == "__main__": + args = parse_vlm_benchmark_args() + experiment_dir = Path(args.out_dir) + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + try: + tasks = get_spatial_tasks() + for task in tasks: + task.set_logger(bench_logger) + + llm = get_llm_for_benchmark( + model_name=args.model_name, + vendor=args.vendor, + ) + run_benchmark( + llm=llm, + out_dir=experiment_dir, + tasks=tasks, + bench_logger=bench_logger, + ) + except Exception as e: + bench_logger.critical( + msg=f"Benchmark failed with error: {e}", + exc_info=True, + ) diff --git a/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py b/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py index 772974b40..bee9c70d4 100644 --- a/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py +++ b/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py @@ -47,7 +47,7 @@ def send_score( if isinstance(callback, CallbackHandler): callback.langfuse.score( trace_id=str(run_id), - name="tool calls result", + name="result", value=score, comment=comment, ) @@ -55,7 +55,7 @@ def send_score( if isinstance(callback, LangChainTracer): callback.client.create_feedback( run_id=run_id, - key="tool calls result", + key="result", score=score, comment=comment, ) diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index 51f1c7cb5..8be2db7e4 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -64,14 +64,12 @@ class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): "manipulation", "navigation", "custom_interfaces", - "spatial_reasoning", ] ] = [ "basic", "manipulation", "navigation", "custom_interfaces", - "spatial_reasoning", ] @property diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index 32547b90c..b8bcc52af 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -38,9 +38,6 @@ TaskResult, ToolCallingAgentRunSummary, ) -from rai_bench.tool_calling_agent.tasks.spatial import ( - SpatialReasoningAgentTask, -) from rai_bench.utils import get_llm_model_name @@ -106,36 +103,15 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: prev_count: int = 0 try: with self.time_limit(60): - if isinstance(task, SpatialReasoningAgentTask): - for state in agent.stream( - { - "messages": [ - HumanMultimodalMessage( - content=task.get_prompt(), images=task.get_images() - ) - ] - }, - config=config, - ): - node = next(iter(state)) - all_messages = state[node]["messages"] - for new_msg in all_messages[prev_count:]: - messages.append(new_msg) - prev_count = len(messages) - else: - for state in agent.stream( - { - "messages": [ - HumanMultimodalMessage(content=task.get_prompt()) - ] - }, - config=config, - ): - node = next(iter(state)) - all_messages = state[node]["messages"] - for new_msg in all_messages[prev_count:]: - messages.append(new_msg) - prev_count = len(messages) + for state in agent.stream( + {"messages": [HumanMultimodalMessage(content=task.get_prompt())]}, + config=config, + ): + node = next(iter(state)) + all_messages = state[node]["messages"] + for new_msg in all_messages[prev_count:]: + messages.append(new_msg) + prev_count = len(messages) except TimeoutException as e: self.logger.error(msg=f"Task timeout: {e}") except GraphRecursionError as e: diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py index cdeefc8be..26d342485 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py @@ -13,7 +13,7 @@ # limitations under the License. import random -from typing import List, Literal, Sequence +from typing import List, Literal from rai.tools.ros2 import MoveToPointToolInput @@ -44,76 +44,11 @@ NavigateToPointTask, SpinAroundTask, ) -from rai_bench.tool_calling_agent.tasks.spatial import ( - BoolImageTask, - BoolImageTaskInput, -) from rai_bench.tool_calling_agent.validators import ( NotOrderedCallsValidator, OrderedCallsValidator, ) -IMG_PATH = "src/rai_bench/rai_bench/tool_calling_agent/predefined/images/" -true_response_inputs: List[BoolImageTaskInput] = [ - BoolImageTaskInput( - question="Is the door on the left from the desk?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is the light on in the room?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Are there any pictures on the wall?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="Are there 3 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant behind the rack?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Is there a pillow on the armchain?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), -] -false_response_inputs: List[BoolImageTaskInput] = [ - BoolImageTaskInput( - question="Is the door open?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is someone in the room?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="Are there 4 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a rack on the left from the sofa?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant on the right from the window?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is there a red pillow on the armchair?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), -] ########## SUBTASKS ####################################################################################### get_topics_subtask = CheckArgsToolCallSubTask( expected_tool_name="get_ros2_topics_names_and_types", expected_args={} @@ -281,16 +216,6 @@ ), ] -true_spatial_tasks: List[Task] = [ - BoolImageTask( - task_input=input_item, validators=[ret_true_ord_val], extra_tool_calls=0 - ) - for input_item in true_response_inputs -] -false_spatial_tasks: List[Task] = [ - BoolImageTask(task_input=input_item, validators=[ret_false_ord_val]) - for input_item in false_response_inputs -] navigation_tasks: List[Task] = [ NavigateToPointTask(validators=[start_navigate_action_ord_val], extra_tool_calls=5), @@ -391,26 +316,6 @@ def get_custom_interfaces_tasks(extra_tool_calls: int = 0) -> List[Task]: ] -def get_spatial_tasks(extra_tool_calls: int = 0) -> Sequence[Task]: - true_tasks = [ - BoolImageTask( - task_input=input_item, - validators=[ret_true_ord_val], - extra_tool_calls=extra_tool_calls, - ) - for input_item in true_response_inputs - ] - false_tasks = [ - BoolImageTask( - task_input=input_item, - validators=[ret_false_ord_val], - extra_tool_calls=extra_tool_calls, - ) - for input_item in false_response_inputs - ] - return true_tasks + false_tasks - - def get_tasks( extra_tool_calls: int = 0, complexities: List[Literal["easy", "medium", "hard"]] = ["easy", "medium", "hard"], @@ -420,14 +325,12 @@ def get_tasks( "manipulation", "navigation", "custom_interfaces", - "spatial_reasoning", ] ] = [ "basic", "manipulation", "navigation", "custom_interfaces", - "spatial_reasoning", ], ) -> List[Task]: # TODO (jmatejcz) implement complexity sorting @@ -440,8 +343,6 @@ def get_tasks( tasks += get_manipulation_tasks(extra_tool_calls=extra_tool_calls) if "navigation" in task_types: tasks += get_navigation_tasks(extra_tool_calls=extra_tool_calls) - if "spatial_reasoning" in task_types: - tasks += get_spatial_tasks(extra_tool_calls=extra_tool_calls) random.shuffle(tasks) return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py deleted file mode 100644 index 24e328d31..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from abc import abstractmethod -from typing import List - -from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field -from rai.messages import preprocess_image - -from rai_bench.tool_calling_agent.interfaces import Task, Validator - -loggers_type = logging.Logger - -SPATIAL_REASONING_SYSTEM_PROMPT = "You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response with the use of the provided tools." - - -class TaskParametrizationError(Exception): - """Exception raised when the task parameters are not valid.""" - - pass - - -class SpatialReasoningAgentTask(Task): - """Abstract class for spatial reasoning tasks for tool calling agent.""" - - def __init__( - self, - validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, - ) -> None: - super().__init__( - validators=validators, - extra_tool_calls=extra_tool_calls, - logger=logger, - ) - self.expected_tools: List[BaseTool] - self.question: str - self.images_paths: List[str] - - @property - def type(self) -> str: - return "spatial_reasoning" - - @abstractmethod - def get_images(self) -> List[str]: - """Get the images related to the task. - - Returns - ------- - List[str] - List of image paths - """ - pass - - def get_system_prompt(self) -> str: - return SPATIAL_REASONING_SYSTEM_PROMPT - - -class ReturnBoolResponseToolInput(BaseModel): - response: bool = Field(..., description="The response to the question.") - - -class ReturnBoolResponseTool(BaseTool): - """Tool that returns a boolean response.""" - - name: str = "return_bool_response" - description: str = "Return a bool response to the question." - args_schema = ReturnBoolResponseToolInput - - def _run(self, response: bool) -> bool: - if type(response) is bool: - return response - raise ValueError("Invalid response type. Response must be a boolean.") - - -class BoolImageTaskInput(BaseModel): - question: str = Field(..., description="The question to be answered.") - images_paths: List[str] = Field( - ..., - description="List of image file paths to be used for answering the question.", - ) - - -class BoolImageTask(SpatialReasoningAgentTask): - complexity = "easy" - - def __init__( - self, - task_input: BoolImageTaskInput, - validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, - ) -> None: - super().__init__( - validators=validators, - extra_tool_calls=extra_tool_calls, - logger=logger, - ) - self.question = task_input.question - self.images_paths = task_input.images_paths - - @property - def available_tools(self) -> List[BaseTool]: - return [ReturnBoolResponseTool()] - - def get_prompt(self): - return self.question - - def get_images(self): - images = [preprocess_image(image_path) for image_path in self.images_paths] - return images diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index 0da9aee5f..b40498d3c 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -22,8 +22,9 @@ from rai.initialization import get_llm_model_direct -def parse_tool_calling_benchmark_args(): - parser = argparse.ArgumentParser(description="Run the Tool Calling Agent Benchmark") +def parse_base_benchmark_args(description: str, default_out_subdir: str): + """Parse common benchmark arguments shared across different benchmark types.""" + parser = argparse.ArgumentParser(description=description) parser.add_argument( "--model-name", type=str, @@ -31,6 +32,21 @@ def parse_tool_calling_benchmark_args(): required=True, ) parser.add_argument("--vendor", type=str, help="Vendor of the model", required=True) + + now = datetime.now() + parser.add_argument( + "--out-dir", + type=str, + default=f"src/rai_bench/rai_bench/experiments/{default_out_subdir}/{now.strftime('%Y-%m-%d_%H-%M-%S')}", + help="Output directory for results and logs", + ) + return parser + + +def parse_tool_calling_benchmark_args(): + parser = parse_base_benchmark_args( + "Run the Tool Calling Agent Benchmark", "tool_calling" + ) parser.add_argument( "--extra-tool-calls", type=int, @@ -54,36 +70,22 @@ def parse_tool_calling_benchmark_args(): "manipulation", "navigation", "custom_interfaces", - "spatial_reasoning", ], default=[ "basic", "manipulation", "navigation", "custom_interfaces", - "spatial_reasoning", ], help="Types of tasks to include in the benchmark", ) - now = datetime.now() - parser.add_argument( - "--out-dir", - type=str, - default=f"src/rai_bench/rai_bench/experiments/tool_calling/{now.strftime('%Y-%m-%d_%H-%M-%S')}", - help="Output directory for results and logs", - ) return parser.parse_args() def parse_manipulation_o3de_benchmark_args(): - parser = argparse.ArgumentParser(description="Run the Manipulation O3DE Benchmark") - parser.add_argument( - "--model-name", - type=str, - help="Model name to use for benchmarking", - required=True, + parser = parse_base_benchmark_args( + "Run the Manipulation O3DE Benchmark", "o3de_manipulation" ) - parser.add_argument("--vendor", type=str, help="Vendor of the model", required=True) parser.add_argument( "--o3de-config-path", type=str, @@ -98,13 +100,11 @@ def parse_manipulation_o3de_benchmark_args(): default=["trivial", "easy", "medium", "hard", "very_hard"], help="Difficulty levels to include in the benchmark", ) - now = datetime.now() - parser.add_argument( - "--out-dir", - type=str, - default=f"src/rai_bench/rai_bench/experiments/o3de_manipulation/{now.strftime('%Y-%m-%d_%H-%M-%S')}", - help="Output directory for results and logs", - ) + return parser.parse_args() + + +def parse_vlm_benchmark_args(): + parser = parse_base_benchmark_args("Run the VLM Benchmark", "vlm_benchmark") return parser.parse_args() @@ -119,11 +119,16 @@ def define_benchmark_logger(out_dir: Path) -> logging.Logger: ) file_handler.setFormatter(formatter) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + bench_logger = logging.getLogger("Benchmark logger") for handler in bench_logger.handlers: bench_logger.removeHandler(handler) bench_logger.setLevel(logging.INFO) bench_logger.addHandler(file_handler) + bench_logger.addHandler(console_handler) return bench_logger diff --git a/src/rai_bench/rai_bench/vlm_benchmark/__init__.py b/src/rai_bench/rai_bench/vlm_benchmark/__init__.py new file mode 100644 index 000000000..b920f5bf0 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .benchmark import run_benchmark +from .predefined.tasks import get_spatial_tasks + +__all__ = ["get_spatial_tasks", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py b/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py new file mode 100644 index 000000000..8f1c9d699 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py @@ -0,0 +1,211 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import statistics +import time +import uuid +from pathlib import Path +from typing import Iterator, List, Sequence, Tuple + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.runnables.config import RunnableConfig +from langgraph.errors import GraphRecursionError +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel +from rai.agents.langchain.core import ( + create_structured_output_runnable, +) +from rai.messages import HumanMultimodalMessage + +from rai_bench.base_benchmark import BaseBenchmark, RunSummary, TimeoutException +from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler +from rai_bench.utils import get_llm_model_name +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask, TaskValidationError +from rai_bench.vlm_benchmark.results_tracking import ( + TaskResult, +) + + +class VLMBenchmark(BaseBenchmark): + """Benchmark for VLMs.""" + + def __init__( + self, + tasks: Sequence[ImageReasoningTask[BaseModel]], + model_name: str, + results_dir: Path, + logger: logging.Logger | None = None, + ) -> None: + super().__init__( + model_name=model_name, + results_dir=results_dir, + logger=logger, + ) + self._tasks: Iterator[Tuple[int, ImageReasoningTask[BaseModel]]] = enumerate( + iter(tasks) + ) + self.num_tasks = len(tasks) + self.task_results: List[TaskResult] = [] + + self.score_tracing_handler = ScoreTracingHandler() + self.tasks_results: List[TaskResult] = [] + self.csv_initialize(self.results_filename, TaskResult) + + def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: + """Runs the next task of the benchmark. + + Parameters + ---------- + agent : CompiledStateGraph + LangChain tool calling agent. + model_name : str + Name of the LLM model. + """ + # try: + i, task = next(self._tasks) + self.logger.info( + "======================================================================================" + ) + self.logger.info( + f"RUNNING TASK NUMBER {i + 1} / {self.num_tasks}, TASK {task.get_prompt()}" + ) + callbacks = self.score_tracing_handler.get_callbacks() + run_id = uuid.uuid4() + config: RunnableConfig = { + "run_id": run_id, + "callbacks": callbacks, + "tags": [ + f"experiment-id:{experiment_id}", + "benchmark:vlm-benchmark", + self.model_name, + f"task-complexity:{task.complexity}", + ], + "recursion_limit": len(agent.get_graph().nodes), + } + + ts = time.perf_counter() + messages: List[BaseMessage] = [] + prev_count: int = 0 + errors: List[str] = [] + try: + with self.time_limit(60): + for state in agent.stream( + { + "messages": [ + HumanMultimodalMessage( + content=task.get_prompt(), images=task.get_images() + ) + ] + }, + config=config, + ): + node = next(iter(state)) + all_messages = state[node]["messages"] + for new_msg in all_messages[prev_count:]: + messages.append(new_msg) + prev_count = len(messages) + except TimeoutException as e: + self.logger.error(msg=f"Task timeout: {e}") + except GraphRecursionError as e: + self.logger.error(msg=f"Reached recursion limit {e}") + + structured_output = None + try: + structured_output = task.get_structured_output_from_messages( + messages=messages + ) + except TaskValidationError as e: + errors.append(str(e)) + + if structured_output is not None: + score = task.validate(output=structured_output) + else: + errors.append(f"Not valid structured output: {type(structured_output)}") + score = False + + te = time.perf_counter() + total_time = te - ts + + self.logger.info(f"TASK SCORE: {score}, TOTAL TIME: {total_time:.3f}") + + task_result = TaskResult( + task_prompt=task.get_prompt(), + system_prompt=task.get_system_prompt(), + type=task.type, + complexity=task.complexity, + model_name=self.model_name, + score=score, + total_time=total_time, + run_id=run_id, + ) + + self.task_results.append(task_result) + + self.csv_writerow(self.results_filename, task_result) + # computing after every iteration in case of early stopping + self.compute_and_save_summary() + + for callback in callbacks: + self.score_tracing_handler.send_score( + callback=callback, + run_id=run_id, + score=score, + errors=[errors], + ) + + def compute_and_save_summary(self): + self.logger.info("Computing and saving average results...") + + success_count = sum(1 for r in self.task_results if r.score == 1.0) + success_rate = success_count / len(self.task_results) * 100 + avg_time = statistics.mean(r.total_time for r in self.task_results) + + summary = RunSummary( + model_name=self.model_name, + success_rate=round(success_rate, 2), + avg_time=round(avg_time, 3), + total_tasks=len(self.task_results), + ) + self.csv_initialize(self.summary_filename, RunSummary) + self.csv_writerow(self.summary_filename, summary) + + +def run_benchmark( + llm: BaseChatModel, + out_dir: Path, + tasks: List[ImageReasoningTask[BaseModel]], + bench_logger: logging.Logger, + experiment_id: uuid.UUID = uuid.uuid4(), +): + benchmark = VLMBenchmark( + tasks=tasks, + logger=bench_logger, + model_name=get_llm_model_name(llm), + results_dir=out_dir, + ) + + for task in tasks: + agent = create_structured_output_runnable( + llm=llm, + structured_output=task.structured_output, + system_prompt=task.get_system_prompt(), + logger=bench_logger, + ) + + benchmark.run_next(agent=agent, experiment_id=experiment_id) + + bench_logger.info("===============================================================") + bench_logger.info("ALL TASKS DONE. BENCHMARK COMPLETED!") + bench_logger.info("===============================================================") diff --git a/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py b/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py new file mode 100644 index 000000000..97f769b93 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py @@ -0,0 +1,174 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from typing import Generic, List, Literal, Optional, TypeVar + +from langchain_core.messages import BaseMessage +from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT +from pydantic import BaseModel, ConfigDict, ValidationError + +loggers_type = logging.Logger + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) + + +IMAGE_REASONING_SYSTEM_PROMPT = "You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response in requested structured output format." + + +class LangchainRawOutputModel(BaseModel): + """ + A Pydantic model for wrapping Langchain message parsing results from a structured output agent. See documentation for more details: + https://github.com/langchain-ai/langchain/blob/02001212b0a2b37d90451d8493089389ea220cab/libs/core/langchain_core/language_models/chat_models.py#L1430-L1432 + + + Attributes + ---------- + raw : BaseMessage + The original raw message object from Langchain before parsing. + parsed : BaseModel + The parsed and validated Pydantic model instance derived from the raw message. + parsing_error : Optional[BaseException] + Any exception that occurred during the parsing process, None if parsing + was successful. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + raw: BaseMessage + parsed: BaseModel + parsing_error: Optional[BaseException] + + +class TaskValidationError(Exception): + pass + + +class ImageReasoningTask(ABC, Generic[BaseModelT]): + complexity: Literal["easy", "medium", "hard"] + recursion_limit: int = DEFAULT_RECURSION_LIMIT + + def __init__( + self, + logger: loggers_type | None = None, + ) -> None: + """ + Abstract base class representing a complete image reasoning task to be validated. + + Each Task has a consistent prompt and structured output schema, along + with validation methods that check the output against the expected result. + + Attributes + ---------- + logger : logging.Logger + Logger for recording task validation results and errors. + """ + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + self.question: str + self.images_paths: List[str] + + def set_logger(self, logger: loggers_type): + self.logger = logger + + @property + @abstractmethod + def structured_output(self) -> type[BaseModelT]: + """Structured output that agent should return.""" + pass + + @property + @abstractmethod + def type(self) -> str: + """Type of task, for example: image_reasoning""" + pass + + def get_system_prompt(self) -> str: + """Get the system prompt that will be passed to agent + + Returns + ------- + str + System prompt + """ + return IMAGE_REASONING_SYSTEM_PROMPT + + @abstractmethod + def get_prompt(self) -> str: + """Get the task instruction - the prompt that will be passed to agent. + + Returns + ------- + str + Prompt + """ + pass + + @abstractmethod + def validate(self, output: BaseModelT) -> bool: + """Validate result of the task.""" + pass + + @abstractmethod + def get_images(self) -> List[str]: + """Get the images related to the task. + + Returns + ------- + List[str] + List of image paths + """ + pass + + def get_structured_output_from_messages( + self, messages: List[BaseMessage] + ) -> BaseModelT | None: + """Extract and validate structured output from a list of messages. + + Iterates through messages in reverse order, attempting to find the message that is + a LangchainRawOutputModel containing the structured output. + + Parameters + ---------- + messages : List[BaseMessage] + List of messages to search for structured output. + + Returns + ------- + BaseModelT | None + The first valid structured output found that matches the task's expected + output type, or None if no valid structured output is found. + + Raises + ------ + TaskValidationError + If a message contains a parsing error during validation. + """ + for message in reversed(messages): + if isinstance(message, dict): + try: + validated_message = LangchainRawOutputModel.model_validate(message) + if validated_message.parsing_error is not None: + raise TaskValidationError( + f"Parsing error: {validated_message.parsing_error}" + ) + + parsed = validated_message.parsed + if isinstance(parsed, self.structured_output): + return parsed + except ValidationError: + continue + return None diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_1.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_1.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_1.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_1.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_2.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_2.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_2.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_2.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_3.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_3.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_3.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_3.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_4.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_4.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_4.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_4.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_5.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_5.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_5.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_5.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_6.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_6.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_6.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_6.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_7.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_7.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_7.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_7.jpg diff --git a/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py b/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py new file mode 100644 index 000000000..bd1d9cd61 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py @@ -0,0 +1,112 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, cast + +from pydantic import BaseModel + +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask +from rai_bench.vlm_benchmark.tasks.tasks import BoolImageTask, BoolImageTaskInput + +IMG_PATH = "src/rai_bench/rai_bench/vlm_benchmark/predefined/images/" +true_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door on the left from the desk?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is the light on in the room?", + images_paths=[IMG_PATH + "image_2.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_2.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Are there any pictures on the wall?", + images_paths=[IMG_PATH + "image_3.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Are there 3 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is there a plant behind the rack?", + images_paths=[IMG_PATH + "image_5.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is there a pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + expected_answer=True, + ), +] +false_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door open?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is someone in the room?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_3.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Are there 4 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a rack on the left from the sofa?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a plant on the right from the window?", + images_paths=[IMG_PATH + "image_6.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a red pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + expected_answer=False, + ), +] + + +def get_spatial_tasks() -> List[ImageReasoningTask[BaseModel]]: + true_tasks = [ + BoolImageTask( + task_input=input_item, + ) + for input_item in true_response_inputs + ] + false_tasks = [ + BoolImageTask( + task_input=input_item, + ) + for input_item in false_response_inputs + ] + return cast(List[ImageReasoningTask[BaseModel]], true_tasks + false_tasks) diff --git a/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py b/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py new file mode 100644 index 000000000..5d8b77b71 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py @@ -0,0 +1,35 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from uuid import UUID + +from pydantic import BaseModel, Field + + +class TaskResult(BaseModel): + task_prompt: str = Field(..., description="The task prompt.") + system_prompt: str = Field(..., description="The system prompt.") + complexity: str = Field(..., description="Complexity of the task.") + type: str = Field( + ..., description="Type of task, for example: bool_response_image_task" + ) + model_name: str = Field(..., description="Name of the LLM.") + score: float = Field( + ..., + description="Value between 0 and 1.", + ) + + total_time: float = Field(..., description="Total time taken to complete the task.") + run_id: UUID = Field(..., description="UUID of the task run.") diff --git a/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py b/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py new file mode 100644 index 000000000..639b50400 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py @@ -0,0 +1,76 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import List + +from pydantic import BaseModel, Field +from rai.messages import preprocess_image + +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask + +loggers_type = logging.Logger + + +class BoolAnswerWithJustification(BaseModel): + """A boolean answer to the user question along with justification for the answer.""" + + answer: bool + justification: str + + +class BoolImageTaskInput(BaseModel): + question: str = Field(..., description="The question to be answered.") + images_paths: List[str] = Field( + ..., + description="List of image file paths to be used for answering the question.", + ) + expected_answer: bool = Field( + ..., description="The expected answer to the question." + ) + + +class BoolImageTask(ImageReasoningTask[BoolAnswerWithJustification]): + complexity = "easy" + + def __init__( + self, + task_input: BoolImageTaskInput, + logger: loggers_type | None = None, + ) -> None: + super().__init__( + logger=logger, + ) + self.question = task_input.question + self.images_paths = task_input.images_paths + self.expected_answer = task_input.expected_answer + + @property + def structured_output(self) -> type[BoolAnswerWithJustification]: + return BoolAnswerWithJustification + + @property + def type(self) -> str: + return "bool_response_image_task" + + def get_prompt(self): + return self.question + + def get_images(self): + images = [preprocess_image(image_path) for image_path in self.images_paths] + return images + + def validate(self, output: BoolAnswerWithJustification) -> bool: + return output.answer == self.expected_answer diff --git a/src/rai_core/rai/agents/langchain/core/__init__.py b/src/rai_core/rai/agents/langchain/core/__init__.py index ea6cb46bf..917348fd0 100644 --- a/src/rai_core/rai/agents/langchain/core/__init__.py +++ b/src/rai_core/rai/agents/langchain/core/__init__.py @@ -19,6 +19,7 @@ create_react_runnable, ) from .state_based_agent import create_state_based_runnable +from .structured_output_agent import create_structured_output_runnable from .tool_runner import ToolRunner __all__ = [ @@ -28,4 +29,5 @@ "create_conversational_agent", "create_react_runnable", "create_state_based_runnable", + "create_structured_output_runnable", ] diff --git a/src/rai_core/rai/agents/langchain/core/structured_output_agent.py b/src/rai_core/rai/agents/langchain/core/structured_output_agent.py new file mode 100644 index 000000000..7c1878cd3 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/core/structured_output_agent.py @@ -0,0 +1,60 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from functools import partial +from typing import Optional + +from langchain.chat_models.base import BaseChatModel +from langchain_core.messages import ( + SystemMessage, +) +from langgraph.graph import START, StateGraph +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel + +from rai.agents.langchain.core.conversational_agent import State, agent + + +def create_structured_output_runnable( + llm: BaseChatModel, + structured_output: type[BaseModel], + system_prompt: str | SystemMessage, + logger: Optional[logging.Logger] = None, + debug: bool = False, +) -> CompiledStateGraph: + _logger = None + if logger: + _logger = logger + else: + _logger = logging.getLogger(__name__) + + _logger.info("Creating structured output runnable") + + llm_with_structured_output = llm.with_structured_output( + schema=structured_output, include_raw=True + ) + + workflow = StateGraph(State) + + workflow.add_node( + "thinker", partial(agent, llm_with_structured_output, _logger, system_prompt) + ) + + workflow.add_edge(START, "thinker") + + app = workflow.compile(debug=debug) + _logger.info("State based agent created") + return app From f71f6c192851cf24dbb9b3eee5e55a689a61f8ae Mon Sep 17 00:00:00 2001 From: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> Date: Mon, 18 Aug 2025 12:50:46 +0200 Subject: [PATCH 03/14] chore: deprecate create_conversational_agent (#671) --- .../rai/agents/langchain/core/conversational_agent.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/rai_core/rai/agents/langchain/core/conversational_agent.py b/src/rai_core/rai/agents/langchain/core/conversational_agent.py index b74d07654..e008fdb7d 100644 --- a/src/rai_core/rai/agents/langchain/core/conversational_agent.py +++ b/src/rai_core/rai/agents/langchain/core/conversational_agent.py @@ -17,6 +17,7 @@ from functools import partial from typing import List, Optional, TypedDict +from deprecated import deprecated from langchain.chat_models.base import BaseChatModel from langchain_core.messages import ( BaseMessage, @@ -64,6 +65,10 @@ def agent( return state +@deprecated( + "Use rai.agents.langchain.core.create_react_runnable instead. " + "Support for the conversational agent will be removed in the 3.0 release." +) def create_conversational_agent( llm: BaseChatModel, tools: List[BaseTool], From fbf056f5e7a0ce05ffc9fafc3a41e25c4a17c94e Mon Sep 17 00:00:00 2001 From: Pawel Kotowski Date: Tue, 19 Aug 2025 11:29:27 +0200 Subject: [PATCH 04/14] docs: improve rai_bench readme (#674) --- docs/simulation_and_benchmarking/rai_bench.md | 6 ++-- docs/tutorials/benchmarking.md | 28 +++++++++---------- src/rai_bench/README.md | 27 +++++++++--------- 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index ea3ff34f1..09258c9f2 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -96,9 +96,9 @@ Evaluates agent performance independently from any simulation, based only on too The `SubTask` class is used to validate just one tool call. Following classes are available: - `CheckArgsToolCallSubTask` - verify if a certain tool was called with expected arguments -- `CheckTopicFieldsToolCallSubTask` - verify if a message published to ROS 2topic was of proper type and included expected fields -- `CheckServiceFieldsToolCallSubTask` - verify if a message published to ROS 2service was of proper type and included expected fields -- `CheckActionFieldsToolCallSubTask` - verify if a message published to ROS 2action was of proper type and included expected fields +- `CheckTopicFieldsToolCallSubTask` - verify if a message published to ROS2 topic was of proper type and included expected fields +- `CheckServiceFieldsToolCallSubTask` - verify if a message published to ROS2 service was of proper type and included expected fields +- `CheckActionFieldsToolCallSubTask` - verify if a message published to ROS2 action was of proper type and included expected fields ### Validator diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index db2a56cd6..f46cfcb24 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -22,21 +22,21 @@ If your goal is creating custom tasks and scenarios, visit [Creating Custom Task level: RoboticManipulationBenchmark robotic_stack_command: ros2 launch examples/manipulation-demo-no-binary.launch.py required_simulation_ros2_interfaces: - services: - - /spawn_entity - - /delete_entity - topics: - - /color_image5 - - /depth_image5 - - /color_camera_info5 - actions: [] + services: + - /spawn_entity + - /delete_entity + topics: + - /color_image5 + - /depth_image5 + - /color_camera_info5 + actions: [] required_robotic_ros2_interfaces: - services: - - /grounding_dino_classify - - /grounded_sam_segment - - /manipulator_move_to - topics: [] - actions: [] + services: + - /grounding_dino_classify + - /grounded_sam_segment + - /manipulator_move_to + topics: [] + actions: [] ``` - Run the benchmark with: diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md index b9a720aa3..739a228ea 100644 --- a/src/rai_bench/README.md +++ b/src/rai_bench/README.md @@ -10,9 +10,9 @@ The Manipulation O3DE Benchmark [manipulation_o3de_benchmark_module](./rai_bench - **GroupObjectsTask** - **BuildCubeTowerTask** - **PlaceObjectAtCoordTask** -- **RotateObjectTask** (currently not applicable due to limitations in the ManipulatorMoveTo tool) +- **RotateObjectTask** (currently not applicable due to limitations in the `ManipulatorMoveTo` tool) -The result of a task is a value between 0 and 1, calculated like initially_misplaced_now_correct / initially_misplaced. This score is calculated at the end of each scenario. +The result of a task is a value between 0 and 1, calculated like `initially_misplaced_now_correct / initially_misplaced`. This score is calculated at the end of each scenario. ### Frame Components @@ -92,7 +92,7 @@ python src/rai_bench/rai_bench/examples/manipulation_o3de/main.py --model-name l ``` > [!NOTE] -> For now benchmark runs all available scenarios (~160). See [Examples](#example-usege) +> For now benchmark runs all available scenarios (~160). See [Examples](#example-usage) > section for details. ### Development @@ -100,7 +100,7 @@ python src/rai_bench/rai_bench/examples/manipulation_o3de/main.py --model-name l When creating new task or changing existing ones, make sure to add unit tests for score calculation in [rai_bench_tests](../../tests/rai_bench/manipulation_o3de/tasks/). This applies also when you are adding or changing the helper methods in `Task` or `ManipulationTask`. -The number of scenarios can be easily extened without writing new tasks, by increasing number of variants of the same task and adding more simulation configs but it won't improve variety of scenarios as much as creating new tasks. +The number of scenarios can be easily extended without writing new tasks, by increasing number of variants of the same task and adding more simulation configs but it won't improve variety of scenarios as much as creating new tasks. ## Tool Calling Agent Benchmark @@ -109,15 +109,16 @@ The Tool Calling Agent Benchmark is the benchmark for LangChain tool calling age ### Frame Components - [Tool Calling Agent Benchmark](rai_bench//tool_calling_agent/benchmark.py) - Benchmark for LangChain tool calling agents -- [Scores tracing](rai_bench/tool_calling_agent_bench/scores_tracing.py) - Component handling sending scores to tracing backends -- [Interfaces](rai_bench//tool_calling_agent/interfaces.py) - Interfaces for validation classes - Task, Validator, SubTask - For detailed description of validation visit -> [Validation](.//rai_bench/docs/tool_calling_agent_benchmark.md) +- [Scores tracing](rai_bench/results_processing/langfuse_scores_tracing.py) - Component handling sending scores to tracing backends +- [Interfaces](rai_bench//tool_calling_agent/interfaces.py) - Interfaces for validation classes - `Task`, `Validator`, `SubTask` -[tool_calling_agent_test_bench.py](rai_bench/examples/tool_calling_agent/main.py) - Script providing benchmark on tasks based on the ROS2 tools usage. + For detailed description of validation visit -> [Validation](./rai_bench/docs/tool_calling_agent_benchmark.md) + +[tool_calling_agent_test_bench.py](./rai_bench/examples/tool_calling_agent/main.py) - Script providing benchmark on tasks based on the ROS2 tools usage. ### Example Usage -Validators can be constructed from any SubTasks, Tasks can be validated by any numer of Validators, which makes whole validation process incredibly versital. +`Validators` can be constructed from any `SubTasks`, `Tasks` can be validated by any number of `Validators`, which makes whole validation process incredibly versatile. ```python # subtasks @@ -144,7 +145,7 @@ GetROS2RGBCameraTask(validators=[topics_ord_val, color_image_ord_val]), ### Running -To set up tracing backends, please follow the instructions in the [tracing.md](../../docs/tracing.md) document. +To set up tracing backends, please follow the instructions in the [tracing.md](../../docs/setup/tracing.md) document. To run the benchmark: @@ -169,7 +170,7 @@ The VLM Benchmark is a benchmark for VLM models. It includes a set of tasks cont ### Running -To set up tracing backends, please follow the instructions in the [tracing.md](../../docs/tracing.md) document. +To set up tracing backends, please follow the instructions in the [tracing.md](../../docs/setup/tracing.md) document. To run the benchmark: @@ -181,7 +182,7 @@ python src/rai_bench/rai_bench/examples/vlm_benchmark.py --model-name gemma3:4b ## Testing Models -To test multiple models, different benchamrks or couple repeats in one go - use script [test_models](./rai_bench/examples/test_models.py) +To test multiple models, different benchmarks or couple repeats in one go - use script [test_models](./rai_bench/examples/test_models.py) Modify these params: @@ -216,7 +217,7 @@ When you run a test via: python src/rai_bench/rai_bench/examples/test_models.py ``` -results will be saved to separate folder in [results](./rai_bench/experiments/), with prefix `run_` +results will be saved to separate folder in [experiments](./rai_bench/experiments/), with prefix `run_` To visualise the results run: From e567d5b2625ee1d716eb5f251f0f44acd1275e88 Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Mon, 25 Aug 2025 04:31:05 -0700 Subject: [PATCH 05/14] fix: invoke_with_tracing breaks rai used outside of developer space (#677) --- src/rai_core/rai/agents/langchain/__init__.py | 2 + .../agents/langchain/invocation_helpers.py | 2 + .../initialization/model_initialization.py | 13 +- .../agents/langchain/test_langchain_agent.py | 111 +++++++++++++++ tests/conftest.py | 129 ++++++++++++++++++ tests/initialization/test_tracing.py | 73 ++++++++++ 6 files changed, 326 insertions(+), 4 deletions(-) create mode 100644 tests/initialization/test_tracing.py diff --git a/src/rai_core/rai/agents/langchain/__init__.py b/src/rai_core/rai/agents/langchain/__init__.py index b608b97bb..186ad47d8 100644 --- a/src/rai_core/rai/agents/langchain/__init__.py +++ b/src/rai_core/rai/agents/langchain/__init__.py @@ -19,6 +19,7 @@ create_react_runnable, create_state_based_runnable, ) +from .invocation_helpers import invoke_llm_with_tracing from .react_agent import ReActAgent from .state_based_agent import BaseStateBasedAgent, StateBasedConfig @@ -32,5 +33,6 @@ "StateBasedConfig", "create_react_runnable", "create_state_based_runnable", + "invoke_llm_with_tracing", "newMessageBehaviorType", ] diff --git a/src/rai_core/rai/agents/langchain/invocation_helpers.py b/src/rai_core/rai/agents/langchain/invocation_helpers.py index c2a667f29..d3ecee749 100644 --- a/src/rai_core/rai/agents/langchain/invocation_helpers.py +++ b/src/rai_core/rai/agents/langchain/invocation_helpers.py @@ -35,6 +35,8 @@ def invoke_llm_with_tracing( This function automatically adds tracing callbacks (like Langfuse) to LLM calls within LangGraph nodes, solving the callback propagation issue. + Tracing is controlled by config.toml. If the file is missing, no tracing is applied. + Parameters ---------- llm : BaseChatModel diff --git a/src/rai_core/rai/initialization/model_initialization.py b/src/rai_core/rai/initialization/model_initialization.py index ce0c4ace5..036457ed8 100644 --- a/src/rai_core/rai/initialization/model_initialization.py +++ b/src/rai_core/rai/initialization/model_initialization.py @@ -273,11 +273,16 @@ def get_embeddings_model( def get_tracing_callbacks( - override_use_langfuse: bool = False, override_use_langsmith: bool = False + config_path: Optional[str] = None, ) -> List[BaseCallbackHandler]: - config = load_config() + try: + config = load_config(config_path) + except Exception as e: + logger.warning(f"Failed to load config for tracing: {e}, tracing disabled") + return [] + callbacks: List[BaseCallbackHandler] = [] - if config.tracing.langfuse.use_langfuse or override_use_langfuse: + if config.tracing.langfuse.use_langfuse: from langfuse.callback import CallbackHandler # type: ignore public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None) @@ -292,7 +297,7 @@ def get_tracing_callbacks( ) callbacks.append(callback) - if config.tracing.langsmith.use_langsmith or override_use_langsmith: + if config.tracing.langsmith.use_langsmith: os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = config.tracing.project api_key = os.getenv("LANGCHAIN_API_KEY", None) diff --git a/tests/agents/langchain/test_langchain_agent.py b/tests/agents/langchain/test_langchain_agent.py index c220f7bbe..c046eed3f 100644 --- a/tests/agents/langchain/test_langchain_agent.py +++ b/tests/agents/langchain/test_langchain_agent.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from collections import deque from typing import List +from unittest.mock import MagicMock, patch import pytest +from rai.agents.langchain import invoke_llm_with_tracing from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType +from rai.initialization import get_tracing_callbacks @pytest.mark.parametrize( @@ -39,3 +43,110 @@ def test_reduce_messages( output_ = LangChainAgent._apply_reduction_behavior(new_message_behavior, buffer) assert output == output_ assert buffer == deque(out_buffer) + + +class TestTracingConfiguration: + """Test tracing configuration integration with langchain agents.""" + + def test_tracing_with_missing_config_file(self): + """Test that tracing gracefully handles missing config.toml file in langchain context.""" + # This should not crash even without config.toml + callbacks = get_tracing_callbacks() + assert len(callbacks) == 0 + + def test_tracing_with_config_file_present(self, test_config_toml): + """Test that tracing works when config.toml is present in langchain context.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=True, langsmith_enabled=False + ) + + try: + # Mock environment variables to avoid actual API calls + with patch.dict( + os.environ, + { + "LANGFUSE_PUBLIC_KEY": "test_key", + "LANGFUSE_SECRET_KEY": "test_secret", + }, + ): + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 1 callback for langfuse + assert len(callbacks) == 1 + finally: + cleanup() + + +class TestInvokeLLMWithTracing: + """Test the invoke_llm_with_tracing function.""" + + def test_invoke_llm_without_tracing(self): + """Test that invoke_llm_with_tracing works when no tracing callbacks are available.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock get_tracing_callbacks to return empty list (no config.toml) + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = [] + + result = invoke_llm_with_tracing(mock_llm, mock_messages) + + mock_llm.invoke.assert_called_once_with(mock_messages, config=None) + assert result == "test response" + + def test_invoke_llm_with_tracing(self): + """Test that invoke_llm_with_tracing works when tracing callbacks are available.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock get_tracing_callbacks to return some callbacks + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = ["tracing_callback"] + + _ = invoke_llm_with_tracing(mock_llm, mock_messages) + + # Verify that the LLM was called with enhanced config + mock_llm.invoke.assert_called_once() + call_args = mock_llm.invoke.call_args + assert call_args[0][0] == mock_messages + assert "callbacks" in call_args[1]["config"] + assert "tracing_callback" in call_args[1]["config"]["callbacks"] + + def test_invoke_llm_with_existing_config(self): + """Test that invoke_llm_with_tracing preserves existing config.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock existing config + existing_config = {"callbacks": ["existing_callback"]} + + # Mock get_tracing_callbacks to return some callbacks + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = ["tracing_callback"] + + _ = invoke_llm_with_tracing(mock_llm, mock_messages, existing_config) + + # Verify that the LLM was called with enhanced config + mock_llm.invoke.assert_called_once() + call_args = mock_llm.invoke.call_args + assert call_args[0][0] == mock_messages + assert "callbacks" in call_args[1]["config"] + assert "existing_callback" in call_args[1]["config"]["callbacks"] + assert "tracing_callback" in call_args[1]["config"]["callbacks"] diff --git a/tests/conftest.py b/tests/conftest.py index 97ceef6f0..adb9e1850 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,3 +11,132 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import os +import tempfile + +import pytest + + +@pytest.fixture +def test_config_toml(): + """ + Fixture to create a temporary test config.toml file with tracing enabled. + + Returns + ------- + tuple + (config_path, cleanup_function) - The path to the config file and a function to clean it up + """ + + def _create_config(langfuse_enabled=False, langsmith_enabled=False): + # Create a temporary config.toml file + f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) + + # Base config sections (always required) + config_content = """[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "text-embedding-ada-002" + +[aws] +simple_model = "anthropic.claude-instant-v1" +complex_model = "anthropic.claude-v2" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-3.5-turbo" +complex_model = "gpt-4" +embeddings_model = "text-embedding-ada-002" +base_url = "https://api.openai.com/v1" + +[ollama] +simple_model = "llama2" +complex_model = "llama2" +embeddings_model = "llama2" +base_url = "http://localhost:11434" + +[tracing] +project = "test-project" + +[tracing.langfuse] +use_langfuse = {langfuse_enabled} +host = "http://localhost:3000" + +[tracing.langsmith] +use_langsmith = {langsmith_enabled} +host = "https://api.smith.langchain.com" +""".format( + langfuse_enabled=str(langfuse_enabled).lower(), + langsmith_enabled=str(langsmith_enabled).lower(), + ) + + f.write(config_content) + f.close() + + def cleanup(): + try: + f.close() # Ensure file is properly closed + os.unlink(f.name) + except (OSError, PermissionError): + pass # File might already be deleted or have permission issues + + return f.name, cleanup + + return _create_config + + +@pytest.fixture +def test_config_no_tracing(): + """ + Fixture to create a temporary test config.toml file with no tracing section. + + Returns + ------- + tuple + (config_path, cleanup_function) - The path to the config file and a function to clean it up + """ + + def _create_config(): + # Create a temporary config.toml file + f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) + + # Base config sections (always required) + config_content = """[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "text-embedding-ada-002" + +[aws] +simple_model = "anthropic.claude-instant-v1" +complex_model = "anthropic.claude-v2" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-3.5-turbo" +complex_model = "gpt-4" +embeddings_model = "text-embedding-ada-002" +base_url = "https://api.openai.com/v1" + +[ollama] +simple_model = "llama2" +complex_model = "llama2" +embeddings_model = "llama2" +base_url = "http://localhost:11434" +""" + + f.write(config_content) + f.close() + + def cleanup(): + try: + f.close() # Ensure file is properly closed + os.unlink(f.name) + except (OSError, PermissionError): + pass # File might already be deleted or have permission issues + + return f.name, cleanup + + return _create_config diff --git a/tests/initialization/test_tracing.py b/tests/initialization/test_tracing.py new file mode 100644 index 000000000..659941755 --- /dev/null +++ b/tests/initialization/test_tracing.py @@ -0,0 +1,73 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +from rai.initialization import get_tracing_callbacks + + +class TestInitializationTracing: + """Test the initialization module's tracing functionality.""" + + def test_tracing_with_missing_config_file(self): + """Test that tracing gracefully handles missing config.toml file.""" + # This should not crash even without config.toml + callbacks = get_tracing_callbacks() + assert len(callbacks) == 0 + + def test_tracing_with_config_file_present_tracing_disabled(self, test_config_toml): + """Test that tracing works when config.toml is present but tracing is disabled.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=False, langsmith_enabled=False + ) + + try: + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 0 callbacks since both langfuse and langsmith are disabled + assert len(callbacks) == 0 + finally: + cleanup() + + def test_tracing_with_config_file_present_tracing_enabled(self, test_config_toml): + """Test that tracing works when config.toml is present and tracing is enabled.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=True, langsmith_enabled=False + ) + + try: + # Mock environment variables to avoid actual API calls + with patch.dict( + os.environ, + { + "LANGFUSE_PUBLIC_KEY": "test_key", + "LANGFUSE_SECRET_KEY": "test_secret", + }, + ): + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 1 callback for langfuse + assert len(callbacks) == 1 + finally: + cleanup() + + def test_tracing_with_valid_config_file_no_tracing(self, test_config_no_tracing): + """Test that tracing works when config.toml is valid but has no tracing sections.""" + config_path, cleanup = test_config_no_tracing() + + try: + # This should not crash, should return empty callbacks + callbacks = get_tracing_callbacks(config_path=config_path) + assert len(callbacks) == 0 + finally: + cleanup() From 6074a1b28998245093f5242b7d5a5bf640d83108 Mon Sep 17 00:00:00 2001 From: Brian Tuan Date: Tue, 9 Sep 2025 02:14:35 -0700 Subject: [PATCH 06/14] fix: refactor messages to use BaseMessage.text() (#683) --- .../rai_bench/manipulation_o3de/benchmark.py | 2 +- .../rai/communication/hri_connector.py | 3 +- src/rai_core/rai/messages/multimodal.py | 4 -- .../agents/langchain/test_langchain_agent.py | 13 +++++++ tests/messages/test_multimodal_message.py | 37 +++++++++++++++++++ 5 files changed, 52 insertions(+), 7 deletions(-) create mode 100644 tests/messages/test_multimodal_message.py diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index 4dce98527..33cbd3354 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -286,7 +286,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: for msg in new_messages: if isinstance(msg, HumanMultimodalMessage): - last_msg = msg.text + last_msg = msg.text() elif isinstance(msg, BaseMessage): if isinstance(msg.content, list): if len(msg.content) == 1: diff --git a/src/rai_core/rai/communication/hri_connector.py b/src/rai_core/rai/communication/hri_connector.py index 17f18d416..bffc06ba0 100644 --- a/src/rai_core/rai/communication/hri_connector.py +++ b/src/rai_core/rai/communication/hri_connector.py @@ -103,12 +103,11 @@ def from_langchain( seq_no: int = 0, seq_end: bool = False, ) -> "HRIMessage": + text = message.text() if isinstance(message, RAIMultimodalMessage): - text = message.text images = message.images audios = message.audios else: - text = str(message.content) images = None audios = None if message.type not in ["ai", "human"]: diff --git a/src/rai_core/rai/messages/multimodal.py b/src/rai_core/rai/messages/multimodal.py index b1646f37d..1b5129d78 100644 --- a/src/rai_core/rai/messages/multimodal.py +++ b/src/rai_core/rai/messages/multimodal.py @@ -63,10 +63,6 @@ def __init__( _content.extend(_image_content) self.content = _content - @property - def text(self) -> str: - return self.content[0]["text"] - class HumanMultimodalMessage(HumanMessage, MultimodalMessage): def __repr_args__(self) -> Any: diff --git a/tests/agents/langchain/test_langchain_agent.py b/tests/agents/langchain/test_langchain_agent.py index c046eed3f..44e883120 100644 --- a/tests/agents/langchain/test_langchain_agent.py +++ b/tests/agents/langchain/test_langchain_agent.py @@ -18,9 +18,13 @@ from unittest.mock import MagicMock, patch import pytest +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.language_models.fake_chat_models import ParrotFakeChatModel +from langchain_core.runnables import RunnableConfig from rai.agents.langchain import invoke_llm_with_tracing from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType from rai.initialization import get_tracing_callbacks +from rai.messages import HumanMultimodalMessage @pytest.mark.parametrize( @@ -150,3 +154,12 @@ def test_invoke_llm_with_existing_config(self): assert "callbacks" in call_args[1]["config"] assert "existing_callback" in call_args[1]["config"]["callbacks"] assert "tracing_callback" in call_args[1]["config"]["callbacks"] + + def test_invoke_llm_with_callback_integration(self): + """Test that invoke_llm_with_tracing works with a callback handler.""" + llm = ParrotFakeChatModel() + human_msg = HumanMultimodalMessage(content="human") + response = llm.invoke( + [human_msg], config=RunnableConfig(callbacks=[BaseCallbackHandler()]) + ) + assert response.content == [{"type": "text", "text": "human"}] diff --git a/tests/messages/test_multimodal_message.py b/tests/messages/test_multimodal_message.py new file mode 100644 index 000000000..221ec0fed --- /dev/null +++ b/tests/messages/test_multimodal_message.py @@ -0,0 +1,37 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from rai.messages import HumanMultimodalMessage + + +class TestMultimodalMessage: + """Test the MultimodalMessage class and expected behaviors.""" + + def test_human_multimodal_message_text_simple(self): + """Test text() method with simple text content.""" + msg = HumanMultimodalMessage(content="Hello world") + assert msg.text() == "Hello world" + assert isinstance(msg.text(), str) + + def test_human_multimodal_message_text_with_images(self): + """Test text() method with text and images.""" + # Use a small valid base64 image (1x1 pixel PNG) + valid_base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + msg = HumanMultimodalMessage( + content="Look at this image", images=[valid_base64_image] + ) + assert msg.text() == "Look at this image" + # Should only return text type blocks, not image content + assert valid_base64_image not in msg.text() From f3a94135744bfb79fb72491fe1d396e6b07d7cf2 Mon Sep 17 00:00:00 2001 From: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> Date: Fri, 12 Sep 2025 12:04:51 +0200 Subject: [PATCH 07/14] feat: service client cache (#685) --- src/rai_core/pyproject.toml | 2 +- .../rai/communication/ros2/api/service.py | 67 ++++++++++++++----- .../ros2/connectors/service_mixin.py | 5 ++ tests/communication/ros2/test_api.py | 36 ++++++++-- 4 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 4f2523a07..9c8152a88 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.2.1" +version = "2.3.0" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" diff --git a/src/rai_core/rai/communication/ros2/api/service.py b/src/rai_core/rai/communication/ros2/api/service.py index 464625aed..06dac1bc1 100644 --- a/src/rai_core/rai/communication/ros2/api/service.py +++ b/src/rai_core/rai/communication/ros2/api/service.py @@ -14,6 +14,7 @@ import os import uuid +from threading import Lock from typing import ( Any, Callable, @@ -30,6 +31,7 @@ import rclpy.qos import rclpy.subscription import rclpy.task +from rclpy.client import Client from rclpy.service import Service from rai.communication.ros2.api.base import ( @@ -39,12 +41,18 @@ class ROS2ServiceAPI(BaseROS2API): - """Handles ROS2 service operations including calling services.""" + """Handles ROS 2 service operations including calling services.""" def __init__(self, node: rclpy.node.Node) -> None: self.node = node self._logger = node.get_logger() self._services: Dict[str, Service] = {} + self._persistent_clients: Dict[str, Client] = {} + self._persistent_clients_lock = Lock() + + def release_client(self, service_name: str) -> bool: + with self._persistent_clients_lock: + return self._persistent_clients.pop(service_name, None) is not None def call_service( self, @@ -52,30 +60,57 @@ def call_service( service_type: str, request: Any, timeout_sec: float = 5.0, + *, + reuse_client: bool = True, ) -> Any: """ - Call a ROS2 service. + Call a ROS 2 service. Args: - service_name: Name of the service to call - service_type: ROS2 service type as string - request: Request message content + service_name: Fully-qualified service name. + service_type: ROS 2 service type string (e.g., 'std_srvs/srv/SetBool'). + request: Request payload dict. + timeout_sec: Seconds to wait for availability/response. + reuse_client: Reuse a cached client. Client creation is synchronized; set + False to create a new client per call. Returns: - The response message + Response message instance. + + Raises: + ValueError: Service not available within the timeout. + AttributeError: Service type or request cannot be constructed. + + Note: + With reuse_client=True, access to the cached client (including the + service call) is serialized by a lock, preventing concurrent calls + through the same client. Use reuse_client=False for per-call clients + when concurrent service calls are required. """ srv_msg, srv_cls = self.build_ros2_service_request(service_type, request) - service_client = self.node.create_client(srv_cls, service_name) # type: ignore - client_ready = service_client.wait_for_service(timeout_sec=timeout_sec) - if not client_ready: - raise ValueError( - f"Service {service_name} not ready within {timeout_sec} seconds. " - "Try increasing the timeout or check if the service is running." - ) - if os.getenv("ROS_DISTRO") == "humble": - return service_client.call(srv_msg) + + def _call_service(client: Client, timeout_sec: float) -> Any: + is_service_available = client.wait_for_service(timeout_sec=timeout_sec) + if not is_service_available: + raise ValueError( + f"Service {service_name} not ready within {timeout_sec} seconds. " + "Try increasing the timeout or check if the service is running." + ) + if os.getenv("ROS_DISTRO") == "humble": + return client.call(srv_msg) + else: + return client.call(srv_msg, timeout_sec=timeout_sec) + + if reuse_client: + with self._persistent_clients_lock: + client = self._persistent_clients.get(service_name, None) + if client is None: + client = self.node.create_client(srv_cls, service_name) # type: ignore + self._persistent_clients[service_name] = client + return _call_service(client, timeout_sec) else: - return service_client.call(srv_msg, timeout_sec=timeout_sec) + client = self.node.create_client(srv_cls, service_name) # type: ignore + return _call_service(client, timeout_sec) def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]: return self.node.get_service_names_and_types() diff --git a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py index 985de6c16..7c1597a56 100644 --- a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py +++ b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py @@ -30,6 +30,9 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: f"{self.__class__.__name__} instance must have an attribute '_service_api' of type ROS2ServiceAPI" ) + def release_client(self, service_name: str) -> bool: + return self._service_api.release_client(service_name) + def service_call( self, message: ROS2Message, @@ -37,6 +40,7 @@ def service_call( timeout_sec: float = 5.0, *, msg_type: str, + reuse_client: bool = True, **kwargs: Any, ) -> ROS2Message: msg = self._service_api.call_service( @@ -44,6 +48,7 @@ def service_call( service_type=msg_type, request=message.payload, timeout_sec=timeout_sec, + reuse_client=reuse_client, ) return ROS2Message( payload=msg, metadata={"msg_type": str(type(msg)), "service": target} diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index bc102a58a..d51f30378 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -138,11 +138,14 @@ def test_ros2_single_message_publish_wrong_qos_setup( shutdown_executors_and_threads(executors, threads) -def service_call_helper(service_name: str, service_api: ROS2ServiceAPI): +def invoke_set_bool_service( + service_name: str, service_api: ROS2ServiceAPI, reuse_client: bool = True +): response = service_api.call_service( service_name, service_type="std_srvs/srv/SetBool", request={"data": True}, + reuse_client=reuse_client, ) assert response.success assert response.message == "Test service called" @@ -164,7 +167,7 @@ def test_ros2_service_single_call( try: service_api = ROS2ServiceAPI(node) - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api) finally: shutdown_executors_and_threads(executors, threads) @@ -186,7 +189,30 @@ def test_ros2_service_multiple_calls( try: service_api = ROS2ServiceAPI(node) for _ in range(3): - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api, reuse_client=False) + finally: + shutdown_executors_and_threads(executors, threads) + + +@pytest.mark.parametrize( + "callback_group", + [MutuallyExclusiveCallbackGroup(), ReentrantCallbackGroup()], + ids=["MutuallyExclusiveCallbackGroup", "ReentrantCallbackGroup"], +) +def test_ros2_service_multiple_calls_with_reused_client( + ros_setup: None, request: pytest.FixtureRequest, callback_group: CallbackGroup +) -> None: + service_name = f"{request.node.originalname}_service" # type: ignore + node_name = f"{request.node.originalname}_node" # type: ignore + service_server = ServiceServer(service_name, callback_group) + node = Node(node_name) + executors, threads = multi_threaded_spinner([service_server, node]) + + try: + service_api = ROS2ServiceAPI(node) + for _ in range(3): + invoke_set_bool_service(service_name, service_api, reuse_client=True) + assert service_api.release_client(service_name), "Client not released" finally: shutdown_executors_and_threads(executors, threads) @@ -210,7 +236,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_threading( service_threads: List[threading.Thread] = [] for _ in range(10): thread = threading.Thread( - target=service_call_helper, args=(service_name, service_api) + target=invoke_set_bool_service, args=(service_name, service_api) ) service_threads.append(thread) thread.start() @@ -241,7 +267,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_multiprocessing( service_api = ROS2ServiceAPI(node) with Pool(10) as pool: pool.map( - lambda _: service_call_helper(service_name, service_api), range(10) + lambda _: invoke_set_bool_service(service_name, service_api), range(10) ) finally: shutdown_executors_and_threads(executors, threads) From cefdffd158ab452f6c06418999e8b7d182fee723 Mon Sep 17 00:00:00 2001 From: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> Date: Fri, 12 Sep 2025 14:53:32 +0200 Subject: [PATCH 08/14] feat: nav2 blocking tool (#669) --- src/rai_core/pyproject.toml | 2 +- .../rai/tools/ros2/navigation/__init__.py | 2 + .../tools/ros2/navigation/nav2_blocking.py | 69 +++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 9c8152a88..d13258cbc 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.3.0" +version = "2.4.0" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" diff --git a/src/rai_core/rai/tools/ros2/navigation/__init__.py b/src/rai_core/rai/tools/ros2/navigation/__init__.py index 8c5a94838..4142fcc68 100644 --- a/src/rai_core/rai/tools/ros2/navigation/__init__.py +++ b/src/rai_core/rai/tools/ros2/navigation/__init__.py @@ -20,6 +20,7 @@ Nav2Toolkit, NavigateToPoseTool, ) +from .nav2_blocking import NavigateToPoseBlockingTool __all__ = [ "CancelNavigateToPoseTool", @@ -27,5 +28,6 @@ "GetNavigateToPoseResultTool", "GetOccupancyGridTool", "Nav2Toolkit", + "NavigateToPoseBlockingTool", "NavigateToPoseTool", ] diff --git a/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py b/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py new file mode 100644 index 000000000..b51a79207 --- /dev/null +++ b/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py @@ -0,0 +1,69 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from geometry_msgs.msg import PoseStamped, Quaternion +from nav2_msgs.action import NavigateToPose +from pydantic import BaseModel, Field +from rclpy.action import ActionClient +from tf_transformations import quaternion_from_euler + +from rai.tools.ros2.base import BaseROS2Tool + + +class NavigateToPoseBlockingToolInput(BaseModel): + x: float = Field(..., description="The x coordinate of the pose") + y: float = Field(..., description="The y coordinate of the pose") + z: float = Field(..., description="The z coordinate of the pose") + yaw: float = Field(..., description="The yaw angle of the pose") + + +class NavigateToPoseBlockingTool(BaseROS2Tool): + name: str = "navigate_to_pose_blocking" + description: str = "Navigate to a specific pose" + frame_id: str = Field( + default="map", description="The frame id of the Nav2 stack (map, odom, etc.)" + ) + action_name: str = Field( + default="navigate_to_pose", description="The name of the Nav2 action" + ) + args_schema: Type[NavigateToPoseBlockingToolInput] = NavigateToPoseBlockingToolInput + + def _run(self, x: float, y: float, z: float, yaw: float) -> str: + action_client = ActionClient( + self.connector.node, NavigateToPose, self.action_name + ) + + pose = PoseStamped() + pose.header.frame_id = self.frame_id + pose.header.stamp = self.connector.node.get_clock().now().to_msg() + pose.pose.position.x = x + pose.pose.position.y = y + pose.pose.position.z = z + quat = quaternion_from_euler(0, 0, yaw) + pose.pose.orientation = Quaternion(x=quat[0], y=quat[1], z=quat[2], w=quat[3]) + + goal = NavigateToPose.Goal() + goal.pose = pose + + result = action_client.send_goal(goal) + + if result is None: + return "Navigate to pose action failed. Please try again." + + if result.result.error_code != 0: + return f"Navigate to pose action failed. Error code: {result.result.error_code}" + + return "Navigate to pose successful." From ee2d7a779055aa275c824f480cd553b36d5c725b Mon Sep 17 00:00:00 2001 From: jmatejcz Date: Mon, 15 Sep 2025 16:12:52 +0200 Subject: [PATCH 09/14] fix: remove leftovers from spatial tasks --- docs/tutorials/benchmarking.md | 4 +- src/rai_bench/rai_bench/test_models.py | 6 +- .../tool_calling_agent/predefined/__init__.py | 2 - .../predefined/spatial_reasoning_tasks.py | 254 ------------------ .../tool_calling_agent/predefined/tasks.py | 9 +- 5 files changed, 5 insertions(+), 270 deletions(-) delete mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index 6e014cbbf..ccb933a1c 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -73,7 +73,7 @@ if __name__ == "__main__": extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", - "manipulation", + "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt prompt_detail=["brief", "descriptive"], # how descriptive should task prompt be @@ -94,7 +94,7 @@ if __name__ == "__main__": ) ``` -Based on the example above the `Tool Calling` benchmark will run basic, spatial_reasoning and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. +Based on the example above the `Tool Calling` benchmark will run basic and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. !!! note diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index 5f917a600..e156cb069 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -79,9 +79,9 @@ class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): complexities : List[Literal["easy", "medium", "hard"]], optional complexity levels of tasks to include in the benchmark, by default all levels are included: ["easy", "medium", "hard"] - task_types : List[Literal["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"]], optional + task_types : List[Literal["basic", "manipulation", "navigation", "custom_interfaces"], optional types of tasks to include in the benchmark, by default all types are included: - ["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"] + ["basic", "manipulation", "navigation", "custom_interfaces"] For more detailed explanation of parameters, see the documentation: (https://robotecai.github.io/rai/simulation_and_benchmarking/rai_bench/) @@ -183,7 +183,6 @@ def test_models( out_dir: str, additional_model_args: Optional[List[Dict[str, Any]]] = None, ): - # TODO (jmatejcz) add docstring after passing agent logic will be added if additional_model_args is None: additional_model_args = [{} for _ in model_names] @@ -213,7 +212,6 @@ def test_models( vendor=vendors[i], **additional_model_args[i], ) - # TODO (jmatejcz) take param to set log level bench_logger = define_benchmark_logger(out_dir=Path(curr_out_dir)) try: if isinstance(bench_conf, ToolCallingAgentBenchmarkConfig): diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py index 53ff03a2a..af2958f59 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py @@ -15,11 +15,9 @@ from .basic_tasks import get_basic_tasks from .custom_interfaces_tasks import get_custom_interfaces_tasks from .manipulation_tasks import get_manipulation_tasks -from .spatial_reasoning_tasks import get_spatial_tasks __all__ = [ "get_basic_tasks", "get_custom_interfaces_tasks", "get_manipulation_tasks", - "get_spatial_tasks", ] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py deleted file mode 100644 index d3ccbfa5e..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Literal, Sequence - -from rai_bench.tool_calling_agent.interfaces import ( - Task, - TaskArgs, -) -from rai_bench.tool_calling_agent.subtasks import ( - CheckArgsToolCallSubTask, -) -from rai_bench.tool_calling_agent.tasks.spatial import ( - BoolImageTaskEasy, - BoolImageTaskHard, - BoolImageTaskInput, - BoolImageTaskMedium, -) -from rai_bench.tool_calling_agent.validators import ( - OrderedCallsValidator, -) - -IMG_PATH = "src/rai_bench/rai_bench/tool_calling_agent/predefined/images/" -########## SUBTASKS ################################################################# -return_true_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": True} -) -return_false_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": False} -) - -######### VALIDATORS ######################################################################################### -ret_true_ord_val = OrderedCallsValidator(subtasks=[return_true_subtask]) -ret_false_ord_val = OrderedCallsValidator(subtasks=[return_false_subtask]) - - -def get_spatial_tasks( - extra_tool_calls: List[int] = [0], - prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], - n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], -) -> Sequence[Task]: - """Get predefined spatial reasoning tasks. - - Parameters - ---------- - Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. - See the class documentation for parameter descriptions. - - Returns - ------- - Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. - """ - tasks: List[Task] = [] - - # Categorize tasks by complexity based on question difficulty - easy_true_inputs = [ - # Single object presence/detection - BoolImageTaskInput( - question="Is the chair in the room?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", images_paths=[IMG_PATH + "image_2.jpg"] - ), - BoolImageTaskInput( - question="Are there any pictures on the wall?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="is there a TV in the room?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - ] - - medium_true_inputs = [ - # Object state or counting - BoolImageTaskInput( - question="Are there 3 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is the light on in the room?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Is there something to sit on?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - ] - - hard_true_inputs = [ - # Spatial relationships between objects - BoolImageTaskInput( - question="Is the door on the left from the desk?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant behind the rack?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Is there a rug under the bed?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Is there a pillow on the armchain?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - ] - - easy_false_inputs = [ - # Single object presence/detection - BoolImageTaskInput( - question="Is someone in the room?", images_paths=[IMG_PATH + "image_1.jpg"] - ), - BoolImageTaskInput( - question="Do you see the plant?", images_paths=[IMG_PATH + "image_3.jpg"] - ), - BoolImageTaskInput( - question="Is there a red pillow on the armchair?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - BoolImageTaskInput( - question="Is there a red desk with chair in the room?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Do you see the bed?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - ] - - medium_false_inputs = [ - # Object state or counting - BoolImageTaskInput( - question="Is the door open?", images_paths=[IMG_PATH + "image_1.jpg"] - ), - BoolImageTaskInput( - question="Are there 4 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is the TV switched on?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is the window opened?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - ] - - hard_false_inputs = [ - # Spatial relationships between objects - BoolImageTaskInput( - question="Is there a rack on the left from the sofa?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant on the right from the window?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is the chair next to a bed?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - ] - - for extra_calls in extra_tool_calls: - for detail in prompt_detail: - for shots in n_shots: - task_args = TaskArgs( - extra_tool_calls=extra_calls, - prompt_detail=detail, - examples_in_system_prompt=shots, - ) - - tasks.extend( - [ - BoolImageTaskEasy( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in easy_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskEasy( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in easy_false_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskMedium( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in medium_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskMedium( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in medium_false_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskHard( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in hard_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskHard( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in hard_false_inputs - ] - ) - - return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py index c25c6d1e0..d2302ce22 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py @@ -22,7 +22,6 @@ get_basic_tasks, get_custom_interfaces_tasks, get_manipulation_tasks, - get_spatial_tasks, ) @@ -53,7 +52,7 @@ def get_tasks( Returns ------- List[Task] - sequence of spatial reasoning tasks with varying difficulty levels. + sequence of tasks with varying difficulty levels. There will be every combination of extra_tool_calls x prompt_detail x n_shots tasks generated. """ @@ -76,12 +75,6 @@ def get_tasks( prompt_detail=prompt_detail, n_shots=n_shots, ) - if "spatial_reasoning" in task_types: - all_tasks += get_spatial_tasks( - extra_tool_calls=extra_tool_calls, - prompt_detail=prompt_detail, - n_shots=n_shots, - ) filtered_tasks: List[Task] = [] for task in all_tasks: From 1c7c97546aceb6d20d4eea76bf097b6cabf92aa5 Mon Sep 17 00:00:00 2001 From: jmatejcz Date: Mon, 15 Sep 2025 16:24:27 +0200 Subject: [PATCH 10/14] fix: add missing inital state --- .../rai_bench/tool_calling_agent/benchmark.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index d1843b843..b95e18ae8 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -27,6 +27,7 @@ create_conversational_agent, ) from rai.agents.langchain.core.react_agent import ReActAgentState +from rai.messages import HumanMultimodalMessage from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, TimeoutException @@ -227,7 +228,12 @@ def run_benchmark( system_prompt=task.get_system_prompt(), logger=bench_logger, ) - benchmark.run_next(agent=agent, experiment_id=experiment_id) + initial_state = ReActAgentState( + messages=[HumanMultimodalMessage(content=task.get_prompt())] + ) + benchmark.run_next( + agent=agent, initial_state=initial_state, experiment_id=experiment_id + ) bench_logger.info("===============================================================") bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") @@ -269,7 +275,12 @@ def run_benchmark_dual_agent( debug=False, ) - benchmark.run_next(agent=agent, experiment_id=experiment_id) + initial_state = ReActAgentState( + messages=[HumanMultimodalMessage(content=task.get_prompt())] + ) + benchmark.run_next( + agent=agent, initial_state=initial_state, experiment_id=experiment_id + ) bench_logger.info("===============================================================") bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") From e0bc56162bed5c3922f1eef1487eb253a78b5a5c Mon Sep 17 00:00:00 2001 From: jmatejcz Date: Mon, 15 Sep 2025 16:53:06 +0200 Subject: [PATCH 11/14] chore: remove dual agent functionality --- src/rai_bench/rai_bench/__init__.py | 2 - src/rai_bench/rai_bench/agents.py | 123 ------------------ .../rai_bench/examples/dual_agent.py | 47 ------- .../rai_bench/manipulation_o3de/__init__.py | 4 +- .../rai_bench/manipulation_o3de/benchmark.py | 57 +------- src/rai_bench/rai_bench/test_models.py | 70 ---------- .../rai_bench/tool_calling_agent/__init__.py | 4 +- .../rai_bench/tool_calling_agent/benchmark.py | 62 +-------- 8 files changed, 12 insertions(+), 357 deletions(-) delete mode 100644 src/rai_bench/rai_bench/agents.py delete mode 100644 src/rai_bench/rai_bench/examples/dual_agent.py diff --git a/src/rai_bench/rai_bench/__init__.py b/src/rai_bench/rai_bench/__init__.py index 9ca566649..5dd4df1e4 100644 --- a/src/rai_bench/rai_bench/__init__.py +++ b/src/rai_bench/rai_bench/__init__.py @@ -14,7 +14,6 @@ from .test_models import ( ManipulationO3DEBenchmarkConfig, ToolCallingAgentBenchmarkConfig, - test_dual_agents, test_models, ) from .utils import ( @@ -33,6 +32,5 @@ "parse_manipulation_o3de_benchmark_args", "parse_tool_calling_benchmark_args", "parse_vlm_benchmark_args", - "test_dual_agents", "test_models", ] diff --git a/src/rai_bench/rai_bench/agents.py b/src/rai_bench/rai_bench/agents.py deleted file mode 100644 index a38e4e9d0..000000000 --- a/src/rai_bench/rai_bench/agents.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from functools import partial -from typing import List, Optional - -from langchain.chat_models.base import BaseChatModel -from langchain_core.messages import ( - AIMessage, - BaseMessage, - HumanMessage, -) -from langchain_core.tools import BaseTool -from langgraph.graph import START, StateGraph -from langgraph.graph.state import CompiledStateGraph -from langgraph.prebuilt.tool_node import tools_condition -from rai.agents.langchain.core.conversational_agent import State, agent -from rai.agents.langchain.core.tool_runner import ToolRunner - - -def multimodal_to_tool_bridge(state: State): - """Node of langchain workflow designed to bridge - nodes with llms. Removing images for context - """ - - cleaned_messages: List[BaseMessage] = [] - for msg in state["messages"]: - if isinstance(msg, HumanMessage): - # Remove images but keep the direct request - if isinstance(msg.content, list): - # Extract text only - text_parts = [ - part.get("text", "") - for part in msg.content - if isinstance(part, dict) and part.get("type") == "text" - ] - if text_parts: - cleaned_messages.append(HumanMessage(content=" ".join(text_parts))) - else: - cleaned_messages.append(msg) - elif isinstance(msg, AIMessage): - # Keep AI messages for context - cleaned_messages.append(msg) - - state["messages"] = cleaned_messages - return state - - -def create_multimodal_to_tool_agent( - multimodal_llm: BaseChatModel, - tool_llm: BaseChatModel, - tools: List[BaseTool], - multimodal_system_prompt: str, - tool_system_prompt: str, - logger: Optional[logging.Logger] = None, - debug: bool = False, -) -> CompiledStateGraph: - """ - Creates an agent flow where inputs first go to a multimodal LLM, - then its output is passed to a tool-calling LLM. - Can be usefull when multimodal llm does not provide tool calling. - - Args: - tools: List of tools available to the tool agent - - Returns: - Compiled state graph - """ - _logger = None - if logger: - _logger = logger - else: - _logger = logging.getLogger(__name__) - - _logger.info("Creating multimodal to tool agent flow") - - tool_llm_with_tools = tool_llm.bind_tools(tools) - tool_node = ToolRunner(tools=tools, logger=_logger) - - workflow = StateGraph(State) - workflow.add_node( - "thinker", - partial(agent, multimodal_llm, _logger, multimodal_system_prompt), - ) - # context bridge for altering the - workflow.add_node( - "context_bridge", - multimodal_to_tool_bridge, - ) - workflow.add_node( - "tool_agent", - partial(agent, tool_llm_with_tools, _logger, tool_system_prompt), - ) - workflow.add_node("tools", tool_node) - - workflow.add_edge(START, "thinker") - workflow.add_edge("thinker", "context_bridge") - workflow.add_edge("context_bridge", "tool_agent") - - workflow.add_conditional_edges( - "tool_agent", - tools_condition, - ) - - # Tool node goes back to tool agent - workflow.add_edge("tools", "tool_agent") - - app = workflow.compile(debug=debug) - _logger.info("Multimodal to tool agent flow created") - return app diff --git a/src/rai_bench/rai_bench/examples/dual_agent.py b/src/rai_bench/rai_bench/examples/dual_agent.py deleted file mode 100644 index 5ba89b5ac..000000000 --- a/src/rai_bench/rai_bench/examples/dual_agent.py +++ /dev/null @@ -1,47 +0,0 @@ -# # Copyright (C) 2025 Robotec.AI -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -from langchain_community.chat_models import ChatOllama -from langchain_openai import ChatOpenAI - -from rai_bench import ( - ManipulationO3DEBenchmarkConfig, - test_dual_agents, -) - -if __name__ == "__main__": - # Define models you want to benchmark - model_name = "gemma3:4b" - m_llm = ChatOllama( - model=model_name, base_url="http://localhost:11434", keep_alive=30 - ) - - tool_llm = ChatOpenAI(model="gpt-4o-mini", base_url="https://api.openai.com/v1/") - # Define benchmarks that will be used - - man_conf = ManipulationO3DEBenchmarkConfig( - o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config - levels=[ # define what difficulty of tasks to include in benchmark - "trivial", - ], - repeats=1, # how many times to repeat - ) - - out_dir = "src/rai_bench/rai_bench/experiments/dual_agents/" - - test_dual_agents( - multimodal_llms=[m_llm], - tool_calling_models=[tool_llm], - benchmark_configs=[man_conf], - out_dir=out_dir, - ) diff --git a/src/rai_bench/rai_bench/manipulation_o3de/__init__.py b/src/rai_bench/rai_bench/manipulation_o3de/__init__.py index 206300fff..8ada5037a 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/__init__.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .benchmark import run_benchmark, run_benchmark_dual_agent +from .benchmark import run_benchmark from .predefined.scenarios import get_scenarios -__all__ = ["get_scenarios", "run_benchmark", "run_benchmark_dual_agent"] +__all__ = ["get_scenarios", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index 05198d088..8bf9421d5 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -16,7 +16,7 @@ import time import uuid from pathlib import Path -from typing import List, Optional, TypeVar +from typing import List, TypeVar import rclpy from langchain.tools import BaseTool @@ -45,7 +45,6 @@ ) from rai_open_set_vision.tools import GetGrabbingPointTool -from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, RunSummary, TimeoutException from rai_bench.manipulation_o3de.interfaces import Task from rai_bench.manipulation_o3de.results_tracking import ( @@ -452,57 +451,3 @@ def run_benchmark( connector.shutdown() o3de.shutdown() rclpy.shutdown() - - -def run_benchmark_dual_agent( - multimodal_llm: BaseChatModel, - tool_calling_llm: BaseChatModel, - out_dir: Path, - scenarios: List[Scenario], - o3de_config_path: str, - bench_logger: logging.Logger, - experiment_id: uuid.UUID = uuid.uuid4(), - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - connector, o3de, benchmark, tools = _setup_benchmark_environment( - o3de_config_path, - get_llm_model_name(multimodal_llm), - scenarios, - out_dir, - bench_logger, - ) - basic_tool_system_prompt = ( - "Based on the conversation call the tools with appropriate arguments" - ) - try: - for scenario in scenarios: - agent = create_multimodal_to_tool_agent( - multimodal_llm=multimodal_llm, - tool_llm=tool_calling_llm, - tools=tools, - multimodal_system_prompt=( - m_system_prompt if m_system_prompt else scenario.task.system_prompt - ), - tool_system_prompt=( - tool_system_prompt - if tool_system_prompt - else basic_tool_system_prompt - ), - logger=bench_logger, - ) - - benchmark.run_next(agent=agent, experiment_id=experiment_id) - - bench_logger.info( - "===============================================================" - ) - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info( - "===============================================================" - ) - - finally: - connector.shutdown() - o3de.shutdown() - rclpy.shutdown() diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index e156cb069..84622a5ad 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -17,7 +17,6 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional -from langchain.chat_models.base import BaseChatModel from pydantic import BaseModel import rai_bench.manipulation_o3de as manipulation_o3de @@ -25,7 +24,6 @@ from rai_bench.utils import ( define_benchmark_logger, get_llm_for_benchmark, - get_llm_model_name, ) @@ -108,74 +106,6 @@ def name(self) -> str: return "tool_calling_agent" -def test_dual_agents( - multimodal_llms: List[BaseChatModel], - tool_calling_models: List[BaseChatModel], - benchmark_configs: List[BenchmarkConfig], - out_dir: str, - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - if len(multimodal_llms) != len(tool_calling_models): - raise ValueError( - "Number of passed multimodal models must match number of passed tool calling models" - ) - experiment_id = uuid.uuid4() - for bench_conf in benchmark_configs: - # for each bench configuration seperate run folder - now = datetime.now() - run_name = f"run_{now.strftime('%Y-%m-%d_%H-%M-%S')}" - for i, m_llm in enumerate(multimodal_llms): - tool_llm = tool_calling_models[i] - for u in range(bench_conf.repeats): - curr_out_dir = ( - out_dir - + "/" - + run_name - + "/" - + bench_conf.name - + "/" - + get_llm_model_name(m_llm) - + "/" - + str(u) - ) - bench_logger = define_benchmark_logger(out_dir=Path(curr_out_dir)) - try: - if isinstance(bench_conf, ToolCallingAgentBenchmarkConfig): - tool_calling_tasks = tool_calling_agent.get_tasks( - extra_tool_calls=bench_conf.extra_tool_calls, - complexities=bench_conf.complexities, - task_types=bench_conf.task_types, - ) - tool_calling_agent.run_benchmark_dual_agent( - multimodal_llm=m_llm, - tool_calling_llm=tool_llm, - m_system_prompt=m_system_prompt, - tool_system_prompt=tool_system_prompt, - out_dir=Path(curr_out_dir), - tasks=tool_calling_tasks, - experiment_id=experiment_id, - bench_logger=bench_logger, - ) - elif isinstance(bench_conf, ManipulationO3DEBenchmarkConfig): - manipulation_o3de_scenarios = manipulation_o3de.get_scenarios( - levels=bench_conf.levels, - logger=bench_logger, - ) - manipulation_o3de.run_benchmark_dual_agent( - multimodal_llm=m_llm, - tool_calling_llm=tool_llm, - out_dir=Path(curr_out_dir), - o3de_config_path=bench_conf.o3de_config_path, - scenarios=manipulation_o3de_scenarios, - experiment_id=experiment_id, - bench_logger=bench_logger, - ) - except Exception as e: - bench_logger.critical(f"BENCHMARK RUN FAILED: {e}") - raise e - - def test_models( model_names: List[str], vendors: List[str], diff --git a/src/rai_bench/rai_bench/tool_calling_agent/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/__init__.py index 5f9771011..c4c668ac6 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/__init__.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .benchmark import run_benchmark, run_benchmark_dual_agent +from .benchmark import run_benchmark from .predefined.tasks import get_tasks -__all__ = ["get_tasks", "run_benchmark", "run_benchmark_dual_agent"] +__all__ = ["get_tasks", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index b95e18ae8..8d22bbb2c 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -16,7 +16,7 @@ import time import uuid from pathlib import Path -from typing import Iterator, List, Optional, Sequence, Tuple +from typing import Iterator, List, Sequence, Tuple from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage @@ -26,10 +26,8 @@ from rai.agents.langchain.core import ( create_conversational_agent, ) -from rai.agents.langchain.core.react_agent import ReActAgentState from rai.messages import HumanMultimodalMessage -from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, TimeoutException from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler from rai_bench.tool_calling_agent.interfaces import ( @@ -68,7 +66,7 @@ def __init__( def run_next( self, agent: CompiledStateGraph, - initial_state: ReActAgentState, + initial_state: dict, experiment_id: uuid.UUID, ) -> None: """Runs the next task of the benchmark. @@ -228,58 +226,12 @@ def run_benchmark( system_prompt=task.get_system_prompt(), logger=bench_logger, ) - initial_state = ReActAgentState( - messages=[HumanMultimodalMessage(content=task.get_prompt())] - ) - benchmark.run_next( - agent=agent, initial_state=initial_state, experiment_id=experiment_id - ) - - bench_logger.info("===============================================================") - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info("===============================================================") - - -def run_benchmark_dual_agent( - multimodal_llm: BaseChatModel, - tool_calling_llm: BaseChatModel, - out_dir: Path, - tasks: List[Task], - bench_logger: logging.Logger, - experiment_id: uuid.UUID = uuid.uuid4(), - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - benchmark = ToolCallingAgentBenchmark( - tasks=tasks, - logger=bench_logger, - model_name=get_llm_model_name(multimodal_llm), - results_dir=out_dir, - ) - - basic_tool_system_prompt = ( - "Based on the conversation call the tools with appropriate arguments" - ) - for task in tasks: - agent = create_multimodal_to_tool_agent( - multimodal_llm=multimodal_llm, - tool_llm=tool_calling_llm, - tools=task.available_tools, - multimodal_system_prompt=( - m_system_prompt if m_system_prompt else task.get_system_prompt() - ), - tool_system_prompt=( - tool_system_prompt if tool_system_prompt else basic_tool_system_prompt - ), - logger=bench_logger, - debug=False, - ) - - initial_state = ReActAgentState( - messages=[HumanMultimodalMessage(content=task.get_prompt())] - ) benchmark.run_next( - agent=agent, initial_state=initial_state, experiment_id=experiment_id + agent=agent, + initial_state={ + "messages": [HumanMultimodalMessage(content=task.get_prompt())] + }, + experiment_id=experiment_id, ) bench_logger.info("===============================================================") From a9abf2752d3d272832efb5441c1e15a7d9dd1e35 Mon Sep 17 00:00:00 2001 From: jmatejcz Date: Mon, 15 Sep 2025 16:53:29 +0200 Subject: [PATCH 12/14] refactor: adjust example --- .../rai_bench/examples/benchmarking_models.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index 114ec1c23..43cdb3099 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -20,7 +20,7 @@ if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen3:4b", "llama3.2:3b"] + model_names = ["qwen2.5:3b", "llama3.2:3b"] vendors = ["ollama", "ollama"] # Define benchmarks that will be used @@ -36,6 +36,7 @@ extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", + "manipulation", "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt @@ -47,11 +48,6 @@ test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[mani_conf, tool_conf], + benchmark_configs=[tool_conf, mani_conf], out_dir=out_dir, - # if you want to pass any additinal args to model - additional_model_args=[ - {"reasoning": False}, - {}, - ], ) From e542f2274d3e79c6e5a761cd55f14de55183ec0d Mon Sep 17 00:00:00 2001 From: jmatejcz Date: Mon, 15 Sep 2025 16:56:06 +0200 Subject: [PATCH 13/14] chore: bump version --- src/rai_core/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index d13258cbc..714ff448a 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.4.0" +version = "2.5.0" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" From 2e94e2f062c37295198adbaf667d6e9aefc734cf Mon Sep 17 00:00:00 2001 From: jmatejcz Date: Mon, 15 Sep 2025 17:00:38 +0200 Subject: [PATCH 14/14] docs: include vlm benchmark in docs --- docs/simulation_and_benchmarking/rai_bench.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index 7ef070187..24c033a1a 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -6,6 +6,7 @@ RAI Bench is a comprehensive package that both provides benchmarks with ready-to - [Manipulation O3DE Benchmark](#manipulation-o3de-benchmark) - [Tool Calling Agent Benchmark](#tool-calling-agent-benchmark) +- [VLM Benchmark](#vlm-benchmark) ## Manipulation O3DE Benchmark @@ -163,3 +164,17 @@ class TaskArgs(BaseModel): - `GetROS2RGBCameraTask` has 1 required tool call and 1 optional. When `extra_tool_calls` set to 5, agent can correct himself couple times and still pass even with 7 tool calls. There can be 2 types of invalid tool calls, first when the tool is used incorrectly and agent receives an error - this allows him to correct himself easier. Second type is when tool is called properly but it is not the tool that should be called or it is called with wrong params. In this case agent won't get any error so it will be harder for him to correct, but BOTH of these cases are counted as `extra tool call`. If you want to know details about every task, visit `rai_bench/tool_calling_agent/tasks` + +## VLM Benchmark + +The VLM Benchmark is a benchmark for VLM models. It includes a set of tasks containing questions related to images and evaluates the performance of the agent that returns the answer in the structured format. + +### Running + +To run the benchmark: + +```bash +cd rai +source setup_shell.sh +python src/rai_bench/rai_bench/examples/vlm_benchmark.py --model-name gemma3:4b --vendor ollama +```