From 761317e940314b89ca566b30f2fb175cb609e5a1 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Tue, 1 Jul 2025 12:49:17 +0200 Subject: [PATCH 01/13] feat: tool calling benchmark unified across types and prompts variety (#620) --- docs/simulation_and_benchmarking/rai_bench.md | 42 +- docs/tutorials/benchmarking.md | 54 +- src/rai_bench/pyproject.toml | 2 +- .../rai_bench/examples/benchmarking_models.py | 14 +- .../rai_bench/examples/tool_calling_agent.py | 2 + .../results_processing/data_loading.py | 28 +- .../results_processing/data_processing.py | 34 +- .../visualise/tool_calling_agent_display.py | 132 +- src/rai_bench/rai_bench/test_models.py | 41 +- .../rai_bench/tool_calling_agent/benchmark.py | 4 +- .../tool_calling_agent/interfaces.py | 66 +- .../mocked_ros2_interfaces.py | 3400 +++++++++++++++++ .../tool_calling_agent/predefined/__init__.py | 27 + .../predefined/basic_tasks.py | 306 ++ .../predefined/custom_interfaces_tasks.py | 95 + .../predefined/manipulation_tasks.py | 104 + .../predefined/navigation_tasks.py | 118 + .../predefined/spatial_reasoning_tasks.py | 278 ++ .../tool_calling_agent/predefined/tasks.py | 460 +-- .../tool_calling_agent/results_tracking.py | 6 +- .../tool_calling_agent/tasks/basic.py | 372 +- .../tasks/custom_interfaces.py | 1232 ++---- .../tool_calling_agent/tasks/manipulation.py | 376 +- .../tool_calling_agent/tasks/navigation.py | 938 +---- .../tool_calling_agent/tasks/spatial.py | 134 +- src/rai_bench/rai_bench/utils.py | 18 +- .../initialization/model_initialization.py | 2 + 27 files changed, 5532 insertions(+), 2753 deletions(-) create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index d07905ac5..be31f4a05 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -109,7 +109,7 @@ The `Validator` class can combine single or multiple subtasks to create a single ### Task -A Task represents a specific prompt and set of tools available. A list of validators is assigned to validate the performance. +A Task represents a specific prompts and set of tools available. A list of validators is assigned to validate the performance. ??? info "Task class definition" @@ -117,20 +117,52 @@ A Task represents a specific prompt and set of tools available. A list of valida As you can see, the framework is very flexible. Any SubTask can be combined into any Validator that can be later assigned to any Task. +Every Task needs to define it's prompt and system prompt, what tools agent will have available, how many tool calls are required to complete it and how many optional tool calls are possible. + +Optional tool calls mean that a certain tool calls is not obligatory to pass the Task, but shoudn't be considered an error, example: `GetROS2RGBCameraTask` which has prompt: `Get RGB camera image.` requires making one tool call with `get_ros2_image` tool. But listing topics before doing it is a valid approach, so in this case opitonal tool calls is `1`. + ### ToolCallingAgentBenchmark The ToolCallingAgentBenchmark class manages the execution of tasks and collects results. ### Available Tasks -Tasks of this benchmark are grouped by type: +There are predefined Tasks available which are grouped by categories: -- Basic - basic usage of tools +- Basic - require retrieving info from certain topics - Navigation - Spatial reasoning - questions about surroundings with images attached - Manipulation - Custom Interfaces - requires using messages with custom interfaces -If you want to know details about every task, visit `rai_bench/tool_calling_agent/tasks` +Every Task has assigned the `complexity` which reflects the difficulty. + +When creating a Task, you can define few params: + +```python +class TaskArgs(BaseModel): + """Holds the configurations specified by user""" + + extra_tool_calls: int = 0 + prompt_detail: Literal["brief", "descriptive"] = "brief" + examples_in_system_prompt: Literal[0, 2, 5] = 0 +``` -## Test Models +- examples_in_system_prompt - How many examples there are in system prompts, example: + + - `0`: `You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. Be proactive and use the tools to answer questions.` + - `2`: `You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. Be proactive and use the tools to answer questions. Example of tool calls: get_ros2_message_interface, args: {'msg_type': 'geometry_msgs/msg/Twist'} publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}}` + +- prompt_detail - How descriptive should the Task prompt be, example: + + - `brief`: "Get all camera images" + - `descriptive`: "Get all camera images from all available camera sources in the system. + This includes both RGB color images and depth images. + You can discover what camera topics are available and capture images from each." + + Descriptive prompts provides guidance and tips. + +- extra_tool_calls - How many extra tool calls an agent can make and still pass the Task, example: + - `GetROS2RGBCameraTask` has 1 required tool call and 1 optional. When `extra_tool_calls` set to 5, agent can correct himself couple times and still pass even with 7 tool calls. There can be 2 types of invalid tool calls, first when the tool is used incorrectly and agent receives an error - this allows him to correct himself easier. Second type is when tool is called properly but it is not the tool that should be called or it is called with wrong params. In this case agent won't get any error so it will be harder for him to correct, but BOTH of these cases are counted as `extra tool call`. + +If you want to know details about every task, visit `rai_bench/tool_calling_agent/tasks` diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index fb4cb663d..deb9d20a9 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -53,12 +53,12 @@ If your goal is creating custom tasks and scenarios, visit [Creating Custom Task This benchmark does not require any additional setup besides the main one [Basic Setup](../setup/install.md), just run: ```bash -python src/rai_bench/rai_bench/examples/tool_calling_agent.py --model-name --vendor --extra-tool-calls <5> --task-types --out-dir +python src/rai_bench/rai_bench/examples/tool_calling_agent.py --model-name --vendor --extra-tool-calls <0 5> --task-types basic --n-shots <0 2> --prompt-detail --complexities --out-dir ``` !!! note - This Benchmark is significantly faster, but still if just trying out, we recommend choosing just one task-type. + This Benchmark is significantly faster, but still, if just trying out, we recommend choosing just one parameter per flag as every combination on params will create more tasks. ## Testing Models @@ -90,12 +90,17 @@ if __name__ == "__main__": ], repeats=1, # how many times to repeat ) - tool_conf = ToolCallingAgentBenchmarkConfig( - extra_tool_calls=5, # how many extra tool calls allowed to still pass + tool_conf = ToolCallingAgentBenchmarkConfig( + extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", "spatial_reasoning", - "manipulation", + "custom_interfaces", + ], + N_shots=[0, 2], # examples in system prompt + prompt_detail=[ # how descriptive should task prompt be + "brief", + "descriptive" ], repeats=1, ) @@ -222,6 +227,21 @@ class ThrowObjectsOffTableTask(ManipulationTask): incorrect: int = len(selected_type_objects) - correct return correct, incorrect + +# configure existing Task with different params +target_coords = (0.1, 0.1) +disp = 0.1 +task = PlaceObjectAtCoordTask( + obj_type="apple", + target_position=target_coords, + allowable_displacement=disp, +) + +Scenario( + task=task, + scene_config=scene_config, + scene_config_path=path_to_your_config +) ``` As `obj_types` is parameterizable, it enables various variants of this Task. In combination with a lot of simulation configs available, it means that a single Task can provide dozens of scenarios. @@ -240,23 +260,14 @@ from rai_bench.tool_calling_agent.subtasks import ( from rai_bench.tool_calling_agent.validators import ( OrderedCallsValidator, ) -from rai_bench.tool_calling_agent.tasks.basic import BasicTask from rai_bench.tool_calling_agent.mocked_tools import ( MockGetROS2TopicsNamesAndTypesTool, ) +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs from langchain_core.tools import BaseTool from typing import List -# configure existing Task with different params -target_coords = (0.1, 0.1) -disp = 0.1 -task = PlaceObjectAtCoordTask( - obj_type="apple", - target_position=target_coords, - allowable_displacement=disp, -) -Scenario(task=task, scene_config=scene_config, scene_config_path=path_to_your_config) # define subtask that requires receive_robot_pos_subtask = CheckArgsToolCallSubTask( @@ -270,7 +281,7 @@ receive_robot_pos_subtask = CheckArgsToolCallSubTask( topics_ord_val = OrderedCallsValidator(subtasks=[receive_robot_pos_subtask]) -class GetROS2RobotPositionTask(BasicTask): +class GetROS2RobotPositionTask(Task): complexity = "easy" @property @@ -287,9 +298,18 @@ class GetROS2RobotPositionTask(BasicTask): ), ] + def get_system_prompt(self) -> str: + return "You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system." + def get_prompt(self) -> str: return "Get the position of the robot." + @property + def optional_tool_calls_number(self) -> int: + # Listing topics before getting any message + return 1 + # optionally pass number of extra tool calls -task = GetROS2RobotPositionTask(validators=[topics_ord_val], extra_tool_calls=1) +args = TaskArgs(extra_tool_calls=0) +task = GetROS2RobotPositionTask(validators=[topics_ord_val], task_args=args) ``` diff --git a/src/rai_bench/pyproject.toml b/src/rai_bench/pyproject.toml index 90d820982..010d8aca5 100644 --- a/src/rai_bench/pyproject.toml +++ b/src/rai_bench/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rai-bench" -version = "0.1.2" +version = "0.2.0" description = "Package for running and creating benchmarks." authors = ["Jakub Matejczyk ", "Magdalena Kotynia "] readme = "README.md" diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index 3a3feefa3..c4008eef4 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -20,11 +20,11 @@ if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen2.5:7b", "llama3.2:3b"] - vendors = ["ollama", "ollama"] + model_names = ["qwen2.5:7b"] + vendors = ["ollama"] # Define benchmarks that will be used - man_conf = ManipulationO3DEBenchmarkConfig( + mani_conf = ManipulationO3DEBenchmarkConfig( o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config levels=[ # define what difficulty of tasks to include in benchmark "trivial", @@ -32,12 +32,16 @@ repeats=1, # how many times to repeat ) tool_conf = ToolCallingAgentBenchmarkConfig( - extra_tool_calls=5, # how many extra tool calls allowed to still pass + extra_tool_calls=[0], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", "spatial_reasoning", + # "navigation", + "custom_interfaces", "manipulation", ], + N_shots=[2], # examples in system prompt + prompt_detail=["brief", "descriptive"], # how descriptive should task prompt be repeats=1, ) @@ -45,6 +49,6 @@ test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[man_conf, tool_conf], + benchmark_configs=[tool_conf], out_dir=out_dir, ) diff --git a/src/rai_bench/rai_bench/examples/tool_calling_agent.py b/src/rai_bench/rai_bench/examples/tool_calling_agent.py index e76c18f00..652efe9a4 100644 --- a/src/rai_bench/rai_bench/examples/tool_calling_agent.py +++ b/src/rai_bench/rai_bench/examples/tool_calling_agent.py @@ -34,6 +34,8 @@ extra_tool_calls=args.extra_tool_calls, complexities=args.complexities, task_types=args.task_types, + n_shots=args.n_shots, + prompt_detail=args.prompt_detail, ) for task in tasks: task.set_logger(bench_logger) diff --git a/src/rai_bench/rai_bench/results_processing/data_loading.py b/src/rai_bench/rai_bench/results_processing/data_loading.py index 429d6fc2d..81c359a38 100644 --- a/src/rai_bench/rai_bench/results_processing/data_loading.py +++ b/src/rai_bench/rai_bench/results_processing/data_loading.py @@ -70,20 +70,19 @@ def convert_row_to_task_result(row: pd.Series) -> TaskResult: ) validator_results.append(validator_result) - return TaskResult( - task_prompt=row["task_prompt"], - system_prompt=row["system_prompt"], - complexity=row["complexity"], - type=row["type"], - model_name=row["model_name"], - validation_info=validator_results, - extra_tool_calls=int(row["extra_tool_calls"]), - extra_tool_calls_used=int(row["extra_tool_calls_used"]), - score=float(row["score"]), - total_time=float(row["total_time"]), - run_id=uuid.UUID(row["run_id"]), + row.update( + { + "validation_info": validator_results, + "extra_tool_calls": int(row["extra_tool_calls"]), + "extra_tool_calls_used": int(row["extra_tool_calls_used"]), + "score": float(row["score"]), + "total_time": float(row["total_time"]), + "run_id": uuid.UUID(row["run_id"]), + } ) + return TaskResult(**row) + def convert_row_to_scenario_result(row: pd.Series) -> ScenarioResult: """ @@ -100,10 +99,7 @@ def convert_row_to_scenario_result(row: pd.Series) -> ScenarioResult: A ScenarioResult object """ return ScenarioResult( - task_prompt=row["task_prompt"], - system_prompt=row["system_prompt"], - model_name=row["model_name"], - scene_config_path=row["scene_config_path"], + **row, score=float(row["score"]), level=row["level"], total_time=float(row["total_time"]), diff --git a/src/rai_bench/rai_bench/results_processing/data_processing.py b/src/rai_bench/rai_bench/results_processing/data_processing.py index 755d3ded4..84073ad63 100644 --- a/src/rai_bench/rai_bench/results_processing/data_processing.py +++ b/src/rai_bench/rai_bench/results_processing/data_processing.py @@ -181,10 +181,14 @@ def create_task_metrics_dataframe( def create_task_details_dataframe( - model_results: ModelResults, task_type: Optional[str] = None + model_results: ModelResults, + task_type: Optional[str] = None, + complexity: Optional[str] = None, + examples_in_system_prompt: Optional[int] = None, + prompt_detail: Optional[str] = None, ) -> pd.DataFrame: """ - Create a DataFrame with task details, optionally filtered by task type. + Create a DataFrame with task details, optionally filtered by multiple criteria. Parameters ---------- @@ -192,6 +196,10 @@ def create_task_details_dataframe( The model results object task_type : Optional[str] Task type to filter by + complexity : Optional[str] + Complexity to filter by + examples_in_system_prompt : Optional[str] + Examples in system prompt to filter by Returns ------- @@ -201,14 +209,30 @@ def create_task_details_dataframe( all_detailed_results = get_all_detailed_results_from_model_results( model_results=model_results ) - if not all_detailed_results: return pd.DataFrame() - # filter by task type + # Apply filters if task_type: all_detailed_results = [r for r in all_detailed_results if r.type == task_type] + if complexity: + all_detailed_results = [ + r for r in all_detailed_results if r.complexity == complexity + ] + + if examples_in_system_prompt: + all_detailed_results = [ + r + for r in all_detailed_results + if r.examples_in_system_prompt == examples_in_system_prompt + ] + + if prompt_detail: + all_detailed_results = [ + r for r in all_detailed_results if r.prompt_detail == prompt_detail + ] + rows: List[Dict[str, Any]] = [ { "task_prompt": result.task_prompt, @@ -217,10 +241,10 @@ def create_task_details_dataframe( "score": result.score, "total_time": result.total_time, "extra_tool_calls_used": result.extra_tool_calls_used, + "examples_in_system_prompt": result.examples_in_system_prompt, } for result in all_detailed_results ] - return pd.DataFrame(rows) diff --git a/src/rai_bench/rai_bench/results_processing/visualise/tool_calling_agent_display.py b/src/rai_bench/rai_bench/results_processing/visualise/tool_calling_agent_display.py index 98f42daec..558a7626f 100644 --- a/src/rai_bench/rai_bench/results_processing/visualise/tool_calling_agent_display.py +++ b/src/rai_bench/rai_bench/results_processing/visualise/tool_calling_agent_display.py @@ -100,20 +100,20 @@ def display_task_type_performance(model_results: ModelResults): st.plotly_chart(fig_type_calls, use_container_width=True) # type: ignore -def display_task_complexity_performance(model_results: ModelResults): +def display_task_performance_by_field(model_results: ModelResults, field: str): """Display performance charts by task complexity.""" - complexity_df = create_task_metrics_dataframe(model_results, "complexity") + metric_df = create_task_metrics_dataframe(model_results, field) - if complexity_df.empty: + if metric_df.empty: st.warning("No complexity data available.") return fig_complexity_score = create_bar_chart( - df=complexity_df, - x_column="complexity", + df=metric_df, + x_column=field, y_column="avg_score", - title="Success Rate by Task Complexity", - x_label="Task Complexity", + title=f"Success Rate by {field}", + x_label=field, y_label="Avg Score", y_range=(0.0, 1.0), count_column="total_tasks", @@ -121,63 +121,43 @@ def display_task_complexity_performance(model_results: ModelResults): st.plotly_chart(fig_complexity_score, use_container_width=True) # type: ignore fig_complexity_calls = create_bar_chart( - df=complexity_df, - x_column="complexity", + df=metric_df, + x_column=field, y_column="avg_extra_tool_calls", - title="Avg Extra Tool Calls Used by Task Complexity", - x_label="Task Complexity", + title=f"Avg Extra Tool Calls Used by {field}", + x_label=field, y_label="Avg Extra Tool Calls Used", count_column="total_tasks", ) st.plotly_chart(fig_complexity_calls, use_container_width=True) # type: ignore -def display_detailed_task_type_analysis( - model_results: ModelResults, selected_type: str +def display_detailed_task_analysis( + model_results: ModelResults, + selected_type: str, + selected_complexity: str, + selected_example_num: str, + selected_prompt_detail: str, ): """Display detailed analysis for a specific task type.""" # first, get only the tasks of the selected type - tasks_for_type_df = create_task_details_dataframe(model_results, selected_type) - if tasks_for_type_df.empty: - st.warning(f"No tasks of type {selected_type} found.") - return - # Now aggregate by complexity for that type - filtered_by_complexity = ( - tasks_for_type_df.groupby("complexity") # type: ignore - .agg( - avg_score=("score", "mean"), - avg_time=("total_time", "mean"), - avg_extra_tool_calls=("extra_tool_calls_used", "mean"), - ) - .reset_index() + tasks_df = create_task_details_dataframe( + model_results, + task_type=selected_type if selected_type != "All" else None, + complexity=selected_complexity if selected_complexity != "All" else None, + examples_in_system_prompt=( + int(selected_example_num) if selected_example_num != "All" else None + ), + prompt_detail=( + selected_prompt_detail if selected_prompt_detail != "All" else None + ), ) - filtered_by_complexity = filtered_by_complexity[ - filtered_by_complexity["complexity"].notna() - ] - - # Display success rate by complexity for the selected task type - if not filtered_by_complexity.empty: - fig_complexity_score = create_bar_chart( - df=filtered_by_complexity, - x_column="complexity", - y_column="avg_score", - title=f"Success Rate by Task Complexity for '{selected_type}' Tasks", - x_label="Task Complexity", - y_label="Avg Score", - y_range=(0.0, 1.0), - count_column="total_tasks", - ) - st.plotly_chart(fig_complexity_score, use_container_width=True) # type: ignore - - # Display success rate by individual task - task_details_df = create_task_details_dataframe(model_results, selected_type) - - if task_details_df.empty: - st.warning(f"No task details available for type: {selected_type}") + if tasks_df.empty: + st.warning(f"No tasks of type {selected_type} found.") return task_stats = ( - task_details_df.groupby("task_prompt") # type: ignore + tasks_df.groupby("task_prompt") # type: ignore .agg({"score": "mean", "total_time": "mean"}) .reset_index() ) @@ -200,7 +180,7 @@ def display_detailed_task_type_analysis( df=task_stats, x_column="task_prompt", y_column="score", - title=f"Avg Score for '{selected_type}' Tasks", + title="Avg Score", x_label="Task", y_label="Avg Score", custom_data=["wrapped_prompt", "score"], @@ -216,7 +196,7 @@ def display_detailed_task_type_analysis( df=task_stats, x_column="task_prompt", y_column="total_time", - title=f"Avg Time for '{selected_type}' Tasks", + title="Avg Time", x_label="Task", y_label="Avg Time (s)", custom_data=["wrapped_prompt", "total_time"], @@ -349,20 +329,58 @@ def render_task_performance_tab(bench_results: BenchmarkResults): # Display performance by complexity st.subheader("Performance by Task Complexity") - display_task_complexity_performance(model_results) + display_task_performance_by_field(model_results, "complexity") + + # Display performance by complexity + st.subheader("Performance by system prompt examples") + display_task_performance_by_field(model_results, "examples_in_system_prompt") + + # Display performance by complexity + st.subheader("Performance by Task's prompt detail") + display_task_performance_by_field(model_results, "prompt_detail") - # Per Task Type Analysis st.subheader("Detailed Task Type Analysis") task_types = get_unique_values_from_results(model_results, "type") - if not task_types: st.warning("No task types available.") return selected_type = st.selectbox( - "Select Task Type", sorted(task_types), key="task_type" + "Select Task Type", ["All"] + sorted(task_types), key="task_type" + ) + + # Add selectboxes for the two additional attributes with "All" as default + complexity_values = get_unique_values_from_results(model_results, "complexity") + selected_complexity = st.selectbox( + "Select Complexity", + ["All"] + complexity_values, + key="complexity_select", + ) + + examples_values = get_unique_values_from_results( + model_results, "examples_in_system_prompt" + ) + selected_examples = st.selectbox( + "Select Examples in System Prompt", + ["All"] + sorted(examples_values), + key="n_shots_select", + ) + prompt_detail_values = get_unique_values_from_results( + model_results, "prompt_detail" + ) + selected_prompt_detail = st.selectbox( + "Select prompt decriptiveness", + ["All"] + prompt_detail_values, + key="prompt_detail_select", + ) + + display_detailed_task_analysis( + model_results, + selected_type, + selected_complexity, + selected_examples, + selected_prompt_detail, ) - display_detailed_task_type_analysis(model_results, selected_type) def render_validator_analysis_tab(bench_results: BenchmarkResults): diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index 51f1c7cb5..def97016c 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -40,7 +40,17 @@ def name(self) -> str: class ManipulationO3DEBenchmarkConfig(BenchmarkConfig): - # by default include all + """Configuration for Manipulation O3DE Benchmark. + + Parameters + ---------- + o3de_config_path : str + path to O3DE configuration file + levels : List[Literal["trivial", "easy", "medium", "hard", "very_hard"]], optional + difficulty levels to include in benchmark, by default all levels are included: + ["trivial", "easy", "medium", "hard", "very_hard"] + """ + o3de_config_path: str levels: List[Literal["trivial", "easy", "medium", "hard", "very_hard"]] = [ "trivial", @@ -56,8 +66,32 @@ def name(self) -> str: class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): - extra_tool_calls: int = 0 + """Configuration for Tool Calling Agent Benchmark. + + Parameters + ---------- + extra_tool_calls : List[int], optional + how many extra tool calls allowed to still pass, by default [0] + prompt_detail : List[Literal["brief", "descriptive"]], optional + how descriptive should task prompt be, by default all levels are included: + ["brief", "descriptive"] + N_shots : List[Literal[0, 2, 5]], optional + how many examples are in system prompt, by default all are included: [0, 2, 5] + complexities : List[Literal["easy", "medium", "hard"]], optional + complexity levels of tasks to include in the benchmark, by default all levels are included: + ["easy", "medium", "hard"] + task_types : List[Literal["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"]], optional + types of tasks to include in the benchmark, by default all types are included: + ["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"] + + For more detailed explanation of parameters, see the documentation: + (https://robotecai.github.io/rai/simulation_and_benchmarking/rai_bench/) + """ + + extra_tool_calls: List[int] = [0] complexities: List[Literal["easy", "medium", "hard"]] = ["easy", "medium", "hard"] + N_shots: List[Literal[0, 2, 5]] = [0, 2, 5] + prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"] task_types: List[ Literal[ "basic", @@ -154,6 +188,7 @@ def test_models( out_dir: str, additional_model_args: Optional[List[Dict[str, Any]]] = None, ): + # TODO (jmatejcz) add docstring after passing agent logic will be added if additional_model_args is None: additional_model_args = [{} for _ in model_names] @@ -189,6 +224,8 @@ def test_models( tool_calling_tasks = tool_calling_agent.get_tasks( extra_tool_calls=bench_conf.extra_tool_calls, complexities=bench_conf.complexities, + prompt_detail=bench_conf.prompt_detail, + n_shots=bench_conf.N_shots, task_types=bench_conf.task_types, ) tool_calling_agent.run_benchmark( diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index 32547b90c..4328c0a8d 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -160,8 +160,10 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: self.logger.info(f"TASK SCORE: {score}, TOTAL TIME: {total_time:.3f}") task_result = TaskResult( - task_prompt=task.get_prompt(), + task_prompt=task.get_base_prompt(), system_prompt=task.get_system_prompt(), + examples_in_system_prompt=task.n_shots, + prompt_detail=task.prompt_detail, type=task.type, extra_tool_calls=task.extra_tool_calls, extra_tool_calls_used=total_extra_calls, diff --git a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py index 65ae83b25..e35d5a6f4 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py @@ -19,6 +19,7 @@ from langchain_core.messages import AIMessage, BaseMessage, ToolCall from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT from langchain_core.tools import BaseTool +from pydantic import BaseModel from rai_bench.tool_calling_agent.results_tracking import SubTaskResult, ValidatorResult @@ -454,14 +455,33 @@ def validate(self, tool_calls: List[ToolCall]) -> Tuple[bool, List[ToolCall]]: pass +class TaskArgs(BaseModel): + """Holds the configurations specified by user + + Parameters + ---------- + extra_tool_calls : int, optional + how many extra tool calls allowed to still pass, by default 0 + prompt_detail : Literal["brief", "descriptive"], optional + how descriptive should task prompt be, by default "brief" + examples_in_system_prompt : Literal[0, 2, 5], optional + how many examples are in system prompt, by default 0 + """ + + extra_tool_calls: int = 0 + prompt_detail: Literal["brief", "descriptive"] = "brief" + examples_in_system_prompt: Literal[0, 2, 5] = 0 + + class Task(ABC): complexity: Literal["easy", "medium", "hard"] + type: str recursion_limit: int = DEFAULT_RECURSION_LIMIT def __init__( self, validators: List[Validator], - extra_tool_calls: int = 0, + task_args: TaskArgs, logger: loggers_type | None = None, ) -> None: """ @@ -473,21 +493,30 @@ def __init__( Attributes ---------- + complexity : Literal["easy", "medium", "hard"] + difficulty level of the task + type : str + type identifier for the task + recursion_limit : int, optional + maximum recursion depth allowed, by default DEFAULT_RECURSION_LIMIT + + Parameters + ---------- validators : List[Validator] List of validators that will be applied in sequence. - extra_tool_calls : int - Number of additional tool calls allowed beyond the minimum required. + task_args : TaskArgs + Configuration parameters for the task specified by user logger : logging.Logger Logger for recording task validation results and errors. - result : Result - Object tracking the validation results across all validators. """ if logger: self.logger = logger else: self.logger = logging.getLogger(__name__) self.validators = validators - self.extra_tool_calls = extra_tool_calls + self.extra_tool_calls = task_args.extra_tool_calls + self.prompt_detail = task_args.prompt_detail + self.n_shots = task_args.examples_in_system_prompt def set_logger(self, logger: loggers_type): self.logger = logger @@ -523,16 +552,25 @@ def available_tools(self) -> List[BaseTool]: @property @abstractmethod - def type(self) -> str: - """Type of task, for example: manipulation""" + def optional_tool_calls_number(self) -> int: + """Optional tool calls means calls that are not considered error. + For example listing topics at the beginning.""" pass @property def max_tool_calls_number(self) -> int: - return self.required_calls + self.extra_tool_calls + """maxiumum number of call to still pass task. + Includes extra tool calls params. + and optional tool calls number which depends on task. + """ + return ( + self.required_calls + + self.optional_tool_calls_number + + self.extra_tool_calls + ) @property - def required_calls(self): + def required_calls(self) -> int: """Minimal number of calls required to complete task""" total = 0 for val in self.validators: @@ -550,6 +588,14 @@ def get_system_prompt(self) -> str: """ pass + @abstractmethod + def get_base_prompt(self) -> str: + """ + Get the base task instruciton, + it will be used to identify task in results processing + """ + pass + @abstractmethod def get_prompt(self) -> str: """Get the task instruction - the prompt that will be passed to agent. diff --git a/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py new file mode 100644 index 000000000..f9758499f --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py @@ -0,0 +1,3400 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Type + +from pydantic import BaseModel +from rai.types import ( + CameraInfo, + Image, +) +from rai.types.rai_interfaces import ( + ManipulatorMoveToRequest, + RAIDetectionArray, + RAIGroundedSamRequest, + RAIGroundingDinoRequest, +) + +from rai_bench.tool_calling_agent.messages.actions import ( + AssistedTeleopGoal, + BackUpGoal, + ComputePathThroughPosesGoal, + ComputePathToPoseGoal, + DriveOnHeadingGoal, + FollowPathGoal, + FollowWaypointsGoal, + NavigateThroughPosesGoal, + NavigateToPoseGoal, + SmoothPathGoal, + SpinGoal, + WaitGoal, +) +from rai_bench.tool_calling_agent.messages.base import Clock +from rai_bench.tool_calling_agent.messages.services import ( + StringListRequest, + VectorStoreRetrievalRequest, + WhatISeeRequest, +) +from rai_bench.tool_calling_agent.messages.topics import AudioMessage, HRIMessage + +# dict of interfaces where keys are interfaces types and values are output +# of GetROS2MessageInterfaceTool which are same as ros2 interface show outputs +# the dict contains custom as well as couple other common interfaces + +COMMON_INTERFACES: Dict[str, str] = { + "sensor_msgs/msg/CameraInfo": """ +# This message defines meta information for a camera. It should be in a +# camera namespace on topic "camera_info" and accompanied by up to five +# image topics named: +# +# image_raw - raw data from the camera driver, possibly Bayer encoded +# image - monochrome, distorted +# image_color - color, distorted +# image_rect - monochrome, rectified +# image_rect_color - color, rectified +# +# The image_pipeline contains packages (image_proc, stereo_image_proc) +# for producing the four processed image topics from image_raw and +# camera_info. The meaning of the camera parameters are described in +# detail at http://www.ros.org/wiki/image_pipeline/CameraInfo. +# +# The image_geometry package provides a user-friendly interface to +# common operations using this meta information. If you want to, e.g., +# project a 3d point into image coordinates, we strongly recommend +# using image_geometry. +# +# If the camera is uncalibrated, the matrices D, K, R, P should be left +# zeroed out. In particular, clients may assume that K[0] == 0.0 +# indicates an uncalibrated camera. + +####################################################################### +# Image acquisition info # +####################################################################### + +# Time of image acquisition, camera coordinate frame ID +std_msgs/Header header # Header timestamp should be acquisition time of image + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of camera + # +x should point to the right in the image + # +y should point down in the image + # +z should point into the plane of the image + + +####################################################################### +# Calibration Parameters # +####################################################################### +# These are fixed during camera calibration. Their values will be the # +# same in all messages until the camera is recalibrated. Note that # +# self-calibrating systems may "recalibrate" frequently. # +# # +# The internal parameters can be used to warp a raw (distorted) image # +# to: # +# 1. An undistorted image (requires D and K) # +# 2. A rectified image (requires D, K, R) # +# The projection matrix P projects 3D points into the rectified image.# +####################################################################### + +# The image dimensions with which the camera was calibrated. +# Normally this will be the full camera resolution in pixels. +uint32 height +uint32 width + +# The distortion model used. Supported models are listed in +# sensor_msgs/distortion_models.hpp. For most cameras, "plumb_bob" - a +# simple model of radial and tangential distortion - is sufficent. +string distortion_model + +# The distortion parameters, size depending on the distortion model. +# For "plumb_bob", the 5 parameters are: (k1, k2, t1, t2, k3). +float64[] d + +# Intrinsic camera matrix for the raw (distorted) images. +# [fx 0 cx] +# K = [ 0 fy cy] +# [ 0 0 1] +# Projects 3D points in the camera coordinate frame to 2D pixel +# coordinates using the focal lengths (fx, fy) and principal point +# (cx, cy). +float64[9] k # 3x3 row-major matrix + +# Rectification matrix (stereo cameras only) +# A rotation matrix aligning the camera coordinate system to the ideal +# stereo image plane so that epipolar lines in both stereo images are +# parallel. +float64[9] r # 3x3 row-major matrix + +# Projection/camera matrix +# [fx' 0 cx' Tx] +# P = [ 0 fy' cy' Ty] +# [ 0 0 1 0] +# By convention, this matrix specifies the intrinsic (camera) matrix +# of the processed (rectified) image. That is, the left 3x3 portion +# is the normal camera intrinsic matrix for the rectified image. +# It projects 3D points in the camera coordinate frame to 2D pixel +# coordinates using the focal lengths (fx', fy') and principal point +# (cx', cy') - these may differ from the values in K. +# For monocular cameras, Tx = Ty = 0. Normally, monocular cameras will +# also have R = the identity and P[1:3,1:3] = K. +# For a stereo pair, the fourth column [Tx Ty 0]' is related to the +# position of the optical center of the second camera in the first +# camera's frame. We assume Tz = 0 so both cameras are in the same +# stereo image plane. The first camera always has Tx = Ty = 0. For +# the right (second) camera of a horizontal stereo pair, Ty = 0 and +# Tx = -fx' * B, where B is the baseline between the cameras. +# Given a 3D point [X Y Z]', the projection (x, y) of the point onto +# the rectified image is given by: +# [u v w]' = P * [X Y Z 1]' +# x = u / w +# y = v / w +# This holds for both images of a stereo pair. +float64[12] p # 3x4 row-major matrix + + +####################################################################### +# Operational Parameters # +####################################################################### +# These define the image region actually captured by the camera # +# driver. Although they affect the geometry of the output image, they # +# may be changed freely without recalibrating the camera. # +####################################################################### + +# Binning refers here to any camera setting which combines rectangular +# neighborhoods of pixels into larger "super-pixels." It reduces the +# resolution of the output image to +# (width / binning_x) x (height / binning_y). +# The default values binning_x = binning_y = 0 is considered the same +# as binning_x = binning_y = 1 (no subsampling). +uint32 binning_x +uint32 binning_y + +# Region of interest (subwindow of full camera resolution), given in +# full resolution (unbinned) image coordinates. A particular ROI +# always denotes the same window of pixels on the camera sensor, +# regardless of binning settings. +# The default setting of roi (all values 0) is considered the same as +# full resolution (roi.width = width, roi.height = height). +RegionOfInterest roi + # + uint32 x_offset # + # (0 if the ROI includes the left edge of the image) + uint32 y_offset # + # (0 if the ROI includes the top edge of the image) + uint32 height # + uint32 width # + bool do_rectify +""", + "sensor_msgs/msg/Image": """ +# This message contains an uncompressed image +# (0, 0) is at top-left corner of image + +std_msgs/Header header # Header timestamp should be acquisition time of image + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + +uint32 height # image height, that is, number of rows +uint32 width # image width, that is, number of columns + +# The legal values for encoding are in file src/image_encodings.cpp +# If you want to standardize a new string format, join +# ros-users@lists.ros.org and send an email proposing a new encoding. + +string encoding # Encoding of pixels -- channel meaning, ordering, size + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + +uint8 is_bigendian # is this data bigendian? +uint32 step # Full row length in bytes +uint8[] data # actual matrix data, size is (step * rows) +""", + "rosgraph_msgs/msg/Clock": """ +# This message communicates the current time. +# +# For more information, see https://design.ros2.org/articles/clock_and_time.html. +builtin_interfaces/Time clock + int32 sec + uint32 nanosec +""", +} +MANIPULATION_INTERFACES: Dict[str, str] = { + "moveit_msgs/action/ExecuteTrajectory": """# The trajectory to execute +RobotTrajectory trajectory + trajectory_msgs/JointTrajectory joint_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + trajectory_msgs/MultiDOFJointTrajectory multi_dof_joint_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + MultiDOFJointTrajectoryPoint[] points + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] velocities + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Twist[] accelerations + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + +--- + +# Error code - encodes the overall reason for failure +MoveItErrorCodes error_code + int32 val + int32 SUCCESS=1 + int32 FAILURE=99999 + int32 PLANNING_FAILED=-1 + int32 INVALID_MOTION_PLAN=-2 + int32 MOTION_PLAN_INVALIDATED_BY_ENVIRONMENT_CHANGE=-3 + int32 CONTROL_FAILED=-4 + int32 UNABLE_TO_AQUIRE_SENSOR_DATA=-5 + int32 TIMED_OUT=-6 + int32 PREEMPTED=-7 + int32 START_STATE_IN_COLLISION=-10 + int32 START_STATE_VIOLATES_PATH_CONSTRAINTS=-11 + int32 START_STATE_INVALID=-26 + int32 GOAL_IN_COLLISION=-12 + int32 GOAL_VIOLATES_PATH_CONSTRAINTS=-13 + int32 GOAL_CONSTRAINTS_VIOLATED=-14 + int32 GOAL_STATE_INVALID=-27 + int32 UNRECOGNIZED_GOAL_TYPE=-28 + int32 INVALID_GROUP_NAME=-15 + int32 INVALID_GOAL_CONSTRAINTS=-16 + int32 INVALID_ROBOT_STATE=-17 + int32 INVALID_LINK_NAME=-18 + int32 INVALID_OBJECT_NAME=-19 + int32 FRAME_TRANSFORM_FAILURE=-21 + int32 COLLISION_CHECKING_UNAVAILABLE=-22 + int32 ROBOT_STATE_STALE=-23 + int32 SENSOR_INFO_STALE=-24 + int32 COMMUNICATION_FAILURE=-25 + int32 CRASH=-29 + int32 ABORT=-30 + int32 NO_IK_SOLUTION=-31 + +--- + +# The internal state that the move group action currently is in +string state} +""", + "moveit_msgs/action/MoveGroup": """# Motion planning request to pass to planner +MotionPlanRequest request + WorkspaceParameters workspace_parameters + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Vector3 min_corner + float64 x + float64 y + float64 z + geometry_msgs/Vector3 max_corner + float64 x + float64 y + float64 z + RobotState start_state + sensor_msgs/JointState joint_state + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] name + float64[] position + float64[] velocity + float64[] effort + sensor_msgs/MultiDOFJointState multi_dof_joint_state + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] twist + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Wrench[] wrench + Vector3 force + float64 x + float64 y + float64 z + Vector3 torque + float64 x + float64 y + float64 z + AttachedCollisionObject[] attached_collision_objects + string link_name + CollisionObject object + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string id + object_recognition_msgs/ObjectType type + string key + string db + shape_msgs/SolidPrimitive[] primitives + uint8 BOX=1 + uint8 SPHERE=2 + uint8 CYLINDER=3 + uint8 CONE=4 + uint8 PRISM=5 + uint8 type + float64[<=3] dimensions # + uint8 BOX_X=0 + uint8 BOX_Y=1 + uint8 BOX_Z=2 + uint8 SPHERE_RADIUS=0 + uint8 CYLINDER_HEIGHT=0 + uint8 CYLINDER_RADIUS=1 + uint8 CONE_HEIGHT=0 + uint8 CONE_RADIUS=1 + uint8 PRISM_HEIGHT=0 + geometry_msgs/Polygon polygon + Point32[] points + # + # + float32 x + float32 y + float32 z + geometry_msgs/Pose[] primitive_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Mesh[] meshes + MeshTriangle[] triangles + uint32[3] vertex_indices + geometry_msgs/Point[] vertices + float64 x + float64 y + float64 z + geometry_msgs/Pose[] mesh_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Plane[] planes + # + float64[4] coef + geometry_msgs/Pose[] plane_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string[] subframe_names + geometry_msgs/Pose[] subframe_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + byte ADD=0 + byte REMOVE=1 + byte APPEND=2 + byte MOVE=3 + byte operation + string[] touch_links + trajectory_msgs/JointTrajectory detach_posture + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + float64 weight + bool is_diff + Constraints[] goal_constraints + string name + JointConstraint[] joint_constraints + string joint_name + float64 position + float64 tolerance_above + float64 tolerance_below + float64 weight + PositionConstraint[] position_constraints + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string link_name + geometry_msgs/Vector3 target_point_offset + float64 x + float64 y + float64 z + BoundingVolume constraint_region + shape_msgs/SolidPrimitive[] primitives + uint8 BOX=1 + uint8 SPHERE=2 + uint8 CYLINDER=3 + uint8 CONE=4 + uint8 PRISM=5 + uint8 type + float64[<=3] dimensions # + uint8 BOX_X=0 + uint8 BOX_Y=1 + uint8 BOX_Z=2 + uint8 SPHERE_RADIUS=0 + uint8 CYLINDER_HEIGHT=0 + uint8 CYLINDER_RADIUS=1 + uint8 CONE_HEIGHT=0 + uint8 CONE_RADIUS=1 + uint8 PRISM_HEIGHT=0 + geometry_msgs/Polygon polygon + Point32[] points + # + # + float32 x + float32 y + float32 z + geometry_msgs/Pose[] primitive_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Mesh[] meshes + MeshTriangle[] triangles + uint32[3] vertex_indices + geometry_msgs/Point[] vertices + float64 x + float64 y + float64 z + geometry_msgs/Pose[] mesh_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64 weight + OrientationConstraint[] orientation_constraints + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string link_name + float64 absolute_x_axis_tolerance + float64 absolute_y_axis_tolerance + float64 absolute_z_axis_tolerance + uint8 parameterization + uint8 XYZ_EULER_ANGLES=0 + uint8 ROTATION_VECTOR=1 + float64 weight + VisibilityConstraint[] visibility_constraints + float64 target_radius + geometry_msgs/PoseStamped target_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + int32 cone_sides + geometry_msgs/PoseStamped sensor_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64 max_view_angle + float64 max_range_angle + uint8 SENSOR_Z=0 + uint8 SENSOR_Y=1 + uint8 SENSOR_X=2 + uint8 sensor_view_direction + float64 weight + Constraints path_constraints + string name + JointConstraint[] joint_constraints + string joint_name + float64 position + float64 tolerance_above + float64 tolerance_below + float64 weight + PositionConstraint[] position_constraints + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string link_name + geometry_msgs/Vector3 target_point_offset + float64 x + float64 y + float64 z + BoundingVolume constraint_region + shape_msgs/SolidPrimitive[] primitives + uint8 BOX=1 + uint8 SPHERE=2 + uint8 CYLINDER=3 + uint8 CONE=4 + uint8 PRISM=5 + uint8 type + float64[<=3] dimensions # + uint8 BOX_X=0 + uint8 BOX_Y=1 + uint8 BOX_Z=2 + uint8 SPHERE_RADIUS=0 + uint8 CYLINDER_HEIGHT=0 + uint8 CYLINDER_RADIUS=1 + uint8 CONE_HEIGHT=0 + uint8 CONE_RADIUS=1 + uint8 PRISM_HEIGHT=0 + geometry_msgs/Polygon polygon + Point32[] points + # + # + float32 x + float32 y + float32 z + geometry_msgs/Pose[] primitive_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Mesh[] meshes + MeshTriangle[] triangles + uint32[3] vertex_indices + geometry_msgs/Point[] vertices + float64 x + float64 y + float64 z + geometry_msgs/Pose[] mesh_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64 weight + OrientationConstraint[] orientation_constraints + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string link_name + float64 absolute_x_axis_tolerance + float64 absolute_y_axis_tolerance + float64 absolute_z_axis_tolerance + uint8 parameterization + uint8 XYZ_EULER_ANGLES=0 + uint8 ROTATION_VECTOR=1 + float64 weight + VisibilityConstraint[] visibility_constraints + float64 target_radius + geometry_msgs/PoseStamped target_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + int32 cone_sides + geometry_msgs/PoseStamped sensor_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64 max_view_angle + float64 max_range_angle + uint8 SENSOR_Z=0 + uint8 SENSOR_Y=1 + uint8 SENSOR_X=2 + uint8 sensor_view_direction + float64 weight + TrajectoryConstraints trajectory_constraints + Constraints[] constraints + string name + JointConstraint[] joint_constraints + string joint_name + float64 position + float64 tolerance_above + float64 tolerance_below + float64 weight + PositionConstraint[] position_constraints + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string link_name + geometry_msgs/Vector3 target_point_offset + float64 x + float64 y + float64 z + BoundingVolume constraint_region + shape_msgs/SolidPrimitive[] primitives + uint8 BOX=1 + uint8 SPHERE=2 + uint8 CYLINDER=3 + uint8 CONE=4 + uint8 PRISM=5 + uint8 type + float64[<=3] dimensions # + uint8 BOX_X=0 + uint8 BOX_Y=1 + uint8 BOX_Z=2 + uint8 SPHERE_RADIUS=0 + uint8 CYLINDER_HEIGHT=0 + uint8 CYLINDER_RADIUS=1 + uint8 CONE_HEIGHT=0 + uint8 CONE_RADIUS=1 + uint8 PRISM_HEIGHT=0 + geometry_msgs/Polygon polygon + Point32[] points + # + # + float32 x + float32 y + float32 z + geometry_msgs/Pose[] primitive_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Mesh[] meshes + MeshTriangle[] triangles + uint32[3] vertex_indices + geometry_msgs/Point[] vertices + float64 x + float64 y + float64 z + geometry_msgs/Pose[] mesh_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64 weight + OrientationConstraint[] orientation_constraints + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string link_name + float64 absolute_x_axis_tolerance + float64 absolute_y_axis_tolerance + float64 absolute_z_axis_tolerance + uint8 parameterization + uint8 XYZ_EULER_ANGLES=0 + uint8 ROTATION_VECTOR=1 + float64 weight + VisibilityConstraint[] visibility_constraints + float64 target_radius + geometry_msgs/PoseStamped target_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + int32 cone_sides + geometry_msgs/PoseStamped sensor_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64 max_view_angle + float64 max_range_angle + uint8 SENSOR_Z=0 + uint8 SENSOR_Y=1 + uint8 SENSOR_X=2 + uint8 sensor_view_direction + float64 weight + GenericTrajectory[] reference_trajectories + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + trajectory_msgs/JointTrajectory[] joint_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + moveit_msgs/CartesianTrajectory[] cartesian_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string tracked_frame + CartesianTrajectoryPoint[] points + CartesianPoint point + geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist velocity + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Accel acceleration + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + string pipeline_id + string planner_id + string group_name + int32 num_planning_attempts + float64 allowed_planning_time + float64 max_velocity_scaling_factor + float64 max_acceleration_scaling_factor + string cartesian_speed_end_effector_link + float64 max_cartesian_speed # + +# Planning options +PlanningOptions planning_options + PlanningScene planning_scene_diff + string name + RobotState robot_state + sensor_msgs/JointState joint_state + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] name + float64[] position + float64[] velocity + float64[] effort + sensor_msgs/MultiDOFJointState multi_dof_joint_state + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] twist + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Wrench[] wrench + Vector3 force + float64 x + float64 y + float64 z + Vector3 torque + float64 x + float64 y + float64 z + AttachedCollisionObject[] attached_collision_objects + string link_name + CollisionObject object + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string id + object_recognition_msgs/ObjectType type + string key + string db + shape_msgs/SolidPrimitive[] primitives + uint8 BOX=1 + uint8 SPHERE=2 + uint8 CYLINDER=3 + uint8 CONE=4 + uint8 PRISM=5 + uint8 type + float64[<=3] dimensions # + uint8 BOX_X=0 + uint8 BOX_Y=1 + uint8 BOX_Z=2 + uint8 SPHERE_RADIUS=0 + uint8 CYLINDER_HEIGHT=0 + uint8 CYLINDER_RADIUS=1 + uint8 CONE_HEIGHT=0 + uint8 CONE_RADIUS=1 + uint8 PRISM_HEIGHT=0 + geometry_msgs/Polygon polygon + Point32[] points + # + # + float32 x + float32 y + float32 z + geometry_msgs/Pose[] primitive_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Mesh[] meshes + MeshTriangle[] triangles + uint32[3] vertex_indices + geometry_msgs/Point[] vertices + float64 x + float64 y + float64 z + geometry_msgs/Pose[] mesh_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Plane[] planes + # + float64[4] coef + geometry_msgs/Pose[] plane_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string[] subframe_names + geometry_msgs/Pose[] subframe_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + byte ADD=0 + byte REMOVE=1 + byte APPEND=2 + byte MOVE=3 + byte operation + string[] touch_links + trajectory_msgs/JointTrajectory detach_posture + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + float64 weight + bool is_diff + string robot_model_name + geometry_msgs/TransformStamped[] fixed_frame_transforms + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string child_frame_id + Transform transform + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + AllowedCollisionMatrix allowed_collision_matrix + string[] entry_names + AllowedCollisionEntry[] entry_values + bool[] enabled + string[] default_entry_names + bool[] default_entry_values + LinkPadding[] link_padding + string link_name + float64 padding + LinkScale[] link_scale + string link_name + float64 scale + ObjectColor[] object_colors + string id + std_msgs/ColorRGBA color + float32 r + float32 g + float32 b + float32 a + PlanningSceneWorld world + CollisionObject[] collision_objects + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string id + object_recognition_msgs/ObjectType type + string key + string db + shape_msgs/SolidPrimitive[] primitives + uint8 BOX=1 + uint8 SPHERE=2 + uint8 CYLINDER=3 + uint8 CONE=4 + uint8 PRISM=5 + uint8 type + float64[<=3] dimensions # + uint8 BOX_X=0 + uint8 BOX_Y=1 + uint8 BOX_Z=2 + uint8 SPHERE_RADIUS=0 + uint8 CYLINDER_HEIGHT=0 + uint8 CYLINDER_RADIUS=1 + uint8 CONE_HEIGHT=0 + uint8 CONE_RADIUS=1 + uint8 PRISM_HEIGHT=0 + geometry_msgs/Polygon polygon + Point32[] points + # + # + float32 x + float32 y + float32 z + geometry_msgs/Pose[] primitive_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Mesh[] meshes + MeshTriangle[] triangles + uint32[3] vertex_indices + geometry_msgs/Point[] vertices + float64 x + float64 y + float64 z + geometry_msgs/Pose[] mesh_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Plane[] planes + # + float64[4] coef + geometry_msgs/Pose[] plane_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string[] subframe_names + geometry_msgs/Pose[] subframe_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + byte ADD=0 + byte REMOVE=1 + byte APPEND=2 + byte MOVE=3 + byte operation + octomap_msgs/OctomapWithPose octomap + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Pose origin + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + octomap_msgs/Octomap octomap + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + bool binary + string id + float64 resolution + int8[] data + bool is_diff + bool plan_only + bool look_around + int32 look_around_attempts + float64 max_safe_execution_cost + bool replan + int32 replan_attempts + float64 replan_delay + +--- + +# An error code reflecting what went wrong +MoveItErrorCodes error_code + int32 val + int32 SUCCESS=1 + int32 FAILURE=99999 + int32 PLANNING_FAILED=-1 + int32 INVALID_MOTION_PLAN=-2 + int32 MOTION_PLAN_INVALIDATED_BY_ENVIRONMENT_CHANGE=-3 + int32 CONTROL_FAILED=-4 + int32 UNABLE_TO_AQUIRE_SENSOR_DATA=-5 + int32 TIMED_OUT=-6 + int32 PREEMPTED=-7 + int32 START_STATE_IN_COLLISION=-10 + int32 START_STATE_VIOLATES_PATH_CONSTRAINTS=-11 + int32 START_STATE_INVALID=-26 + int32 GOAL_IN_COLLISION=-12 + int32 GOAL_VIOLATES_PATH_CONSTRAINTS=-13 + int32 GOAL_CONSTRAINTS_VIOLATED=-14 + int32 GOAL_STATE_INVALID=-27 + int32 UNRECOGNIZED_GOAL_TYPE=-28 + int32 INVALID_GROUP_NAME=-15 + int32 INVALID_GOAL_CONSTRAINTS=-16 + int32 INVALID_ROBOT_STATE=-17 + int32 INVALID_LINK_NAME=-18 + int32 INVALID_OBJECT_NAME=-19 + int32 FRAME_TRANSFORM_FAILURE=-21 + int32 COLLISION_CHECKING_UNAVAILABLE=-22 + int32 ROBOT_STATE_STALE=-23 + int32 SENSOR_INFO_STALE=-24 + int32 COMMUNICATION_FAILURE=-25 + int32 CRASH=-29 + int32 ABORT=-30 + int32 NO_IK_SOLUTION=-31 + +# The full starting state of the robot at the start of the trajectory +moveit_msgs/RobotState trajectory_start + sensor_msgs/JointState joint_state + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] name + float64[] position + float64[] velocity + float64[] effort + sensor_msgs/MultiDOFJointState multi_dof_joint_state + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] twist + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Wrench[] wrench + Vector3 force + float64 x + float64 y + float64 z + Vector3 torque + float64 x + float64 y + float64 z + AttachedCollisionObject[] attached_collision_objects + string link_name + CollisionObject object + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string id + object_recognition_msgs/ObjectType type + string key + string db + shape_msgs/SolidPrimitive[] primitives + uint8 BOX=1 + uint8 SPHERE=2 + uint8 CYLINDER=3 + uint8 CONE=4 + uint8 PRISM=5 + uint8 type + float64[<=3] dimensions # + uint8 BOX_X=0 + uint8 BOX_Y=1 + uint8 BOX_Z=2 + uint8 SPHERE_RADIUS=0 + uint8 CYLINDER_HEIGHT=0 + uint8 CYLINDER_RADIUS=1 + uint8 CONE_HEIGHT=0 + uint8 CONE_RADIUS=1 + uint8 PRISM_HEIGHT=0 + geometry_msgs/Polygon polygon + Point32[] points + # + # + float32 x + float32 y + float32 z + geometry_msgs/Pose[] primitive_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Mesh[] meshes + MeshTriangle[] triangles + uint32[3] vertex_indices + geometry_msgs/Point[] vertices + float64 x + float64 y + float64 z + geometry_msgs/Pose[] mesh_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + shape_msgs/Plane[] planes + # + float64[4] coef + geometry_msgs/Pose[] plane_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + string[] subframe_names + geometry_msgs/Pose[] subframe_poses + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + byte ADD=0 + byte REMOVE=1 + byte APPEND=2 + byte MOVE=3 + byte operation + string[] touch_links + trajectory_msgs/JointTrajectory detach_posture + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + float64 weight + bool is_diff + +# The trajectory that moved group produced for execution +moveit_msgs/RobotTrajectory planned_trajectory + trajectory_msgs/JointTrajectory joint_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + trajectory_msgs/MultiDOFJointTrajectory multi_dof_joint_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + MultiDOFJointTrajectoryPoint[] points + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] velocities + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Twist[] accelerations + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + +# The trace of the trajectory recorded during execution +moveit_msgs/RobotTrajectory executed_trajectory + trajectory_msgs/JointTrajectory joint_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + trajectory_msgs/MultiDOFJointTrajectory multi_dof_joint_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + MultiDOFJointTrajectoryPoint[] points + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] velocities + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Twist[] accelerations + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + +# The amount of time it took to complete the motion plan +float64 planning_time + +--- + +# The internal state that the move group action currently is in +string state +""", + "control_msgs/action/FollowJointTrajectory": """# The trajectory for all revolute, continuous or prismatic joints +trajectory_msgs/JointTrajectory trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + JointTrajectoryPoint[] points + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec +# The trajectory for all planar or floating joints (i.e. individual joints with more than one DOF) +trajectory_msgs/MultiDOFJointTrajectory multi_dof_trajectory + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string[] joint_names + MultiDOFJointTrajectoryPoint[] points + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] velocities + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Twist[] accelerations + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + +# Tolerances for the trajectory. If the measured joint values fall +# outside the tolerances the trajectory goal is aborted. Any +# tolerances that are not specified (by being omitted or set to 0) are +# set to the defaults for the action server (often taken from the +# parameter server). + +# Tolerances applied to the joints as the trajectory is executed. If +# violated, the goal aborts with error_code set to +# PATH_TOLERANCE_VIOLATED. +JointTolerance[] path_tolerance + # + string name + float64 position # + float64 velocity # + float64 acceleration # +JointComponentTolerance[] component_path_tolerance + uint16 X_AXIS=1 + uint16 Y_AXIS=2 + uint16 Z_AXIS=3 + uint16 TRANSLATION=4 + uint16 ROTATION=5 + string joint_name + uint16 component + float64 position + float64 velocity + float64 acceleration + +# To report success, the joints must be within goal_tolerance of the +# final trajectory value. The goal must be achieved by time the +# trajectory ends plus goal_time_tolerance. (goal_time_tolerance +# allows some leeway in time, so that the trajectory goal can still +# succeed even if the joints reach the goal some time after the +# precise end time of the trajectory). +# +# If the joints are not within goal_tolerance after "trajectory finish +# time" + goal_time_tolerance, the goal aborts with error_code set to +# GOAL_TOLERANCE_VIOLATED +JointTolerance[] goal_tolerance + # + string name + float64 position # + float64 velocity # + float64 acceleration # +JointComponentTolerance[] component_goal_tolerance + uint16 X_AXIS=1 + uint16 Y_AXIS=2 + uint16 Z_AXIS=3 + uint16 TRANSLATION=4 + uint16 ROTATION=5 + string joint_name + uint16 component + float64 position + float64 velocity + float64 acceleration +builtin_interfaces/Duration goal_time_tolerance + int32 sec + uint32 nanosec + +--- +int32 error_code +int32 SUCCESSFUL = 0 +int32 INVALID_GOAL = -1 +int32 INVALID_JOINTS = -2 +int32 OLD_HEADER_TIMESTAMP = -3 +int32 PATH_TOLERANCE_VIOLATED = -4 +int32 GOAL_TOLERANCE_VIOLATED = -5 + +# Human readable description of the error code. Contains complementary +# information that is especially useful when execution fails, for instance: +# - INVALID_GOAL: The reason for the invalid goal (e.g., the requested +# trajectory is in the past). +# - INVALID_JOINTS: The mismatch between the expected controller joints +# and those provided in the goal. +# - PATH_TOLERANCE_VIOLATED and GOAL_TOLERANCE_VIOLATED: Which joint +# violated which tolerance, and by how much. +string error_string + +--- +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id +string[] joint_names +trajectory_msgs/JointTrajectoryPoint desired + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec +trajectory_msgs/JointTrajectoryPoint actual + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec +trajectory_msgs/JointTrajectoryPoint error + float64[] positions + float64[] velocities + float64[] accelerations + float64[] effort + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec + +string[] multi_dof_joint_names +trajectory_msgs/MultiDOFJointTrajectoryPoint multi_dof_desired + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] velocities + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Twist[] accelerations + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec +trajectory_msgs/MultiDOFJointTrajectoryPoint multi_dof_actual + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] velocities + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Twist[] accelerations + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec +trajectory_msgs/MultiDOFJointTrajectoryPoint multi_dof_error + geometry_msgs/Transform[] transforms + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + geometry_msgs/Twist[] velocities + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + geometry_msgs/Twist[] accelerations + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + builtin_interfaces/Duration time_from_start + int32 sec + uint32 nanosec +""", + "/panda_hand_controller/gripper_cmd": """GripperCommand command +float64 position +float64 max_effort +--- +float64 position # The current gripper gap size (in meters) +float64 effort # The current effort exerted (in Newtons) +bool stalled # True iff the gripper is exerting max effort and not moving +bool reached_goal # True iff the gripper position has reached the commanded setpoint +--- +float64 position # The current gripper gap size (in meters) +float64 effort # The current effort exerted (in Newtons) +bool stalled # True iff the gripper is exerting max effort and not moving +bool reached_goal # True iff the gripper position has reached the commanded setpoint +""", +} + + +NAVIGATION_INTERFACES: Dict[str, str] = { + "nav2_msgs/action/NavigateToPose": """#goal definition +geometry_msgs/PoseStamped pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +string behavior_tree +--- +#result definition +std_msgs/Empty result +--- +#feedback definition +geometry_msgs/PoseStamped current_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +builtin_interfaces/Duration navigation_time + int32 sec + uint32 nanosec +builtin_interfaces/Duration estimated_time_remaining + int32 sec + uint32 nanosec +int16 number_of_recoveries +float32 distance_remaining +""", + "nav2_msgs/action/AssistedTeleop": """#goal definition +builtin_interfaces/Duration time_allowance + int32 sec + uint32 nanosec +--- +#result definition +builtin_interfaces/Duration total_elapsed_time + int32 sec + uint32 nanosec +--- +#feedback +builtin_interfaces/Duration current_teleop_duration + int32 sec + uint32 nanosec""", + "nav2_msgs/action/BackUp": """#goal definition +geometry_msgs/Point target + float64 x + float64 y + float64 z +float32 speed +builtin_interfaces/Duration time_allowance + int32 sec + uint32 nanosec +--- +#result definition +builtin_interfaces/Duration total_elapsed_time + int32 sec + uint32 nanosec +--- +#feedback definition +float32 distance_traveled""", + "nav2_msgs/action/ComputePathThroughPoses": """#goal definition +geometry_msgs/PoseStamped[] goals + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +geometry_msgs/PoseStamped start + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +string planner_id +bool use_start # If false, use current robot pose as path start, if true, use start above instead +--- +#result definition +nav_msgs/Path path + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/PoseStamped[] poses + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +builtin_interfaces/Duration planning_time + int32 sec + uint32 nanosec +--- +#feedback definition""", + "nav2_msgs/action/ComputePathToPose": """#goal definition +geometry_msgs/PoseStamped goal + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +geometry_msgs/PoseStamped start + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +string planner_id +bool use_start # If false, use current robot pose as path start, if true, use start above instead +--- +#result definition +nav_msgs/Path path + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/PoseStamped[] poses + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +builtin_interfaces/Duration planning_time + int32 sec + uint32 nanosec +--- +#feedback definition""", + "nav2_msgs/action/DriveOnHeading": """#goal definition +geometry_msgs/Point target + float64 x + float64 y + float64 z +float32 speed +builtin_interfaces/Duration time_allowance + int32 sec + uint32 nanosec +--- +#result definition +builtin_interfaces/Duration total_elapsed_time + int32 sec + uint32 nanosec +--- +#feedback definition +float32 distance_traveled""", + "nav2_msgs/action/FollowPath": """#goal definition +nav_msgs/Path path + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/PoseStamped[] poses + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +string controller_id +string goal_checker_id +--- +#result definition +std_msgs/Empty result +--- +#feedback definition +float32 distance_to_goal +float32 speed""", + "nav2_msgs/action/FollowWaypoints": """#goal definition +geometry_msgs/PoseStamped[] poses + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +--- +#result definition +int32[] missed_waypoints +--- +#feedback definition +uint32 current_waypoint""", + "nav2_msgs/action/NavigateThroughPoses": """#goal definition +geometry_msgs/PoseStamped[] poses + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +string behavior_tree +--- +#result definition +std_msgs/Empty result +--- +#feedback definition +geometry_msgs/PoseStamped current_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +builtin_interfaces/Duration navigation_time + int32 sec + uint32 nanosec +builtin_interfaces/Duration estimated_time_remaining + int32 sec + uint32 nanosec +int16 number_of_recoveries +float32 distance_remaining +int16 number_of_poses_remaining +""", + "nav2_msgs/action/SmoothPath": """#goal definition +nav_msgs/Path path + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/PoseStamped[] poses + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +string smoother_id +builtin_interfaces/Duration max_smoothing_duration + int32 sec + uint32 nanosec +bool check_for_collisions +--- +#result definition +nav_msgs/Path path + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + geometry_msgs/PoseStamped[] poses + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +builtin_interfaces/Duration smoothing_duration + int32 sec + uint32 nanosec +bool was_completed +--- +#feedback definition +""", + "nav2_msgs/action/Wait": """#goal definition +builtin_interfaces/Duration time + int32 sec + uint32 nanosec +--- +#result definition +builtin_interfaces/Duration total_elapsed_time + int32 sec + uint32 nanosec +--- +#feedback definition +builtin_interfaces/Duration time_left + int32 sec + uint32 nanosec""", +} + +CUSTOM_INTERFACES: Dict[str, str] = { + "rai_interfaces/msg/HRIMessage": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id +string text +sensor_msgs/Image[] images + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +rai_interfaces/AudioMessage[] audios + # + # + # + # + # + int16[] audio + uint16 sample_rate + uint16 channels +string communication_id +int64 seq_no +bool seq_end +""", + "rai_interfaces/msg/AudioMessage": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +int16[] audio +uint16 sample_rate +uint16 channels +""", + "rai_interfaces/msg/RAIDetectionArray": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A list of 2D detections, for a multi-object 2D detector. +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + +# A list of the detected proposals. A multi-proposal detector might generate +# this list with many candidate detections generated from a single input. +vision_msgs/Detection2D[] detections + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + ObjectHypothesisWithPose[] results + ObjectHypothesis hypothesis + string class_id + float64 score + geometry_msgs/PoseWithCovariance pose + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64[36] covariance + BoundingBox2D bbox + vision_msgs/Pose2D center + float64 x + float64 y + float64 theta + float64 size_x + float64 size_y + string id +# a list of classes being detected +string[] detection_classes +""", + "rai_interfaces/srv/ManipulatorMoveTo": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A simplified approach with binary states for the gripper +bool initial_gripper_state +bool final_gripper_state +geometry_msgs/PoseStamped target_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +--- +bool success +""", + "rai_interfaces/srv/RAIGroundedSam": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +RAIDetectionArray detections + # + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + vision_msgs/Detection2D[] detections + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + ObjectHypothesisWithPose[] results + ObjectHypothesis hypothesis + string class_id + float64 score + geometry_msgs/PoseWithCovariance pose + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64[36] covariance + BoundingBox2D bbox + vision_msgs/Pose2D center + float64 x + float64 y + float64 theta + float64 size_x + float64 size_y + string id + string[] detection_classes +sensor_msgs/Image source_img + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +--- +sensor_msgs/Image[] masks + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +""", + "rai_interfaces/srv/RAIGroundingDino": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +string classes +float64 box_threshold +float64 text_threshold +sensor_msgs/Image source_img + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +--- +RAIDetectionArray detections + # + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + vision_msgs/Detection2D[] detections + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + ObjectHypothesisWithPose[] results + ObjectHypothesis hypothesis + string class_id + float64 score + geometry_msgs/PoseWithCovariance pose + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64[36] covariance + BoundingBox2D bbox + vision_msgs/Pose2D center + float64 x + float64 y + float64 theta + float64 size_x + float64 size_y + string id + string[] detection_classes +""", + "rai_interfaces/srv/StringList": """ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Request - empty +--- +# Response +bool success +string[] string_list +""", + "rai_interfaces/srv/VectorStoreRetrieval": """ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Request +string query + +--- +# Response +bool success +string message +string[] documents +float32[] scores +""", + "rai_interfaces/srv/WhatISee": """z +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Request (empty) + +--- +# Response, timed with image timestamp +string[] observations +string perception_source +sensor_msgs/Image image + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +""", + "rai_interfaces/action/Task": """ +# Goal +string task +string description +string priority + +--- +# Result +bool success +string report + +--- +# Feedback +string current_status +""", + "/load_map": """ +string filename +--- +bool success +""", + "/query_planner_interface": """ +--- + +# The planning instances that could be used in the benchmark +PlannerInterfaceDescription[] planner_interfaces + string name + string pipeline_id + string[] planner_ids + +""", +} + +COMMON_TOPICS_AND_TYPES: Dict[str, str] = { + "/clock": "rosgraph_msgs/msg/Clock", + "/parameter_events": "rcl_interfaces/msg/ParameterEvent", + "/rosout": "rcl_interfaces/msg/Log", + "/tf": "tf2_msgs/msg/TFMessage", + "/tf_static": "tf2_msgs/msg/TFMessage", + "/joint_states": "sensor_msgs/msg/JointState", + "/robot_description": "std_msgs/msg/String", + "/robot_description_semantic": "std_msgs/msg/String", + "/bond": "bond/msg/Status", + "/diagnostics": "diagnostic_msgs/msg/DiagnosticArray", + # Perception topics + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/color_image5": "sensor_msgs/msg/Image", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_image5": "sensor_msgs/msg/Image", + "/pointcloud": "sensor_msgs/msg/PointCloud2", + "/scan": "sensor_msgs/msg/LaserScan", + "/odom": "nav_msgs/msg/Odometry", + "/odometry/filtered": "nav_msgs/msg/Odometry", +} + +MANIPULATION_TOPICS_AND_TYPES: Dict[str, str] = { + # MoveIt2 planning and execution + "/move_action/_action/feedback": "moveit_msgs/action/MoveGroup_FeedbackMessage", + "/move_action/_action/status": "action_msgs/msg/GoalStatusArray", + "/execute_trajectory/_action/feedback": "moveit_msgs/action/ExecuteTrajectory_FeedbackMessage", + "/execute_trajectory/_action/status": "action_msgs/msg/GoalStatusArray", + "/motion_plan_request": "moveit_msgs/msg/MotionPlanRequest", + "/display_planned_path": "moveit_msgs/msg/DisplayTrajectory", + "/trajectory_execution_event": "std_msgs/msg/String", + # Planning scene management + "/planning_scene": "moveit_msgs/msg/PlanningScene", + "/planning_scene_world": "moveit_msgs/msg/PlanningSceneWorld", + "/monitored_planning_scene": "moveit_msgs/msg/PlanningScene", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/display_contacts": "visualization_msgs/msg/MarkerArray", + # Arm and gripper controllers + "/panda_arm_controller/follow_joint_trajectory/_action/feedback": "control_msgs/action/FollowJointTrajectory_FeedbackMessage", + "/panda_arm_controller/follow_joint_trajectory/_action/status": "action_msgs/msg/GoalStatusArray", + "/panda_hand_controller/gripper_cmd/_action/feedback": "control_msgs/action/GripperCommand_FeedbackMessage", + "/panda_hand_controller/gripper_cmd/_action/status": "action_msgs/msg/GoalStatusArray", +} + + +NAVIGATION_TOPICS_AND_TYPES: Dict[str, str] = { + # Main navigation actions + "/navigate_to_pose/_action/feedback": "nav2_msgs/action/NavigateToPose_FeedbackMessage", + "/navigate_to_pose/_action/status": "action_msgs/msg/GoalStatusArray", + "/navigate_through_poses/_action/feedback": "nav2_msgs/action/NavigateThroughPoses_FeedbackMessage", + "/navigate_through_poses/_action/status": "action_msgs/msg/GoalStatusArray", + "/follow_path/_action/feedback": "nav2_msgs/action/FollowPath_FeedbackMessage", + "/follow_path/_action/status": "action_msgs/msg/GoalStatusArray", + "/follow_waypoints/_action/feedback": "nav2_msgs/action/FollowWaypoints_FeedbackMessage", + "/follow_waypoints/_action/status": "action_msgs/msg/GoalStatusArray", + # Path planning actions + "/compute_path_to_pose/_action/feedback": "nav2_msgs/action/ComputePathToPose_FeedbackMessage", + "/compute_path_to_pose/_action/status": "action_msgs/msg/GoalStatusArray", + "/compute_path_through_poses/_action/feedback": "nav2_msgs/action/ComputePathThroughPoses_FeedbackMessage", + "/compute_path_through_poses/_action/status": "action_msgs/msg/GoalStatusArray", + "/smooth_path/_action/feedback": "nav2_msgs/action/SmoothPath_FeedbackMessage", + "/smooth_path/_action/status": "action_msgs/msg/GoalStatusArray", + # Behavior actions + "/assisted_teleop/_action/feedback": "nav2_msgs/action/AssistedTeleop_FeedbackMessage", + "/assisted_teleop/_action/status": "action_msgs/msg/GoalStatusArray", + "/backup/_action/feedback": "nav2_msgs/action/BackUp_FeedbackMessage", + "/backup/_action/status": "action_msgs/msg/GoalStatusArray", + "/drive_on_heading/_action/feedback": "nav2_msgs/action/DriveOnHeading_FeedbackMessage", + "/drive_on_heading/_action/status": "action_msgs/msg/GoalStatusArray", + "/spin/_action/feedback": "nav2_msgs/action/Spin_FeedbackMessage", + "/spin/_action/status": "action_msgs/msg/GoalStatusArray", + "/wait/_action/feedback": "nav2_msgs/action/Wait_FeedbackMessage", + "/wait/_action/status": "action_msgs/msg/GoalStatusArray", + # Costmaps and mapping + "/global_costmap/costmap": "nav_msgs/msg/OccupancyGrid", + "/global_costmap/costmap_raw": "nav2_msgs/msg/Costmap", + "/global_costmap/costmap_updates": "map_msgs/msg/OccupancyGridUpdate", + "/global_costmap/footprint": "geometry_msgs/msg/Polygon", + "/global_costmap/published_footprint": "geometry_msgs/msg/PolygonStamped", + "/global_costmap/scan": "sensor_msgs/msg/LaserScan", + "/local_costmap/costmap": "nav_msgs/msg/OccupancyGrid", + "/local_costmap/costmap_raw": "nav2_msgs/msg/Costmap", + "/local_costmap/costmap_updates": "map_msgs/msg/OccupancyGridUpdate", + "/local_costmap/footprint": "geometry_msgs/msg/Polygon", + "/local_costmap/published_footprint": "geometry_msgs/msg/PolygonStamped", + "/local_costmap/scan": "sensor_msgs/msg/LaserScan", + "/map": "nav_msgs/msg/OccupancyGrid", + "/map_metadata": "nav_msgs/msg/MapMetaData", + # SLAM + "/slam_toolbox/feedback": "visualization_msgs/msg/InteractiveMarkerFeedback", + "/slam_toolbox/graph_visualization": "visualization_msgs/msg/MarkerArray", + "/slam_toolbox/scan_visualization": "sensor_msgs/msg/LaserScan", + "/slam_toolbox/update": "visualization_msgs/msg/InteractiveMarkerUpdate", + # Path planning and visualization + "/plan": "nav_msgs/msg/Path", + "/plan_smoothed": "nav_msgs/msg/Path", + "/unsmoothed_plan": "nav_msgs/msg/Path", + "/transformed_global_plan": "nav_msgs/msg/Path", + "/trajectories": "visualization_msgs/msg/MarkerArray", + # Control and goals + "/cmd_vel_nav": "geometry_msgs/msg/Twist", + "/cmd_vel_teleop": "geometry_msgs/msg/Twist", + "/goal_pose": "geometry_msgs/msg/PoseStamped", + "/pose": "geometry_msgs/msg/PoseWithCovarianceStamped", + "/preempt_teleop": "std_msgs/msg/Empty", + "/speed_limit": "nav2_msgs/msg/SpeedLimit", + # Behavior tree + "/behavior_tree_log": "nav2_msgs/msg/BehaviorTreeLog", + # Other + "/led_strip": "sensor_msgs/msg/Image", + # Lifecycle management + "/behavior_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/bt_navigator/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/controller_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/global_costmap/global_costmap/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/local_costmap/local_costmap/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/map_saver/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/planner_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/smoother_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/velocity_smoother/transition_event": "lifecycle_msgs/msg/TransitionEvent", + "/waypoint_follower/transition_event": "lifecycle_msgs/msg/TransitionEvent", +} + +CUSTOM_TOPICS_AND_TYPES: Dict[str, str] = { + "/hri_message": "rai_interfaces/msg/HRIMessage", + "/audio_message": "rai_interfaces/msg/AudioMessage", + "/detection_array": "rai_interfaces/msg/RAIDetectionArray", +} + + +COMMON_SERVICES_AND_TYPES: Dict[str, str] = { + # Core infrastructure + "/tf2_frames": "tf2_msgs/srv/FrameGraph", + # Container management + "/nav2_container/_container/list_nodes": "composition_interfaces/srv/ListNodes", + "/nav2_container/_container/load_node": "composition_interfaces/srv/LoadNode", + "/nav2_container/_container/unload_node": "composition_interfaces/srv/UnloadNode", + # Robot state and transforms + "/robot_state_publisher/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/robot_state_publisher/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/robot_state_publisher/get_parameters": "rcl_interfaces/srv/GetParameters", + "/robot_state_publisher/list_parameters": "rcl_interfaces/srv/ListParameters", + "/robot_state_publisher/set_parameters": "rcl_interfaces/srv/SetParameters", + "/robot_state_publisher/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/static_transform_publisher/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/static_transform_publisher/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/static_transform_publisher/get_parameters": "rcl_interfaces/srv/GetParameters", + "/static_transform_publisher/list_parameters": "rcl_interfaces/srv/ListParameters", + "/static_transform_publisher/set_parameters": "rcl_interfaces/srv/SetParameters", + "/static_transform_publisher/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + # Simulation/Gazebo services + "/delete_entity": "gazebo_msgs/srv/DeleteEntity", + "/get_available_spawnable_names": "gazebo_msgs/srv/GetWorldProperties", + "/get_spawn_point_info": "gazebo_msgs/srv/GetModelState", + "/get_spawn_points_names": "gazebo_msgs/srv/GetWorldProperties", + "/spawn_entity": "gazebo_msgs/srv/SpawnEntity", + # Parameter services (common pattern for all nodes) + "/launch_ros_138640/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/launch_ros_138640/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/launch_ros_138640/get_parameters": "rcl_interfaces/srv/GetParameters", + "/launch_ros_138640/list_parameters": "rcl_interfaces/srv/ListParameters", + "/launch_ros_138640/set_parameters": "rcl_interfaces/srv/SetParameters", + "/launch_ros_138640/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/launch_ros_2375507/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/launch_ros_2375507/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/launch_ros_2375507/get_parameters": "rcl_interfaces/srv/GetParameters", + "/launch_ros_2375507/list_parameters": "rcl_interfaces/srv/ListParameters", + "/launch_ros_2375507/set_parameters": "rcl_interfaces/srv/SetParameters", + "/launch_ros_2375507/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/o3de_ros2_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/o3de_ros2_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/o3de_ros2_node/get_parameters": "rcl_interfaces/srv/GetParameters", + "/o3de_ros2_node/list_parameters": "rcl_interfaces/srv/ListParameters", + "/o3de_ros2_node/set_parameters": "rcl_interfaces/srv/SetParameters", + "/o3de_ros2_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + # AI/ML services (custom interfaces for perception and documentation) + "/grounded_sam/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/grounded_sam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/grounded_sam/get_parameters": "rcl_interfaces/srv/GetParameters", + "/grounded_sam/list_parameters": "rcl_interfaces/srv/ListParameters", + "/grounded_sam/set_parameters": "rcl_interfaces/srv/SetParameters", + "/grounded_sam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/grounding_dino/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/grounding_dino/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/grounding_dino/get_parameters": "rcl_interfaces/srv/GetParameters", + "/grounding_dino/list_parameters": "rcl_interfaces/srv/ListParameters", + "/grounding_dino/set_parameters": "rcl_interfaces/srv/SetParameters", + "/grounding_dino/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/rai_ros2_ari_connector_b6ed00ab6356/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/rai_ros2_ari_connector_b6ed00ab6356/get_parameters": "rcl_interfaces/srv/GetParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/list_parameters": "rcl_interfaces/srv/ListParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters": "rcl_interfaces/srv/SetParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", +} + +MANIPULATION_SERVICES_AND_TYPES: Dict[str, str] = { + # MoveIt2 planning services + "/apply_planning_scene": "moveit_msgs/srv/ApplyPlanningScene", + "/check_state_validity": "moveit_msgs/srv/GetStateValidity", + "/clear_octomap": "std_srvs/srv/Empty", + "/compute_cartesian_path": "moveit_msgs/srv/GetCartesianPath", + "/compute_fk": "moveit_msgs/srv/GetPositionFK", + "/compute_ik": "moveit_msgs/srv/GetPositionIK", + "/get_planner_params": "moveit_msgs/srv/GetPlannerParams", + "/get_planning_scene": "moveit_msgs/srv/GetPlanningScene", + "/load_map": "moveit_msgs/srv/LoadMap", + "/plan_kinematic_path": "moveit_msgs/srv/GetMotionPlan", + "/query_planner_interface": "moveit_msgs/srv/QueryPlannerInterfaces", + "/save_map": "moveit_msgs/srv/SaveMap", + "/set_planner_params": "moveit_msgs/srv/SetPlannerParams", + # Custom manipulation interfaces + "/reset_manipulator": "std_srvs/srv/Trigger", + # MoveIt2 component parameter services + "/move_group/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/move_group/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/move_group/get_parameters": "rcl_interfaces/srv/GetParameters", + "/move_group/list_parameters": "rcl_interfaces/srv/ListParameters", + "/move_group/set_parameters": "rcl_interfaces/srv/SetParameters", + "/move_group/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/move_group_private_96220314512624/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/move_group_private_96220314512624/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/move_group_private_96220314512624/get_parameters": "rcl_interfaces/srv/GetParameters", + "/move_group_private_96220314512624/list_parameters": "rcl_interfaces/srv/ListParameters", + "/move_group_private_96220314512624/set_parameters": "rcl_interfaces/srv/SetParameters", + "/move_group_private_96220314512624/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/moveit_simple_controller_manager/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/moveit_simple_controller_manager/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/moveit_simple_controller_manager/get_parameters": "rcl_interfaces/srv/GetParameters", + "/moveit_simple_controller_manager/list_parameters": "rcl_interfaces/srv/ListParameters", + "/moveit_simple_controller_manager/set_parameters": "rcl_interfaces/srv/SetParameters", + "/moveit_simple_controller_manager/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + # Arm controller services + "/arm_controller/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/arm_controller/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/arm_controller/get_parameters": "rcl_interfaces/srv/GetParameters", + "/arm_controller/list_parameters": "rcl_interfaces/srv/ListParameters", + "/arm_controller/set_parameters": "rcl_interfaces/srv/SetParameters", + "/arm_controller/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/state_controller/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/state_controller/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/state_controller/get_parameters": "rcl_interfaces/srv/GetParameters", + "/state_controller/list_parameters": "rcl_interfaces/srv/ListParameters", + "/state_controller/set_parameters": "rcl_interfaces/srv/SetParameters", + "/state_controller/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", +} + +NAVIGATION_SERVICES_AND_TYPES: Dict[str, str] = { + # Action services for navigation behaviors + "/assisted_teleop/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/assisted_teleop/_action/get_result": "nav2_msgs/action/AssistedTeleop_GetResult", + "/assisted_teleop/_action/send_goal": "nav2_msgs/action/AssistedTeleop_SendGoal", + "/backup/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/backup/_action/get_result": "nav2_msgs/action/BackUp_GetResult", + "/backup/_action/send_goal": "nav2_msgs/action/BackUp_SendGoal", + "/drive_on_heading/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/drive_on_heading/_action/get_result": "nav2_msgs/action/DriveOnHeading_GetResult", + "/drive_on_heading/_action/send_goal": "nav2_msgs/action/DriveOnHeading_SendGoal", + "/follow_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/follow_path/_action/get_result": "nav2_msgs/action/FollowPath_GetResult", + "/follow_path/_action/send_goal": "nav2_msgs/action/FollowPath_SendGoal", + "/follow_waypoints/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/follow_waypoints/_action/get_result": "nav2_msgs/action/FollowWaypoints_GetResult", + "/follow_waypoints/_action/send_goal": "nav2_msgs/action/FollowWaypoints_SendGoal", + "/spin/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/spin/_action/get_result": "nav2_msgs/action/Spin_GetResult", + "/spin/_action/send_goal": "nav2_msgs/action/Spin_SendGoal", + "/wait/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/wait/_action/get_result": "nav2_msgs/action/Wait_GetResult", + "/wait/_action/send_goal": "nav2_msgs/action/Wait_SendGoal", + # Path planning action services + "/compute_path_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/compute_path_through_poses/_action/get_result": "nav2_msgs/action/ComputePathThroughPoses_GetResult", + "/compute_path_through_poses/_action/send_goal": "nav2_msgs/action/ComputePathThroughPoses_SendGoal", + "/compute_path_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/compute_path_to_pose/_action/get_result": "nav2_msgs/action/ComputePathToPose_GetResult", + "/compute_path_to_pose/_action/send_goal": "nav2_msgs/action/ComputePathToPose_SendGoal", + "/smooth_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/smooth_path/_action/get_result": "nav2_msgs/action/SmoothPath_GetResult", + "/smooth_path/_action/send_goal": "nav2_msgs/action/SmoothPath_SendGoal", + # Main navigation action services + "/navigate_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/navigate_through_poses/_action/get_result": "nav2_msgs/action/NavigateThroughPoses_GetResult", + "/navigate_through_poses/_action/send_goal": "nav2_msgs/action/NavigateThroughPoses_SendGoal", + "/navigate_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/navigate_to_pose/_action/get_result": "nav2_msgs/action/NavigateToPose_GetResult", + "/navigate_to_pose/_action/send_goal": "nav2_msgs/action/NavigateToPose_SendGoal", + # Costmap management services + "/global_costmap/clear_around_global_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", + "/global_costmap/clear_entirely_global_costmap": "nav2_msgs/srv/ClearEntireCostmap", + "/global_costmap/clear_except_global_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", + "/global_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", + "/local_costmap/clear_around_local_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", + "/local_costmap/clear_entirely_local_costmap": "nav2_msgs/srv/ClearEntireCostmap", + "/local_costmap/clear_except_local_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", + "/local_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", + # Path validation + "/is_path_valid": "nav2_msgs/srv/IsPathValid", + # SLAM services + "/slam_toolbox/clear_changes": "slam_toolbox/srv/Clear", + "/slam_toolbox/clear_queue": "slam_toolbox/srv/ClearQueue", + "/slam_toolbox/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/slam_toolbox/deserialize_map": "slam_toolbox/srv/DeserializePoseGraph", + "/slam_toolbox/dynamic_map": "nav_msgs/srv/GetMap", + "/slam_toolbox/get_interactive_markers": "visualization_msgs/srv/GetInteractiveMarkers", + "/slam_toolbox/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/slam_toolbox/get_parameters": "rcl_interfaces/srv/GetParameters", + "/slam_toolbox/list_parameters": "rcl_interfaces/srv/ListParameters", + "/slam_toolbox/manual_loop_closure": "slam_toolbox/srv/LoopClosure", + "/slam_toolbox/pause_new_measurements": "slam_toolbox/srv/Pause", + "/slam_toolbox/save_map": "slam_toolbox/srv/SaveMap", + "/slam_toolbox/serialize_map": "slam_toolbox/srv/SerializePoseGraph", + "/slam_toolbox/set_parameters": "rcl_interfaces/srv/SetParameters", + "/slam_toolbox/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/slam_toolbox/toggle_interactive_mode": "slam_toolbox/srv/ToggleInteractive", + # Map saving + "/map_saver/change_state": "lifecycle_msgs/srv/ChangeState", + "/map_saver/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/map_saver/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/map_saver/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/map_saver/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/map_saver/get_parameters": "rcl_interfaces/srv/GetParameters", + "/map_saver/get_state": "lifecycle_msgs/srv/GetState", + "/map_saver/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/map_saver/list_parameters": "rcl_interfaces/srv/ListParameters", + "/map_saver/save_map": "nav2_msgs/srv/SaveMap", + "/map_saver/set_parameters": "rcl_interfaces/srv/SetParameters", + "/map_saver/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + # Navigation server lifecycle and parameter services + "/behavior_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/behavior_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/behavior_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/behavior_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/behavior_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/behavior_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/behavior_server/get_state": "lifecycle_msgs/srv/GetState", + "/behavior_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/behavior_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/behavior_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/behavior_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/bt_navigator/change_state": "lifecycle_msgs/srv/ChangeState", + "/bt_navigator/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/bt_navigator/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/bt_navigator/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/bt_navigator/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/bt_navigator/get_parameters": "rcl_interfaces/srv/GetParameters", + "/bt_navigator/get_state": "lifecycle_msgs/srv/GetState", + "/bt_navigator/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/bt_navigator/list_parameters": "rcl_interfaces/srv/ListParameters", + "/bt_navigator/set_parameters": "rcl_interfaces/srv/SetParameters", + "/bt_navigator/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/bt_navigator_navigate_through_poses_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/bt_navigator_navigate_to_pose_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/controller_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/controller_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/controller_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/controller_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/controller_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/controller_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/controller_server/get_state": "lifecycle_msgs/srv/GetState", + "/controller_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/controller_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/controller_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/controller_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/global_costmap/global_costmap/change_state": "lifecycle_msgs/srv/ChangeState", + "/global_costmap/global_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/global_costmap/global_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/global_costmap/global_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/global_costmap/global_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/global_costmap/global_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", + "/global_costmap/global_costmap/get_state": "lifecycle_msgs/srv/GetState", + "/global_costmap/global_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/global_costmap/global_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", + "/global_costmap/global_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", + "/global_costmap/global_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/local_costmap/local_costmap/change_state": "lifecycle_msgs/srv/ChangeState", + "/local_costmap/local_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/local_costmap/local_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/local_costmap/local_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/local_costmap/local_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/local_costmap/local_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", + "/local_costmap/local_costmap/get_state": "lifecycle_msgs/srv/GetState", + "/local_costmap/local_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/local_costmap/local_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", + "/local_costmap/local_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", + "/local_costmap/local_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/planner_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/planner_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/planner_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/planner_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/planner_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/planner_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/planner_server/get_state": "lifecycle_msgs/srv/GetState", + "/planner_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/planner_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/planner_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/planner_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/smoother_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/smoother_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/smoother_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/smoother_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/smoother_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/smoother_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/smoother_server/get_state": "lifecycle_msgs/srv/GetState", + "/smoother_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/smoother_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/smoother_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/smoother_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/velocity_smoother/change_state": "lifecycle_msgs/srv/ChangeState", + "/velocity_smoother/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/velocity_smoother/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/velocity_smoother/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/velocity_smoother/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/velocity_smoother/get_parameters": "rcl_interfaces/srv/GetParameters", + "/velocity_smoother/get_state": "lifecycle_msgs/srv/GetState", + "/velocity_smoother/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/velocity_smoother/list_parameters": "rcl_interfaces/srv/ListParameters", + "/velocity_smoother/set_parameters": "rcl_interfaces/srv/SetParameters", + "/velocity_smoother/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/waypoint_follower/change_state": "lifecycle_msgs/srv/ChangeState", + "/waypoint_follower/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/waypoint_follower/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/waypoint_follower/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/waypoint_follower/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/waypoint_follower/get_parameters": "rcl_interfaces/srv/GetParameters", + "/waypoint_follower/get_state": "lifecycle_msgs/srv/GetState", + "/waypoint_follower/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/waypoint_follower/list_parameters": "rcl_interfaces/srv/ListParameters", + "/waypoint_follower/set_parameters": "rcl_interfaces/srv/SetParameters", + "/waypoint_follower/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + # Lifecycle management services + "/lifecycle_manager_navigation/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/lifecycle_manager_navigation/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/lifecycle_manager_navigation/get_parameters": "rcl_interfaces/srv/GetParameters", + "/lifecycle_manager_navigation/is_active": "std_srvs/srv/Trigger", + "/lifecycle_manager_navigation/list_parameters": "rcl_interfaces/srv/ListParameters", + "/lifecycle_manager_navigation/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", + "/lifecycle_manager_navigation/set_parameters": "rcl_interfaces/srv/SetParameters", + "/lifecycle_manager_navigation/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/lifecycle_manager_slam/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/lifecycle_manager_slam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/lifecycle_manager_slam/get_parameters": "rcl_interfaces/srv/GetParameters", + "/lifecycle_manager_slam/is_active": "std_srvs/srv/Trigger", + "/lifecycle_manager_slam/list_parameters": "rcl_interfaces/srv/ListParameters", + "/lifecycle_manager_slam/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", + "/lifecycle_manager_slam/set_parameters": "rcl_interfaces/srv/SetParameters", + "/lifecycle_manager_slam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", +} +CUSTOM_SERVICES_AND_TYPES: Dict[str, str] = { + "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam", + "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino", + "/manipulator_move_to": "rai_interfaces/srv/ManipulatorMoveTo", + "/get_log_digest": "rai_interfaces/srv/StringList", + "/rai_whoami_documentation_service": "rai_interfaces/srv/VectorStoreRetrieval", + "/rai/whatisee/get": "rai_interfaces/srv/WhatISee", +} + +MANIPULATION_ACTIONS_AND_TYPES: Dict[str, str] = { + "/move_action": "moveit_msgs/action/MoveGroup", + "/execute_trajectory": "moveit_msgs/action/ExecuteTrajectory", + "/panda_arm_controller/follow_joint_trajectory": "control_msgs/action/FollowJointTrajectory", + "/arm_controller/follow_joint_trajectory": "control_msgs/action/FollowJointTrajectory", + "/panda_hand_controller/gripper_cmd": "control_msgs/action/GripperCommand", + "/gripper_controller/gripper_cmd": "control_msgs/action/GripperCommand", + "/pickup": "moveit_msgs/action/Pickup", + "/place": "moveit_msgs/action/Place", +} +NAVIGATION_ACTIONS_AND_TYPES: Dict[str, str] = { + "/navigate_to_pose": "nav2_msgs/action/NavigateToPose", + "/navigate_through_poses": "nav2_msgs/action/Nmoveit_msgs/action/MoveGroupmoveit_msgs/action/MoveGroupavigateThroughPoses", + "/follow_path": "nav2_msgs/action/FollowPath", + "/follow_waypoints": "nav2_msgs/action/FollowWaypoints", + "/compute_path_to_pose": "nav2_msgs/action/ComputePathToPose", + "/compute_path_through_poses": "nav2_msgs/action/ComputePathThroughPoses", + "/smooth_path": "nav2_msgs/action/SmoothPath", + "/spin": "nav2_msgs/action/Spin", + "/backup": "nav2_msgs/action/BackUp", + "/drive_on_heading": "nav2_msgs/action/DriveOnHeading", + "/wait": "nav2_msgs/action/Wait", + "/assisted_teleop": "nav2_msgs/action/AssistedTeleop", + "/clear_costmap": "nav2_msgs/action/ClearEntireCostmap", +} +COMMON_TOPIC_MODELS: Dict[str, Type[BaseModel]] = { + "sensor_msgs/msg/CameraInfo": CameraInfo, + "sensor_msgs/msg/Image": Image, + "rosgraph_msgs/msg/Clock": Clock, +} + +CUSTOM_TOPIC_MODELS: Dict[str, Type[BaseModel]] = { + "rai_interfaces/msg/HRIMessage": HRIMessage, + "rai_interfaces/msg/AudioMessage": AudioMessage, + "rai_interfaces/msg/RAIDetectionArray": RAIDetectionArray, +} + +CUSTOM_SERVICE_MODELS: Dict[str, Type[BaseModel]] = { + "rai_interfaces/srv/ManipulatorMoveTo": ManipulatorMoveToRequest, + "rai_interfaces/srv/RAIGroundedSam": RAIGroundedSamRequest, + "rai_interfaces/srv/RAIGroundingDino": RAIGroundingDinoRequest, + "rai_interfaces/srv/StringList": StringListRequest, + "rai_interfaces/srv/VectorStoreRetrieval": VectorStoreRetrievalRequest, + "rai_interfaces/srv/WhatISee": WhatISeeRequest, +} +MANIPULATION_ACTION_MODELS: Dict[str, Type[BaseModel]] = {} +NAVIGATION_ACTION_MODELS: Dict[str, Type[BaseModel]] = { + "nav2_msgs/action/NavigateToPose": NavigateToPoseGoal, + "nav2_msgs/action/Spin": SpinGoal, + "nav2_msgs/action/AssistedTeleop": AssistedTeleopGoal, + "nav2_msgs/action/BackUp": BackUpGoal, + "nav2_msgs/action/ComputePathThroughPoses": ComputePathThroughPosesGoal, + "nav2_msgs/action/ComputePathToPose": ComputePathToPoseGoal, + "nav2_msgs/action/DriveOnHeading": DriveOnHeadingGoal, + "nav2_msgs/action/FollowPath": FollowPathGoal, + "nav2_msgs/action/FollowWaypoints": FollowWaypointsGoal, + "nav2_msgs/action/NavigateThroughPoses": NavigateThroughPosesGoal, + "nav2_msgs/action/SmoothPath": SmoothPathGoal, + "nav2_msgs/action/Wait": WaitGoal, +} diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py new file mode 100644 index 000000000..953dea515 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py @@ -0,0 +1,27 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .basic_tasks import get_basic_tasks +from .custom_interfaces_tasks import get_custom_interfaces_tasks +from .manipulation_tasks import get_manipulation_tasks +from .navigation_tasks import get_navigation_tasks +from .spatial_reasoning_tasks import get_spatial_tasks + +__all__ = [ + "get_basic_tasks", + "get_custom_interfaces_tasks", + "get_manipulation_tasks", + "get_navigation_tasks", + "get_spatial_tasks", +] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py new file mode 100644 index 000000000..2a2ae7e88 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py @@ -0,0 +1,306 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal + +from rai_bench.tool_calling_agent.interfaces import ( + Task, + TaskArgs, +) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, +) +from rai_bench.tool_calling_agent.tasks.basic import ( + AssessSensorDataQualityTask, + CheckRobotHealthTask, + GetAllROS2CamerasTask, + GetPointcloudTask, + GetRobotDescriptionTask, + GetROS2DepthCameraTask, + GetROS2RGBCameraTask, + GetROS2TopicsTask, +) +from rai_bench.tool_calling_agent.validators import ( + NotOrderedCallsValidator, + OrderedCallsValidator, +) + +########## SUBTASKS ################################################################# +get_topics_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_topics_names_and_types", expected_args={} +) + +color_image5_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_image", + expected_args={"topic": "/color_image5"}, + expected_optional_args={"timeout_sec": int}, +) +depth_image5_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_image", + expected_args={"topic": "/depth_image5"}, + expected_optional_args={"timeout_sec": int}, +) + +color_camera_info5_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_image", + expected_args={"topic": "/color_image5"}, + expected_optional_args={"timeout_sec": int}, +) +depth_camera_info5_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_image", + expected_args={"topic": "/depth_image5"}, + expected_optional_args={"timeout_sec": int}, +) + +receive_robot_desc_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_description"}, + expected_optional_args={"timeout_sec": int}, +) + +receive_pointcloud_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/pointcloud"}, + expected_optional_args={"timeout_sec": int}, +) + +# System health subtasks +diagnostics_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/diagnostics"}, + expected_optional_args={"timeout_sec": int}, +) +rosout_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/rosout"}, + expected_optional_args={"timeout_sec": int}, +) +joint_states_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/joint_states"}, + expected_optional_args={"timeout_sec": int}, +) + +# Odometry subtasks +odom_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/odom"}, + expected_optional_args={"timeout_sec": int}, +) +filtered_odom_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/odometry/filtered"}, + expected_optional_args={"timeout_sec": int}, +) + +# Transform subtasks +tf_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/tf"}, + expected_optional_args={"timeout_sec": int}, +) +tf_static_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/tf_static"}, + expected_optional_args={"timeout_sec": int}, +) + + +# Robot description subtasks +robot_description_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_description"}, + expected_optional_args={"timeout_sec": int}, +) +robot_description_semantic_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_description_semantic"}, + expected_optional_args={"timeout_sec": int}, +) + +# Sensor data subtasks +scan_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/scan"}, + expected_optional_args={"timeout_sec": int}, +) +pointcloud_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/pointcloud"}, + expected_optional_args={"timeout_sec": int}, +) + + +# Robot description subtasks +robot_description_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_description"}, + expected_optional_args={"timeout_sec": int}, +) +robot_description_semantic_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_description_semantic"}, + expected_optional_args={"timeout_sec": int}, +) + +# Sensor data subtasks +scan_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/scan"}, + expected_optional_args={"timeout_sec": int}, +) +pointcloud_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/pointcloud"}, + expected_optional_args={"timeout_sec": int}, +) + +# Robot description subtasks +robot_description_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_description"}, + expected_optional_args={"timeout_sec": int}, +) +robot_description_semantic_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_description_semantic"}, + expected_optional_args={"timeout_sec": int}, +) + +######### VALIDATORS ######################################################################################### +topics_ord_val = OrderedCallsValidator(subtasks=[get_topics_subtask]) + +color_image_ord_val = OrderedCallsValidator(subtasks=[color_image5_subtask]) +depth_image_ord_val = OrderedCallsValidator(subtasks=[depth_image5_subtask]) + +color_camera_info_ord_val = OrderedCallsValidator(subtasks=[color_camera_info5_subtask]) +depth_camera_info_ord_val = OrderedCallsValidator(subtasks=[depth_camera_info5_subtask]) + +color_image_with_info_ord_val = NotOrderedCallsValidator( + subtasks=[color_image5_subtask, color_camera_info5_subtask] +) +depth_image_with_info_ord_val = NotOrderedCallsValidator( + subtasks=[depth_image5_subtask, color_camera_info5_subtask] +) + +all_camera_images_notord_val = NotOrderedCallsValidator( + subtasks=[ + color_image5_subtask, + depth_image5_subtask, + ] +) +all_camera_info_notord_val = NotOrderedCallsValidator( + subtasks=[ + color_camera_info5_subtask, + depth_camera_info5_subtask, + ] +) +all_camera_images_with_info_notord_val = NotOrderedCallsValidator( + subtasks=[ + color_image5_subtask, + depth_image5_subtask, + color_camera_info5_subtask, + depth_camera_info5_subtask, + ] +) + +joint_states_ord_val = OrderedCallsValidator(subtasks=[joint_states_subtask]) +diagnostics_ord_val = OrderedCallsValidator(subtasks=[diagnostics_subtask]) + +get_pointcloud_ord_val = OrderedCallsValidator(subtasks=[receive_pointcloud_subtask]) +get_robot_desc_ord_val = OrderedCallsValidator(subtasks=[receive_robot_desc_subtask]) + +robot_health_val = NotOrderedCallsValidator( + subtasks=[diagnostics_subtask, joint_states_subtask, rosout_subtask] +) + +odometry_comparison_val = NotOrderedCallsValidator( + subtasks=[odom_subtask, filtered_odom_subtask] +) +sensor_data_val = NotOrderedCallsValidator( + subtasks=[ + scan_subtask, + receive_pointcloud_subtask, + color_image5_subtask, + depth_image5_subtask, + color_camera_info5_subtask, + depth_camera_info5_subtask, + ] +) + + +def get_basic_tasks( + extra_tool_calls: List[int] = [0], + prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], + n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], +) -> List[Task]: + """Get predefined basic tasks. + + Parameters + ---------- + Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. + See the class documentation for parameter descriptions. + + Returns + ------- + Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. + """ + tasks: List[Task] = [] + + # Generate all combinations of prompt_detail and n_shots and extra tool calls + for extra_calls in extra_tool_calls: + for detail in prompt_detail: + for shots in n_shots: + task_args = TaskArgs( + extra_tool_calls=extra_calls, + prompt_detail=detail, + examples_in_system_prompt=shots, + ) + + tasks.extend( + [ + GetROS2RGBCameraTask( + validators=[color_image_ord_val], + task_args=task_args, + ), + GetROS2TopicsTask( + validators=[topics_ord_val], + task_args=task_args, + ), + GetROS2DepthCameraTask( + validators=[depth_image_ord_val], + task_args=task_args, + ), + GetAllROS2CamerasTask( + validators=[all_camera_images_notord_val], + task_args=task_args, + ), + GetPointcloudTask( + validators=[get_pointcloud_ord_val], task_args=task_args + ), + GetRobotDescriptionTask( + validators=[get_robot_desc_ord_val], task_args=task_args + ), + CheckRobotHealthTask( + validators=[robot_health_val], + task_args=task_args, + ), + AssessSensorDataQualityTask( + validators=[sensor_data_val], + task_args=task_args, + ), + ] + ) + + return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py new file mode 100644 index 000000000..36f84a31b --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py @@ -0,0 +1,95 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal + +from rai_bench.tool_calling_agent.interfaces import ( + Task, + TaskArgs, +) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, + CheckTopicFieldsToolCallSubTask, +) +from rai_bench.tool_calling_agent.tasks.custom_interfaces import ( + PublishROS2HRIMessageTextTask, +) +from rai_bench.tool_calling_agent.validators import ( + OrderedCallsValidator, +) + +########## SUBTASKS ################################################################# +pub_HRIMessage_text_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic="/to_human", + expected_message_type="rai_interfaces/msg/HRIMessage", + expected_fields={"text": "Hello!"}, +) + +get_HRIMessage_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": "rai_interfaces/msg/HRIMessage"}, +) + + +######### VALIDATORS ######################################################################################### +pub_HRIMessage_text_ord_val = OrderedCallsValidator( + subtasks=[pub_HRIMessage_text_subtask] +) +get_interface_publish_ord_val = OrderedCallsValidator( + subtasks=[ + get_HRIMessage_interface_subtask, + pub_HRIMessage_text_subtask, + ] +) + + +def get_custom_interfaces_tasks( + extra_tool_calls: List[int] = [0], + prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], + n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], +) -> List[Task]: + """Get predefined custom_interfaces tasks. + + Parameters + ---------- + Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. + See the class documentation for parameter descriptions. + + Returns + ------- + Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. + """ + tasks: List[Task] = [] + + for extra_calls in extra_tool_calls: + for detail in prompt_detail: + for shots in n_shots: + task_args = TaskArgs( + extra_tool_calls=extra_calls, + prompt_detail=detail, + examples_in_system_prompt=shots, + ) + tasks.append( + PublishROS2HRIMessageTextTask( + topic="/to_human", + validators=[ + get_interface_publish_ord_val, + ], + task_args=task_args, + text="Hello!", + ), + ) + + return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py new file mode 100644 index 000000000..7e4068c8f --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py @@ -0,0 +1,104 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal + +from rai.tools.ros2 import MoveToPointToolInput +from rai.types import Point + +from rai_bench.tool_calling_agent.interfaces import ( + Task, + TaskArgs, +) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, +) +from rai_bench.tool_calling_agent.tasks.manipulation import ( + MoveToPointTask, +) +from rai_bench.tool_calling_agent.validators import ( + OrderedCallsValidator, +) + +########## SUBTASKS ################################################################# +move_to_point_subtask_grab = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}, +) +move_to_point_subtask_drop = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}, +) + +######### VALIDATORS ######################################################################################### +move_to_point_ord_val_grab = OrderedCallsValidator( + subtasks=[move_to_point_subtask_grab] +) +move_to_point_ord_val_drop = OrderedCallsValidator( + subtasks=[move_to_point_subtask_drop] +) + + +def get_manipulation_tasks( + extra_tool_calls: List[int] = [0], + prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], + n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], +) -> List[Task]: + """Get predefined manipulation tasks. + + Parameters + ---------- + Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. + See the class documentation for parameter descriptions. + + Returns + ------- + Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. + """ + tasks: List[Task] = [] + + objects = { + "banana": [Point(x=0.1, y=0.2, z=0.3), Point(x=0.4, y=0.5, z=0.6)], + "cube": [Point(x=0.7, y=0.8, z=0.9)], + } + for extra_calls in extra_tool_calls: + for detail in prompt_detail: + for shots in n_shots: + task_args = TaskArgs( + extra_tool_calls=extra_calls, + prompt_detail=detail, + examples_in_system_prompt=shots, + ) + tasks.extend( + [ + MoveToPointTask( + objects=objects, + move_to_tool_input=MoveToPointToolInput( + x=1.0, y=2.0, z=3.0, task="grab" + ), + validators=[move_to_point_ord_val_grab], + task_args=task_args, + ), + MoveToPointTask( + objects=objects, + move_to_tool_input=MoveToPointToolInput( + x=1.2, y=2.3, z=3.4, task="drop" + ), + validators=[move_to_point_ord_val_drop], + task_args=task_args, + ), + ] + ) + + return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py new file mode 100644 index 000000000..4099e5a22 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py @@ -0,0 +1,118 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal + +from rai_bench.tool_calling_agent.interfaces import ( + Task, + TaskArgs, +) +from rai_bench.tool_calling_agent.subtasks import ( + CheckActionFieldsToolCallSubTask, +) +from rai_bench.tool_calling_agent.tasks.navigation import ( + MoveToBedTask, + MoveToFrontTask, + NavigateToPointTask, + SpinAroundTask, +) +from rai_bench.tool_calling_agent.validators import ( + OrderedCallsValidator, +) + +########## SUBTASKS ################################################################# + +start_nav_action_subtask = CheckActionFieldsToolCallSubTask( + expected_tool_name="start_ros2_action", + expected_action="/navigate_to_pose", + expected_action_type="nav2_msgs/action/NavigateToPose", + expected_fields={ + "pose": { + "header": {"frame_id": "map"}, + "pose": { + "position": {"x": 2.0, "y": 2.0, "z": 0.0}, + }, + }, + }, +) +start_spin_action_subtask = CheckActionFieldsToolCallSubTask( + expected_tool_name="start_ros2_action", + expected_action="/spin", + expected_action_type="nav2_msgs/action/Spin", + expected_fields={"target_yaw": 3}, +) +start_move_front_action_subtask = CheckActionFieldsToolCallSubTask( + expected_tool_name="start_ros2_action", + expected_action="/drive_on_heading", + expected_action_type="nav2_msgs/action/DriveOnHeading", + expected_fields={ + "target": {"y": 0.0, "z": 0.0}, + }, +) +######### VALIDATORS ######################################################################################### +start_navigate_action_ord_val = OrderedCallsValidator( + subtasks=[start_nav_action_subtask] +) +start_spin_action_ord_val = OrderedCallsValidator(subtasks=[start_spin_action_subtask]) +move_ahead_ord_val = OrderedCallsValidator(subtasks=[start_move_front_action_subtask]) + + +def get_navigation_tasks( + extra_tool_calls: List[int] = [0], + prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], + n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], +) -> List[Task]: + """Get predefined navigation tasks. + + Parameters + ---------- + Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. + See the class documentation for parameter descriptions. + + Returns + ------- + Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. + """ + tasks: List[Task] = [] + + for extra_calls in extra_tool_calls: + for detail in prompt_detail: + for shots in n_shots: + task_args = TaskArgs( + extra_tool_calls=extra_calls, + prompt_detail=detail, + examples_in_system_prompt=shots, + ) + tasks.extend( + [ + NavigateToPointTask( + validators=[start_navigate_action_ord_val], + task_args=task_args, + ), + SpinAroundTask( + validators=[start_spin_action_ord_val], + task_args=task_args, + ), + MoveToBedTask( + validators=[move_ahead_ord_val], + task_args=task_args, + ), + MoveToFrontTask( + validators=[move_ahead_ord_val], + task_args=task_args, + ), + ] + ) + + return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py new file mode 100644 index 000000000..f67d19792 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py @@ -0,0 +1,278 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal, Sequence + +from rai_bench.tool_calling_agent.interfaces import ( + Task, + TaskArgs, +) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, +) +from rai_bench.tool_calling_agent.tasks.spatial import ( + BoolImageTaskEasy, + BoolImageTaskHard, + BoolImageTaskInput, + BoolImageTaskMedium, +) +from rai_bench.tool_calling_agent.validators import ( + OrderedCallsValidator, +) + +IMG_PATH = "src/rai_bench/rai_bench/tool_calling_agent/predefined/images/" +true_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door on the left from the desk?", + images_paths=[IMG_PATH + "image_1.jpg"], + ), + BoolImageTaskInput( + question="Is the light on in the room?", + images_paths=[IMG_PATH + "image_2.jpg"], + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_2.jpg"], + ), + BoolImageTaskInput( + question="Are there any pictures on the wall?", + images_paths=[IMG_PATH + "image_3.jpg"], + ), + BoolImageTaskInput( + question="Are there 3 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + ), + BoolImageTaskInput( + question="Is there a plant behind the rack?", + images_paths=[IMG_PATH + "image_5.jpg"], + ), + BoolImageTaskInput( + question="Is there a pillow on the armchain?", + images_paths=[IMG_PATH + "image_7.jpg"], + ), +] +false_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door open?", + images_paths=[IMG_PATH + "image_1.jpg"], + ), + BoolImageTaskInput( + question="Is someone in the room?", + images_paths=[IMG_PATH + "image_1.jpg"], + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_3.jpg"], + ), + BoolImageTaskInput( + question="Are there 4 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + ), + BoolImageTaskInput( + question="Is there a rack on the left from the sofa?", + images_paths=[IMG_PATH + "image_4.jpg"], + ), + BoolImageTaskInput( + question="Is there a plant on the right from the window?", + images_paths=[IMG_PATH + "image_6.jpg"], + ), + BoolImageTaskInput( + question="Is there a red pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + ), +] +########## SUBTASKS ################################################################# +return_true_subtask = CheckArgsToolCallSubTask( + expected_tool_name="return_bool_response", expected_args={"response": True} +) +return_false_subtask = CheckArgsToolCallSubTask( + expected_tool_name="return_bool_response", expected_args={"response": False} +) + +######### VALIDATORS ######################################################################################### +ret_true_ord_val = OrderedCallsValidator(subtasks=[return_true_subtask]) +ret_false_ord_val = OrderedCallsValidator(subtasks=[return_false_subtask]) + + +def get_spatial_tasks( + extra_tool_calls: List[int] = [0], + prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], + n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], +) -> Sequence[Task]: + """Get predefined spatial reasoning tasks. + + Parameters + ---------- + Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. + See the class documentation for parameter descriptions. + + Returns + ------- + Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. + """ + tasks: List[Task] = [] + + # Categorize tasks by complexity based on question difficulty + easy_true_inputs = [ + # Single object presence/detection + BoolImageTaskInput( + question="Is the light on in the room?", + images_paths=[IMG_PATH + "image_2.jpg"], + ), + BoolImageTaskInput( + question="Do you see the plant?", images_paths=[IMG_PATH + "image_2.jpg"] + ), + BoolImageTaskInput( + question="Are there any pictures on the wall?", + images_paths=[IMG_PATH + "image_3.jpg"], + ), + BoolImageTaskInput( + question="Is there a pillow on the armchain?", + images_paths=[IMG_PATH + "image_7.jpg"], + ), + ] + + medium_true_inputs = [ + # Object state or counting + BoolImageTaskInput( + question="Are there 3 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + ), + ] + + hard_true_inputs = [ + # Spatial relationships between objects + BoolImageTaskInput( + question="Is the door on the left from the desk?", + images_paths=[IMG_PATH + "image_1.jpg"], + ), + BoolImageTaskInput( + question="Is there a plant behind the rack?", + images_paths=[IMG_PATH + "image_5.jpg"], + ), + ] + + easy_false_inputs = [ + # Single object presence/detection + BoolImageTaskInput( + question="Is someone in the room?", images_paths=[IMG_PATH + "image_1.jpg"] + ), + BoolImageTaskInput( + question="Do you see the plant?", images_paths=[IMG_PATH + "image_3.jpg"] + ), + BoolImageTaskInput( + question="Is there a red pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + ), + ] + + medium_false_inputs = [ + # Object state or counting + BoolImageTaskInput( + question="Is the door open?", images_paths=[IMG_PATH + "image_1.jpg"] + ), + BoolImageTaskInput( + question="Are there 4 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + ), + ] + + hard_false_inputs = [ + # Spatial relationships between objects + BoolImageTaskInput( + question="Is there a rack on the left from the sofa?", + images_paths=[IMG_PATH + "image_4.jpg"], + ), + BoolImageTaskInput( + question="Is there a plant on the right from the window?", + images_paths=[IMG_PATH + "image_6.jpg"], + ), + ] + + for extra_calls in extra_tool_calls: + for detail in prompt_detail: + for shots in n_shots: + task_args = TaskArgs( + extra_tool_calls=extra_calls, + prompt_detail=detail, + examples_in_system_prompt=shots, + ) + + tasks.extend( + [ + BoolImageTaskEasy( + task_input=input_item, + validators=[ret_true_ord_val], + task_args=task_args, + ) + for input_item in easy_true_inputs + ] + ) + + tasks.extend( + [ + BoolImageTaskEasy( + task_input=input_item, + validators=[ret_false_ord_val], + task_args=task_args, + ) + for input_item in easy_false_inputs + ] + ) + + tasks.extend( + [ + BoolImageTaskMedium( + task_input=input_item, + validators=[ret_true_ord_val], + task_args=task_args, + ) + for input_item in medium_true_inputs + ] + ) + + tasks.extend( + [ + BoolImageTaskMedium( + task_input=input_item, + validators=[ret_false_ord_val], + task_args=task_args, + ) + for input_item in medium_false_inputs + ] + ) + + tasks.extend( + [ + BoolImageTaskHard( + task_input=input_item, + validators=[ret_true_ord_val], + task_args=task_args, + ) + for input_item in hard_true_inputs + ] + ) + + tasks.extend( + [ + BoolImageTaskHard( + task_input=input_item, + validators=[ret_false_ord_val], + task_args=task_args, + ) + for input_item in hard_false_inputs + ] + ) + + return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py index cdeefc8be..f3a166c8a 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py @@ -13,407 +13,25 @@ # limitations under the License. import random -from typing import List, Literal, Sequence - -from rai.tools.ros2 import MoveToPointToolInput +from typing import List, Literal from rai_bench.tool_calling_agent.interfaces import ( Task, ) -from rai_bench.tool_calling_agent.subtasks import ( - CheckActionFieldsToolCallSubTask, - CheckArgsToolCallSubTask, - CheckTopicFieldsToolCallSubTask, -) -from rai_bench.tool_calling_agent.tasks.basic import ( - GetAllROS2DepthCamerasTask, - GetAllROS2RGBCamerasTask, - GetROS2DepthCameraTask, - GetROS2RGBCameraTask, - GetROS2TopicsTask, -) -from rai_bench.tool_calling_agent.tasks.custom_interfaces import ( - PublishROS2HRIMessageTextTask, -) -from rai_bench.tool_calling_agent.tasks.manipulation import ( - MoveToPointTask, -) -from rai_bench.tool_calling_agent.tasks.navigation import ( - MoveToBedTask, - MoveToFrontTask, - NavigateToPointTask, - SpinAroundTask, -) -from rai_bench.tool_calling_agent.tasks.spatial import ( - BoolImageTask, - BoolImageTaskInput, -) -from rai_bench.tool_calling_agent.validators import ( - NotOrderedCallsValidator, - OrderedCallsValidator, -) - -IMG_PATH = "src/rai_bench/rai_bench/tool_calling_agent/predefined/images/" -true_response_inputs: List[BoolImageTaskInput] = [ - BoolImageTaskInput( - question="Is the door on the left from the desk?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is the light on in the room?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Are there any pictures on the wall?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="Are there 3 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant behind the rack?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Is there a pillow on the armchain?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), -] -false_response_inputs: List[BoolImageTaskInput] = [ - BoolImageTaskInput( - question="Is the door open?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is someone in the room?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="Are there 4 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a rack on the left from the sofa?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant on the right from the window?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is there a red pillow on the armchair?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), -] -########## SUBTASKS ####################################################################################### -get_topics_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_topics_names_and_types", expected_args={} -) - -color_image5_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_image", - expected_args={"topic": "/color_image5"}, - expected_optional_args={"timeout_sec": int}, -) -depth_image5_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_image", - expected_args={"topic": "/depth_image5"}, - expected_optional_args={"timeout_sec": int}, -) - -receive_robot_desc_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description"}, - expected_optional_args={"timeout_sec": int}, -) - -move_to_point_subtask_grab = CheckArgsToolCallSubTask( - expected_tool_name="move_to_point", - expected_args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}, -) -move_to_point_subtask_drop = CheckArgsToolCallSubTask( - expected_tool_name="move_to_point", - expected_args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}, +from rai_bench.tool_calling_agent.predefined import ( + get_basic_tasks, + get_custom_interfaces_tasks, + get_manipulation_tasks, + get_navigation_tasks, + get_spatial_tasks, ) -pub_HRIMessage_text_subtask = CheckTopicFieldsToolCallSubTask( - expected_tool_name="publish_ros2_message", - expected_topic="/to_human", - expected_message_type="rai_interfaces/msg/HRIMessage", - expected_fields={"text": "Hello!"}, -) - -get_tohuman_interface_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_message_interface", - expected_args={"msg_type": "rai_interfaces/msg/HRIMessage"}, -) - - -return_true_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": True} -) -return_false_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": False} -) - -start_nav_action_subtask = CheckActionFieldsToolCallSubTask( - expected_tool_name="start_ros2_action", - expected_action="/navigate_to_pose", - expected_action_type="nav2_msgs/action/NavigateToPose", - expected_fields={ - "pose": { - "header": {"frame_id": "map"}, - "pose": { - "position": {"x": 2.0, "y": 2.0, "z": 0.0}, - }, - }, - }, -) -start_spin_action_subtask = CheckActionFieldsToolCallSubTask( - expected_tool_name="start_ros2_action", - expected_action="/spin", - expected_action_type="nav2_msgs/action/Spin", - expected_fields={"target_yaw": 3}, -) -start_move_front_action_subtask = CheckActionFieldsToolCallSubTask( - expected_tool_name="start_ros2_action", - expected_action="/drive_on_heading", - expected_action_type="nav2_msgs/action/DriveOnHeading", - expected_fields={ - "target": {"y": 0.0, "z": 0.0}, - }, -) -######### VALIDATORS ######################################################################################### -topics_ord_val = OrderedCallsValidator(subtasks=[get_topics_subtask]) -topics_and_color_image_ord_val = OrderedCallsValidator( - subtasks=[ - get_topics_subtask, - color_image5_subtask, - ] -) -color_image_ord_val = OrderedCallsValidator(subtasks=[color_image5_subtask]) -depth_image_ord_val = OrderedCallsValidator(subtasks=[depth_image5_subtask]) -all_color_images_notord_val = NotOrderedCallsValidator( - subtasks=[color_image5_subtask, color_image5_subtask] -) -all_depth_images_notord_val = NotOrderedCallsValidator( - subtasks=[depth_image5_subtask, depth_image5_subtask] -) - -move_to_point_ord_val_grab = OrderedCallsValidator( - subtasks=[move_to_point_subtask_grab] -) -move_to_point_ord_val_drop = OrderedCallsValidator( - subtasks=[move_to_point_subtask_drop] -) - -pub_HRIMessage_text_ord_val = OrderedCallsValidator( - subtasks=[pub_HRIMessage_text_subtask] -) - -list_topic_get_interface_publish_ord_val = OrderedCallsValidator( - subtasks=[ - get_topics_subtask, - get_tohuman_interface_subtask, - pub_HRIMessage_text_subtask, - ] -) - -ret_true_ord_val = OrderedCallsValidator(subtasks=[return_true_subtask]) -ret_false_ord_val = OrderedCallsValidator(subtasks=[return_false_subtask]) - -start_navigate_action_ord_val = OrderedCallsValidator( - subtasks=[start_nav_action_subtask] -) -start_spin_action_ord_val = OrderedCallsValidator(subtasks=[start_spin_action_subtask]) -move_ahead_ord_val = OrderedCallsValidator(subtasks=[start_move_front_action_subtask]) - -######### TASKS ############################################################################################ -basic_tasks: List[Task] = [ - # 3 options to validate same task: - # most strict, agent has to call both tool correctly to pass this validator - GetROS2RGBCameraTask(validators=[topics_and_color_image_ord_val]), - # verifing only if the GetCameraImage call was made properly - GetROS2RGBCameraTask(validators=[color_image_ord_val]), - # Soft verification. verifing in separate vaidators the list topic and get image. - # agent can get 0.5 score by only calling list topics - GetROS2RGBCameraTask(validators=[topics_ord_val, color_image_ord_val]), - # we can also add extra tool calls to allow model to correct itself - GetROS2RGBCameraTask( - validators=[topics_ord_val, color_image_ord_val], extra_tool_calls=3 - ), - GetROS2TopicsTask(validators=[topics_ord_val]), - GetROS2DepthCameraTask(validators=[depth_image_ord_val]), - GetAllROS2RGBCamerasTask(validators=[all_color_images_notord_val]), - GetAllROS2DepthCamerasTask(validators=[all_depth_images_notord_val]), -] -manipulation_tasks: List[Task] = [ - MoveToPointTask( - move_to_tool_input=MoveToPointToolInput(x=1.0, y=2.0, z=3.0, task="grab"), - validators=[move_to_point_ord_val_drop], - ), - MoveToPointTask( - move_to_tool_input=MoveToPointToolInput(x=1.2, y=2.3, z=3.4, task="drop"), - validators=[move_to_point_ord_val_drop], - ), -] - -custom_interfaces_tasks: List[Task] = [ - PublishROS2HRIMessageTextTask( - topic="/to_human", - text="Hello!", - validators=[pub_HRIMessage_text_ord_val], - extra_tool_calls=2, - ), - PublishROS2HRIMessageTextTask( - topic="/to_human", - text="Hello!", - validators=[list_topic_get_interface_publish_ord_val], - ), -] - -true_spatial_tasks: List[Task] = [ - BoolImageTask( - task_input=input_item, validators=[ret_true_ord_val], extra_tool_calls=0 - ) - for input_item in true_response_inputs -] -false_spatial_tasks: List[Task] = [ - BoolImageTask(task_input=input_item, validators=[ret_false_ord_val]) - for input_item in false_response_inputs -] - -navigation_tasks: List[Task] = [ - NavigateToPointTask(validators=[start_navigate_action_ord_val], extra_tool_calls=5), - SpinAroundTask(validators=[start_spin_action_ord_val], extra_tool_calls=5), - MoveToBedTask(validators=[move_ahead_ord_val], extra_tool_calls=5), - MoveToFrontTask(validators=[move_ahead_ord_val], extra_tool_calls=5), -] - - -def get_basic_tasks(extra_tool_calls: int = 0) -> List[Task]: - return [ - # 3 options to validate same task: - # most strict, agent has to call both tool correctly to pass this validator - GetROS2RGBCameraTask( - validators=[topics_and_color_image_ord_val], - extra_tool_calls=extra_tool_calls, - ), - # verifing only if the GetCameraImage call was made properly - GetROS2RGBCameraTask( - validators=[color_image_ord_val], extra_tool_calls=extra_tool_calls - ), - # Soft verification. verifing in separate vaidators the list topic and get image. - # agent can get 0.5 score by only calling list topics - GetROS2RGBCameraTask( - validators=[topics_ord_val, color_image_ord_val], - extra_tool_calls=extra_tool_calls, - ), - # we can also add extra tool calls to allow model to correct itself - GetROS2RGBCameraTask( - validators=[topics_ord_val, color_image_ord_val], - extra_tool_calls=extra_tool_calls, - ), - GetROS2TopicsTask( - validators=[topics_ord_val], extra_tool_calls=extra_tool_calls - ), - GetROS2DepthCameraTask( - validators=[depth_image_ord_val], extra_tool_calls=extra_tool_calls - ), - GetAllROS2RGBCamerasTask( - validators=[all_color_images_notord_val], extra_tool_calls=extra_tool_calls - ), - GetAllROS2DepthCamerasTask( - validators=[all_depth_images_notord_val], extra_tool_calls=extra_tool_calls - ), - ] - - -def get_navigation_tasks(extra_tool_calls: int = 0) -> List[Task]: - return [ - NavigateToPointTask( - validators=[start_navigate_action_ord_val], - extra_tool_calls=extra_tool_calls, - ), - SpinAroundTask( - validators=[start_spin_action_ord_val], - extra_tool_calls=extra_tool_calls, - ), - MoveToBedTask( - validators=[move_ahead_ord_val], - extra_tool_calls=extra_tool_calls, - ), - MoveToFrontTask( - validators=[move_ahead_ord_val], - extra_tool_calls=extra_tool_calls, - ), - ] - - -def get_manipulation_tasks(extra_tool_calls: int = 0) -> List[Task]: - return [ - MoveToPointTask( - move_to_tool_input=MoveToPointToolInput(x=1.0, y=2.0, z=3.0, task="grab"), - validators=[move_to_point_ord_val_grab], - extra_tool_calls=extra_tool_calls, - ), - MoveToPointTask( - move_to_tool_input=MoveToPointToolInput(x=1.2, y=2.3, z=3.4, task="drop"), - validators=[move_to_point_ord_val_drop], - extra_tool_calls=extra_tool_calls, - ), - ] - - -def get_custom_interfaces_tasks(extra_tool_calls: int = 0) -> List[Task]: - return [ - PublishROS2HRIMessageTextTask( - topic="/to_human", - text="Hello!", - validators=[pub_HRIMessage_text_ord_val], - extra_tool_calls=extra_tool_calls, - ), - PublishROS2HRIMessageTextTask( - topic="/to_human", - text="Hello!", - validators=[list_topic_get_interface_publish_ord_val], - extra_tool_calls=extra_tool_calls, - ), - ] - - -def get_spatial_tasks(extra_tool_calls: int = 0) -> Sequence[Task]: - true_tasks = [ - BoolImageTask( - task_input=input_item, - validators=[ret_true_ord_val], - extra_tool_calls=extra_tool_calls, - ) - for input_item in true_response_inputs - ] - false_tasks = [ - BoolImageTask( - task_input=input_item, - validators=[ret_false_ord_val], - extra_tool_calls=extra_tool_calls, - ) - for input_item in false_response_inputs - ] - return true_tasks + false_tasks - def get_tasks( - extra_tool_calls: int = 0, + extra_tool_calls: List[int] = [0], complexities: List[Literal["easy", "medium", "hard"]] = ["easy", "medium", "hard"], + prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], + n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], task_types: List[ Literal[ "basic", @@ -430,18 +48,58 @@ def get_tasks( "spatial_reasoning", ], ) -> List[Task]: - # TODO (jmatejcz) implement complexity sorting - tasks: List[Task] = [] + """Get a list of tasks based on the provided configuration. + + Parameters + ---------- + Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. + See the class documentation for parameter descriptions. + + Returns + ------- + List[Task] + sequence of spatial reasoning tasks with varying difficulty levels. + There will be every combination of extra_tool_calls x prompt_detail x n_shots tasks generated. + """ + + all_tasks: List[Task] = [] if "basic" in task_types: - tasks += get_basic_tasks(extra_tool_calls=extra_tool_calls) + all_tasks += get_basic_tasks( + extra_tool_calls=extra_tool_calls, + prompt_detail=prompt_detail, + n_shots=n_shots, + ) if "custom_interfaces" in task_types: - tasks += get_custom_interfaces_tasks(extra_tool_calls=extra_tool_calls) + all_tasks += get_custom_interfaces_tasks( + extra_tool_calls=extra_tool_calls, + prompt_detail=prompt_detail, + n_shots=n_shots, + ) if "manipulation" in task_types: - tasks += get_manipulation_tasks(extra_tool_calls=extra_tool_calls) + all_tasks += get_manipulation_tasks( + extra_tool_calls=extra_tool_calls, + prompt_detail=prompt_detail, + n_shots=n_shots, + ) if "navigation" in task_types: - tasks += get_navigation_tasks(extra_tool_calls=extra_tool_calls) + all_tasks += get_navigation_tasks( + extra_tool_calls=extra_tool_calls, + prompt_detail=prompt_detail, + n_shots=n_shots, + ) if "spatial_reasoning" in task_types: - tasks += get_spatial_tasks(extra_tool_calls=extra_tool_calls) + all_tasks += get_spatial_tasks( + extra_tool_calls=extra_tool_calls, + prompt_detail=prompt_detail, + n_shots=n_shots, + ) + + filtered_tasks: List[Task] = [] + for task in all_tasks: + if task.complexity not in complexities: + continue + + filtered_tasks.append(task) - random.shuffle(tasks) - return tasks + random.shuffle(all_tasks) + return all_tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/results_tracking.py b/src/rai_bench/rai_bench/tool_calling_agent/results_tracking.py index 7d6a1a1d9..ba0034fa1 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/results_tracking.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/results_tracking.py @@ -41,8 +41,12 @@ class ValidatorResult(BaseModel): class TaskResult(BaseModel): - task_prompt: str = Field(..., description="The task prompt.") + task_prompt: str = Field(..., description="The task base prompt.") system_prompt: str = Field(..., description="The system prompt.") + examples_in_system_prompt: int = Field( + ..., description="Number of examples on how to use tool in system prompt" + ) + prompt_detail: str = Field(..., description="How detailed the task prompt is.") complexity: str = Field(..., description="Complexity of the task.") type: str = Field(..., description="Type of task, for example: manipulation") model_name: str = Field(..., description="Name of the LLM.") diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py index 511cf2c93..9eb1a159c 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py @@ -14,328 +14,206 @@ import logging from abc import ABC -from typing import Dict, List +from typing import List -import inflect from langchain_core.tools import BaseTool -from rai.types import Point from rai_bench.tool_calling_agent.interfaces import ( Task, - Validator, ) +from rai_bench.tool_calling_agent.mocked_ros2_interfaces import COMMON_TOPICS_AND_TYPES from rai_bench.tool_calling_agent.mocked_tools import ( - MockGetObjectPositionsTool, MockGetROS2ImageTool, MockGetROS2TopicsNamesAndTypesTool, MockReceiveROS2MessageTool, ) loggers_type = logging.Logger +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT = """You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. +Be proactive and use the tools to answer questions.""" - -PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT = """You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. -Be proactive and use the tools to answer questions. +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT = ( + PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT + + """ Example of tool calls: - get_ros2_message_interface, args: {'msg_type': 'geometry_msgs/msg/Twist'} -- publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}} -- get_ros2_message_interface, args: {'msg_type': 'turtlesim/srv/TeleportAbsolute'} +- publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}}""" +) + +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT = ( + PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + + """ +- get_ros2_topics_names_and_types, args: {} +- get_ros2_image, args: {'topic': '/camera/image_raw', 'timeout_sec': 10} - publish_ros2_message, args: {'topic': '/turtle1/teleport_absolute', 'message_type': 'turtlesim/srv/TeleportAbsolute', 'message': {x: 5.0, y: 2.0, theta: 1.57}}""" +) -CAMERA_TOPICS_AND_TYPES = [ - "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_image5\ntype: sensor_msgs/msg/Image\n", - "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n", -] -CAMERA_TOPICS = [ - "/color_camera_info5", - "/color_image5", - "/depth_camera_info5", - "/depth_image5", +TOPIC_STRINGS = [ + f"topic: {topic}\ntype: {topic_type}\n" + for topic, topic_type in COMMON_TOPICS_AND_TYPES.items() ] class BasicTask(Task, ABC): - @property - def type(self) -> str: - return "basic" - - def get_system_prompt(self) -> str: - return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT - - -class GetROS2TopicsTask(BasicTask): - complexity = "easy" + type = "basic" @property def available_tools(self) -> List[BaseTool]: return [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /display_contacts\ntype: visualization_msgs/msg/MarkerArray\n", - "topic: /display_planned_path\ntype: moveit_msgs/msg/DisplayTrajectory\n", - "topic: /execute_trajectory/_action/feedback\ntype: moveit_msgs/action/ExecuteTrajectory_FeedbackMessage\n", - "topic: /execute_trajectory/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /joint_states\ntype: sensor_msgs/msg/JointState\n", - "topic: /monitored_planning_scene\ntype: moveit_msgs/msg/PlanningScene\n", - "topic: /motion_plan_request\ntype: moveit_msgs/msg/MotionPlanRequest\n", - "topic: /move_action/_action/feedback\ntype: moveit_msgs/action/MoveGroup_FeedbackMessage\n", - "topic: /move_action/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /panda_arm_controller/follow_joint_trajectory/_action/feedback\ntype: control_msgs/action/FollowJointTrajectory_FeedbackMessage\n", - "topic: /panda_arm_controller/follow_joint_trajectory/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /panda_hand_controller/gripper_cmd/_action/feedback\ntype: control_msgs/action/GripperCommand_FeedbackMessage\n", - "topic: /panda_hand_controller/gripper_cmd/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /parameter_events\ntype: rcl_interfaces/msg/ParameterEvent\n", - "topic: /planning_scene\ntype: moveit_msgs/msg/PlanningScene\n", - "topic: /planning_scene_world\ntype: moveit_msgs/msg/PlanningSceneWorld\n", - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /robot_description_semantic\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /trajectory_execution_event\ntype: std_msgs/msg/String\n", - ] - + CAMERA_TOPICS_AND_TYPES - ) + mock_topics_names_and_types=TOPIC_STRINGS + ), + MockGetROS2ImageTool(available_topics=list(COMMON_TOPICS_AND_TYPES.keys())), + MockReceiveROS2MessageTool( + available_topics=list(COMMON_TOPICS_AND_TYPES.keys()) + ), ] - def get_prompt(self) -> str: - return "Get the names and types of all ROS2 topics" - - -class GetROS2TopicsTask2(BasicTask): - complexity = "easy" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - ] - + CAMERA_TOPICS_AND_TYPES - ) - ] + def optional_tool_calls_number(self) -> int: + # Listing topics before getting any message + return 1 - def get_prompt(self) -> str: - return "What is in the ROS2 network?" + def get_system_prompt(self) -> str: + if self.n_shots == 0: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT + elif self.n_shots == 2: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + else: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT -class GetROS2RGBCameraTask(BasicTask): +class GetROS2TopicsTask(BasicTask): complexity = "easy" @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - ] - + CAMERA_TOPICS_AND_TYPES - ), - MockGetROS2ImageTool(available_topics=CAMERA_TOPICS), - ] + def optional_tool_calls_number(self) -> int: + return 0 + + def get_base_prompt(self) -> str: + return "Get all topics" def get_prompt(self) -> str: - return "Get the RGB image from the camera." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} available in the ROS2 system with their names and message types. " + "You can discover what topics are currently active." + ) -class GetROS2DepthCameraTask(BasicTask): +class GetROS2RGBCameraTask(BasicTask): complexity = "easy" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - ] - + CAMERA_TOPICS_AND_TYPES - ), - MockGetROS2ImageTool(available_topics=CAMERA_TOPICS), - ] + def get_base_prompt(self) -> str: + return "Get RGB camera image." def get_prompt(self) -> str: - return "Get the depth image from the camera." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can explore available camera topics and capture the RGB color image." + ) -class GetAllROS2RGBCamerasTask(BasicTask): +class GetROS2DepthCameraTask(BasicTask): complexity = "easy" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - ] - + CAMERA_TOPICS_AND_TYPES - ), - MockGetROS2ImageTool(available_topics=CAMERA_TOPICS), - ] + def get_base_prompt(self) -> str: + return "Get depth camera image." def get_prompt(self) -> str: - return "Get RGB images from all of the available cameras." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can explore available camera topics and capture the depth image data." + ) -class GetAllROS2DepthCamerasTask(BasicTask): +class GetPointcloudTask(BasicTask): complexity = "easy" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - ] - + CAMERA_TOPICS_AND_TYPES - ), - MockGetROS2ImageTool(available_topics=CAMERA_TOPICS), - ] + def get_base_prompt(self) -> str: + return "Get the pointcloud data." def get_prompt(self) -> str: - return "Get depth images from all of the available cameras." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover available sensor topics and receive the pointcloud information." + ) -# NOTE (jm) is this task redundant? -class GetROS2MessageTask(BasicTask): +class GetRobotDescriptionTask(BasicTask): complexity = "easy" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - ] - + CAMERA_TOPICS_AND_TYPES - ), - MockReceiveROS2MessageTool(available_topics=CAMERA_TOPICS), - ] - - def get_system_prompt(self) -> str: - return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + def get_base_prompt(self) -> str: + return "Get robot description." def get_prompt(self) -> str: - return "Get RGB image." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} You can explore the system " + "to find robot description data." + ) -class GetRobotDescriptionTask(BasicTask): - complexity = "easy" +class GetAllROS2CamerasTask(BasicTask): + complexity = "medium" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /trajectory_execution_event\ntype: std_msgs/msg/String\n", - ] - ), - MockReceiveROS2MessageTool(available_topics=["/robot_description"]), - ] + def get_base_prompt(self) -> str: + return "Get all camera images" def get_prompt(self) -> str: - return "Give me description of the robot." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} from all available camera sources in the system. " + "This includes both RGB color images and depth images. " + "You can discover what camera topics are available and capture images from each." + ) -class GetPointcloudTask(BasicTask): - complexity = "easy" +class CheckRobotHealthTask(BasicTask): + complexity = "medium" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /trajectory_execution_event\ntype: std_msgs/msg/String\n", - ] - ), - MockReceiveROS2MessageTool(available_topics=["/pointcloud"]), - ] + def get_base_prompt(self) -> str: + return "Check robot health status" def get_prompt(self) -> str: - return "Get the pointcloud." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} by examining system diagnostics and monitoring data. " + "You can explore available diagnostic topics and gather information " + "about robot health, joint states, and system logs." + ) -class GetObjectPositionsTask(Task): - complexity = "easy" +class AssessSensorDataQualityTask(BasicTask): + complexity = "hard" - def __init__( - self, - objects: Dict[str, List[Point]], - validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, - ) -> None: - super().__init__( - validators=validators, extra_tool_calls=extra_tool_calls, logger=logger - ) - """Task to get the positions of the objects - - Examples - -------- - objects = { - "banana": [(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], - "cube": [(0.7, 0.8, 0.9)], - } - """ - self.objects = objects - - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - ] - ), - MockGetObjectPositionsTool(mock_objects=self.objects), - ] + def get_base_prompt(self) -> str: + return "Assess sensor data quality" def get_prompt(self) -> str: - """Generates a prompt based on the objects provided in the task. If there is more than one object, the object in the prompt will be pluralized. - Returns: - str: Formatted prompt for the task - """ - inflector = inflect.engine() - object_counts = {obj: len(positions) for obj, positions in self.objects.items()} - formatted_objects = [ - inflector.plural(obj) if count > 1 else obj - for obj, count in object_counts.items() - ] - if len(formatted_objects) > 1: - objects_list = ( - ", ".join(formatted_objects[:-1]) + f", and {formatted_objects[-1]}" - ) + if self.prompt_detail == "brief": + return self.get_base_prompt() else: - objects_list = formatted_objects[0] - return f"Get the {objects_list} positions." + return ( + f"{self.get_base_prompt()} across all available sensors in the robot system. " + "You can explore sensor topics and gather data from various sources " + "including laser scans, cameras, pointclouds, and odometry to evaluate " + "overall sensor performance." + ) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py index 2d9e36360..0cff01ff7 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py @@ -14,16 +14,13 @@ import logging from abc import ABC -from typing import Any, Dict, List, Type +from typing import Any, List from langchain_core.tools import BaseTool -from pydantic import BaseModel from rai.types import ( BoundingBox2D, - CameraInfo, Detection2D, Header, - Image, Point, Pose, Pose2D, @@ -32,20 +29,21 @@ Time, ) from rai.types.rai_interfaces import ( - ManipulatorMoveToRequest, RAIDetectionArray, - RAIGroundedSamRequest, - RAIGroundingDinoRequest, ) -from rai_bench.tool_calling_agent.interfaces import Task, Validator -from rai_bench.tool_calling_agent.messages.base import Clock -from rai_bench.tool_calling_agent.messages.services import ( - StringListRequest, - VectorStoreRetrievalRequest, - WhatISeeRequest, +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator +from rai_bench.tool_calling_agent.mocked_ros2_interfaces import ( + COMMON_INTERFACES, + COMMON_SERVICES_AND_TYPES, + COMMON_TOPIC_MODELS, + COMMON_TOPICS_AND_TYPES, + CUSTOM_INTERFACES, + CUSTOM_SERVICE_MODELS, + CUSTOM_SERVICES_AND_TYPES, + CUSTOM_TOPIC_MODELS, + CUSTOM_TOPICS_AND_TYPES, ) -from rai_bench.tool_calling_agent.messages.topics import AudioMessage, HRIMessage from rai_bench.tool_calling_agent.mocked_tools import ( MockCallROS2ServiceTool, MockGetROS2MessageInterfaceTool, @@ -55,764 +53,17 @@ ) loggers_type = logging.Logger +INTERFACES = COMMON_INTERFACES | CUSTOM_INTERFACES -# dict of interfaces where keys are interfaces types and values are output -# of GetROS2MessageInterfaceTool which are same as ros2 interface show outputs -# the dict contains custom as well as couple other common interfaces -MOCK_INTERFACES: Dict[str, str] = { - "sensor_msgs/msg/CameraInfo": """ -# This message defines meta information for a camera. It should be in a -# camera namespace on topic "camera_info" and accompanied by up to five -# image topics named: -# -# image_raw - raw data from the camera driver, possibly Bayer encoded -# image - monochrome, distorted -# image_color - color, distorted -# image_rect - monochrome, rectified -# image_rect_color - color, rectified -# -# The image_pipeline contains packages (image_proc, stereo_image_proc) -# for producing the four processed image topics from image_raw and -# camera_info. The meaning of the camera parameters are described in -# detail at http://www.ros.org/wiki/image_pipeline/CameraInfo. -# -# The image_geometry package provides a user-friendly interface to -# common operations using this meta information. If you want to, e.g., -# project a 3d point into image coordinates, we strongly recommend -# using image_geometry. -# -# If the camera is uncalibrated, the matrices D, K, R, P should be left -# zeroed out. In particular, clients may assume that K[0] == 0.0 -# indicates an uncalibrated camera. - -####################################################################### -# Image acquisition info # -####################################################################### - -# Time of image acquisition, camera coordinate frame ID -std_msgs/Header header # Header timestamp should be acquisition time of image - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - # Header frame_id should be optical frame of camera - # origin of frame should be optical center of camera - # +x should point to the right in the image - # +y should point down in the image - # +z should point into the plane of the image - - -####################################################################### -# Calibration Parameters # -####################################################################### -# These are fixed during camera calibration. Their values will be the # -# same in all messages until the camera is recalibrated. Note that # -# self-calibrating systems may "recalibrate" frequently. # -# # -# The internal parameters can be used to warp a raw (distorted) image # -# to: # -# 1. An undistorted image (requires D and K) # -# 2. A rectified image (requires D, K, R) # -# The projection matrix P projects 3D points into the rectified image.# -####################################################################### - -# The image dimensions with which the camera was calibrated. -# Normally this will be the full camera resolution in pixels. -uint32 height -uint32 width - -# The distortion model used. Supported models are listed in -# sensor_msgs/distortion_models.hpp. For most cameras, "plumb_bob" - a -# simple model of radial and tangential distortion - is sufficent. -string distortion_model - -# The distortion parameters, size depending on the distortion model. -# For "plumb_bob", the 5 parameters are: (k1, k2, t1, t2, k3). -float64[] d - -# Intrinsic camera matrix for the raw (distorted) images. -# [fx 0 cx] -# K = [ 0 fy cy] -# [ 0 0 1] -# Projects 3D points in the camera coordinate frame to 2D pixel -# coordinates using the focal lengths (fx, fy) and principal point -# (cx, cy). -float64[9] k # 3x3 row-major matrix - -# Rectification matrix (stereo cameras only) -# A rotation matrix aligning the camera coordinate system to the ideal -# stereo image plane so that epipolar lines in both stereo images are -# parallel. -float64[9] r # 3x3 row-major matrix - -# Projection/camera matrix -# [fx' 0 cx' Tx] -# P = [ 0 fy' cy' Ty] -# [ 0 0 1 0] -# By convention, this matrix specifies the intrinsic (camera) matrix -# of the processed (rectified) image. That is, the left 3x3 portion -# is the normal camera intrinsic matrix for the rectified image. -# It projects 3D points in the camera coordinate frame to 2D pixel -# coordinates using the focal lengths (fx', fy') and principal point -# (cx', cy') - these may differ from the values in K. -# For monocular cameras, Tx = Ty = 0. Normally, monocular cameras will -# also have R = the identity and P[1:3,1:3] = K. -# For a stereo pair, the fourth column [Tx Ty 0]' is related to the -# position of the optical center of the second camera in the first -# camera's frame. We assume Tz = 0 so both cameras are in the same -# stereo image plane. The first camera always has Tx = Ty = 0. For -# the right (second) camera of a horizontal stereo pair, Ty = 0 and -# Tx = -fx' * B, where B is the baseline between the cameras. -# Given a 3D point [X Y Z]', the projection (x, y) of the point onto -# the rectified image is given by: -# [u v w]' = P * [X Y Z 1]' -# x = u / w -# y = v / w -# This holds for both images of a stereo pair. -float64[12] p # 3x4 row-major matrix - - -####################################################################### -# Operational Parameters # -####################################################################### -# These define the image region actually captured by the camera # -# driver. Although they affect the geometry of the output image, they # -# may be changed freely without recalibrating the camera. # -####################################################################### - -# Binning refers here to any camera setting which combines rectangular -# neighborhoods of pixels into larger "super-pixels." It reduces the -# resolution of the output image to -# (width / binning_x) x (height / binning_y). -# The default values binning_x = binning_y = 0 is considered the same -# as binning_x = binning_y = 1 (no subsampling). -uint32 binning_x -uint32 binning_y - -# Region of interest (subwindow of full camera resolution), given in -# full resolution (unbinned) image coordinates. A particular ROI -# always denotes the same window of pixels on the camera sensor, -# regardless of binning settings. -# The default setting of roi (all values 0) is considered the same as -# full resolution (roi.width = width, roi.height = height). -RegionOfInterest roi - # - uint32 x_offset # - # (0 if the ROI includes the left edge of the image) - uint32 y_offset # - # (0 if the ROI includes the top edge of the image) - uint32 height # - uint32 width # - bool do_rectify -""", - "sensor_msgs/msg/Image": """ -# This message contains an uncompressed image -# (0, 0) is at top-left corner of image - -std_msgs/Header header # Header timestamp should be acquisition time of image - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - # Header frame_id should be optical frame of camera - # origin of frame should be optical center of cameara - # +x should point to the right in the image - # +y should point down in the image - # +z should point into to plane of the image - # If the frame_id here and the frame_id of the CameraInfo - # message associated with the image conflict - # the behavior is undefined - -uint32 height # image height, that is, number of rows -uint32 width # image width, that is, number of columns - -# The legal values for encoding are in file src/image_encodings.cpp -# If you want to standardize a new string format, join -# ros-users@lists.ros.org and send an email proposing a new encoding. - -string encoding # Encoding of pixels -- channel meaning, ordering, size - # taken from the list of strings in include/sensor_msgs/image_encodings.hpp - -uint8 is_bigendian # is this data bigendian? -uint32 step # Full row length in bytes -uint8[] data # actual matrix data, size is (step * rows) -""", - "rosgraph_msgs/msg/Clock": """ -# This message communicates the current time. -# -# For more information, see https://design.ros2.org/articles/clock_and_time.html. -builtin_interfaces/Time clock - int32 sec - uint32 nanosec -""", - "rai_interfaces/msg/HRIMessage": """ -# -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id -string text -sensor_msgs/Image[] images - std_msgs/Header header # - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - # Header frame_id should be optical frame of camera - # origin of frame should be optical center of cameara - # +x should point to the right in the image - # +y should point down in the image - # +z should point into to plane of the image - # If the frame_id here and the frame_id of the CameraInfo - # message associated with the image conflict - # the behavior is undefined - uint32 height # - uint32 width # - string encoding # - # taken from the list of strings in include/sensor_msgs/image_encodings.hpp - uint8 is_bigendian # - uint32 step # - uint8[] data # -rai_interfaces/AudioMessage[] audios - # - # - # - # - # - int16[] audio - uint16 sample_rate - uint16 channels -string communication_id -int64 seq_no -bool seq_end -""", - "rai_interfaces/msg/AudioMessage": """ -# -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -int16[] audio -uint16 sample_rate -uint16 channels -""", - "rai_interfaces/msg/RAIDetectionArray": """ -# -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# A list of 2D detections, for a multi-object 2D detector. -std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - -# A list of the detected proposals. A multi-proposal detector might generate -# this list with many candidate detections generated from a single input. -vision_msgs/Detection2D[] detections - # - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - ObjectHypothesisWithPose[] results - ObjectHypothesis hypothesis - string class_id - float64 score - geometry_msgs/PoseWithCovariance pose - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 - float64[36] covariance - BoundingBox2D bbox - vision_msgs/Pose2D center - float64 x - float64 y - float64 theta - float64 size_x - float64 size_y - string id -# a list of classes being detected -string[] detection_classes -""", - "rai_interfaces/srv/ManipulatorMoveTo": """ -# -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# A simplified approach with binary states for the gripper -bool initial_gripper_state -bool final_gripper_state -geometry_msgs/PoseStamped target_pose - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 ---- -bool success -""", - "rai_interfaces/srv/RAIGroundedSam": """ -# -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -RAIDetectionArray detections - # - # - # - # - # - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - vision_msgs/Detection2D[] detections - # - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - ObjectHypothesisWithPose[] results - ObjectHypothesis hypothesis - string class_id - float64 score - geometry_msgs/PoseWithCovariance pose - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 - float64[36] covariance - BoundingBox2D bbox - vision_msgs/Pose2D center - float64 x - float64 y - float64 theta - float64 size_x - float64 size_y - string id - string[] detection_classes -sensor_msgs/Image source_img - std_msgs/Header header # - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - # Header frame_id should be optical frame of camera - # origin of frame should be optical center of cameara - # +x should point to the right in the image - # +y should point down in the image - # +z should point into to plane of the image - # If the frame_id here and the frame_id of the CameraInfo - # message associated with the image conflict - # the behavior is undefined - uint32 height # - uint32 width # - string encoding # - # taken from the list of strings in include/sensor_msgs/image_encodings.hpp - uint8 is_bigendian # - uint32 step # - uint8[] data # ---- -sensor_msgs/Image[] masks - std_msgs/Header header # - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - # Header frame_id should be optical frame of camera - # origin of frame should be optical center of cameara - # +x should point to the right in the image - # +y should point down in the image - # +z should point into to plane of the image - # If the frame_id here and the frame_id of the CameraInfo - # message associated with the image conflict - # the behavior is undefined - uint32 height # - uint32 width # - string encoding # - # taken from the list of strings in include/sensor_msgs/image_encodings.hpp - uint8 is_bigendian # - uint32 step # - uint8[] data # -""", - "rai_interfaces/srv/RAIGroundingDino": """ -# -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -string classes -float64 box_threshold -float64 text_threshold -sensor_msgs/Image source_img - std_msgs/Header header # - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - # Header frame_id should be optical frame of camera - # origin of frame should be optical center of cameara - # +x should point to the right in the image - # +y should point down in the image - # +z should point into to plane of the image - # If the frame_id here and the frame_id of the CameraInfo - # message associated with the image conflict - # the behavior is undefined - uint32 height # - uint32 width # - string encoding # - # taken from the list of strings in include/sensor_msgs/image_encodings.hpp - uint8 is_bigendian # - uint32 step # - uint8[] data # ---- -RAIDetectionArray detections - # - # - # - # - # - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - vision_msgs/Detection2D[] detections - # - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - ObjectHypothesisWithPose[] results - ObjectHypothesis hypothesis - string class_id - float64 score - geometry_msgs/PoseWithCovariance pose - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 - float64[36] covariance - BoundingBox2D bbox - vision_msgs/Pose2D center - float64 x - float64 y - float64 theta - float64 size_x - float64 size_y - string id - string[] detection_classes -""", - "rai_interfaces/srv/StringList": """ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Request - empty ---- -# Response -bool success -string[] string_list -""", - "rai_interfaces/srv/VectorStoreRetrieval": """ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# +TOPICS_AND_TYPES = COMMON_TOPICS_AND_TYPES | CUSTOM_TOPICS_AND_TYPES +TOPIC_MODELS = COMMON_TOPIC_MODELS | CUSTOM_TOPIC_MODELS -# Request -string query - ---- -# Response -bool success -string message -string[] documents -float32[] scores -""", - "rai_interfaces/srv/WhatISee": """z -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Request (empty) - ---- -# Response, timed with image timestamp -string[] observations -string perception_source -sensor_msgs/Image image - std_msgs/Header header # - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - # Header frame_id should be optical frame of camera - # origin of frame should be optical center of cameara - # +x should point to the right in the image - # +y should point down in the image - # +z should point into to plane of the image - # If the frame_id here and the frame_id of the CameraInfo - # message associated with the image conflict - # the behavior is undefined - uint32 height # - uint32 width # - string encoding # - # taken from the list of strings in include/sensor_msgs/image_encodings.hpp - uint8 is_bigendian # - uint32 step # - uint8[] data # -geometry_msgs/Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -""", - "rai_interfaces/action/Task": """ -# Goal -string task -string description -string priority - ---- -# Result -bool success -string report - ---- -# Feedback -string current_status -""", - "/load_map": """ -string filename ---- -bool success -""", - "/query_planner_interface": """ ---- - -# The planning instances that could be used in the benchmark -PlannerInterfaceDescription[] planner_interfaces - string name - string pipeline_id - string[] planner_ids - -""", -} - - -SERVICES_AND_TYPES = { - # sample interfaces - # "/load_map": "moveit_msgs/srv/LoadMap", - # "/query_planner_interface": "moveit_msgs/srv/QueryPlannerInterfaces", - # custom interfaces - "/manipulator_move_to": "rai_interfaces/srv/ManipulatorMoveTo", - "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam", - "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino", - "/get_log_digest": "rai_interfaces/srv/StringList", - "/rai_whoami_documentation_service": "rai_interfaces/srv/VectorStoreRetrieval", - "/rai/whatisee/get": "rai_interfaces/srv/WhatISee", -} - -SERVICE_MODELS: Dict[str, Type[BaseModel]] = { - "rai_interfaces/srv/ManipulatorMoveTo": ManipulatorMoveToRequest, - "rai_interfaces/srv/RAIGroundedSam": RAIGroundedSamRequest, - "rai_interfaces/srv/RAIGroundingDino": RAIGroundingDinoRequest, - "rai_interfaces/srv/StringList": StringListRequest, - "rai_interfaces/srv/VectorStoreRetrieval": VectorStoreRetrievalRequest, - "rai_interfaces/srv/WhatISee": WhatISeeRequest, -} - -TOPICS_AND_TYPES: Dict[str, str] = { - # sample topics - "/camera_image_color": "sensor_msgs/msg/Image", - "/camera_image_depth": "sensor_msgs/msg/Image", - "/clock": "rosgraph_msgs/msg/Clock", - "/color_camera_info": "sensor_msgs/msg/CameraInfo", - "/color_camera_info5": "sensor_msgs/msg/CameraInfo", - "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", - "/depth_image5": "sensor_msgs/msg/Image", - # custom topics - "/to_human": "rai_interfaces/msg/HRIMessage", - "/send_audio": "rai_interfaces/msg/AudioMessage", - "/send_detections": "rai_interfaces/msg/RAIDetectionArray", -} - -ACTIONS_AND_TYPES = { - # custom actions - "/perform_task": "rai_interfaces/action/Task", - # some sample actions - # "/execute_trajectory": "moveit_msgs/action/ExecuteTrajectory", - # "/move_action": "moveit_msgs/action/MoveGroup", - # "/follow_joint_trajectory": "control_msgs/action/FollowJointTrajectory", - # "/gripper_cmd": "control_msgs/action/GripperCommand", -} +SERVICES_AND_TYPES = COMMON_SERVICES_AND_TYPES | CUSTOM_SERVICES_AND_TYPES TOPIC_STRINGS = [ f"topic: {topic}\ntype: {msg_type}\n" for topic, msg_type in TOPICS_AND_TYPES.items() ] -TOPIC_MODELS: Dict[str, Type[BaseModel]] = { - "sensor_msgs/msg/CameraInfo": CameraInfo, - "sensor_msgs/msg/Image": Image, - "rosgraph_msgs/msg/Clock": Clock, - "rai_interfaces/msg/HRIMessage": HRIMessage, - "rai_interfaces/msg/AudioMessage": AudioMessage, - "rai_interfaces/msg/RAIDetectionArray": RAIDetectionArray, -} - -IMAGE_TOPICS: Dict[str, str] = { - "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", - "/camera_image_color": "sensor_msgs/msg/Image", - "/camera_image_depth": "sensor_msgs/msg/Image", - "/clock": "rosgraph_msgs/msg/Clock", - "/collision_object": "moveit_msgs/msg/CollisionObject", - "/color_camera_info": "sensor_msgs/msg/CameraInfo", - "/color_camera_info5": "sensor_msgs/msg/CameraInfo", - "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", -} + SERVICE_STRINGS = [ f"service: {service}\ntype: {msg_type}\n" @@ -820,34 +71,49 @@ ] -PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT = """You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. -Be proactive and use the tools to answer questions. +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT = """You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. +Be proactive and use the tools to answer questions.""" + +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT = ( + PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT + + """ Example of tool calls: - get_ros2_message_interface, args: {'msg_type': 'geometry_msgs/msg/Twist'} -- publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}} -- get_ros2_message_interface, args: {'msg_type': 'turtlesim/srv/TeleportAbsolute'} -- publish_ros2_message, args: {'topic': '/turtle1/teleport_absolute', 'message_type': 'turtlesim/srv/TeleportAbsolute', 'message': {x: 5.0, y: 2.0, theta: 1.57}}""" +- publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}}""" +) + +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT = ( + PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + + """ +- get_ros2_topics_names_and_types, args: {} +- get_ros2_message_interface, args: {'msg_type': 'rai_interfaces/msg/HRIMessage'} +- call_ros2_service, args: {'service': '/grounding_dino_classify', 'service_type': 'rai_interfaces/srv/RAIGroundingDino', 'request': {'classes': 'bottle, book', 'box_threshold': 0.4, 'text_threshold': 0.25}}""" +) class CustomInterfaceTask(Task, ABC): + type = "custom_interface" + @property - def type(self) -> str: - return "custom_interface" + def optional_tool_calls_number(self) -> int: + # list topics + # get interface is not optional + return 1 + + def get_system_prompt(self) -> str: + if self.n_shots == 0: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT + elif self.n_shots == 2: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + else: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT class CustomInterfacesTopicTask(CustomInterfaceTask, ABC): def __init__( - self, - topic: str, - validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, + self, topic: str, validators: List[Validator], task_args: TaskArgs ) -> None: - super().__init__( - validators=validators, - extra_tool_calls=extra_tool_calls, - logger=logger, - ) + super().__init__(validators=validators, task_args=task_args) self.topic = topic @property @@ -856,7 +122,7 @@ def available_tools(self) -> List[BaseTool]: MockGetROS2TopicsNamesAndTypesTool( mock_topics_names_and_types=TOPIC_STRINGS ), - MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES), + MockGetROS2MessageInterfaceTool(mock_interfaces=INTERFACES), MockPublishROS2MessageTool( available_topics=list(TOPICS_AND_TYPES.keys()), available_message_types=list(TOPICS_AND_TYPES.values()), @@ -864,9 +130,6 @@ def available_tools(self) -> List[BaseTool]: ), ] - def get_system_prompt(self) -> str: - return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT - class CustomInterfacesServiceTask(CustomInterfaceTask, ABC): def __init__( @@ -874,14 +137,9 @@ def __init__( service: str, service_args: dict[str, Any], validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, + task_args: TaskArgs, ) -> None: - super().__init__( - validators=validators, - extra_tool_calls=extra_tool_calls, - logger=logger, - ) + super().__init__(validators=validators, task_args=task_args) self.service = service self.service_args = service_args @@ -891,211 +149,337 @@ def available_tools(self) -> List[BaseTool]: MockGetROS2ServicesNamesAndTypesTool( mock_service_names_and_types=SERVICE_STRINGS ), - MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES), + MockGetROS2MessageInterfaceTool(mock_interfaces=INTERFACES), MockCallROS2ServiceTool( available_services=list(SERVICES_AND_TYPES.keys()), available_service_types=list(SERVICES_AND_TYPES.values()), - available_service_models=SERVICE_MODELS, + available_service_models=CUSTOM_SERVICE_MODELS, ), ] -# TODO (jm) add actions Tasks - - -# TODO (jm) should we and how to parametrize these classes? class PublishROS2HRIMessageTextTask(CustomInterfacesTopicTask): complexity = "easy" def __init__( self, topic: str, - text: str, validators: List[Validator], - extra_tool_calls: int = 0, - logger: logging.Logger | None = None, + task_args: TaskArgs, + text: str = "Hello!", ) -> None: - super().__init__(topic, validators, extra_tool_calls, logger) + super().__init__(topic, validators=validators, task_args=task_args) self.text = text + def get_base_prompt(self) -> str: + return f"Publish message to topic '{self.topic}' with text: '{self.text}'." + def get_prompt(self) -> str: - return ( - f"You need to publish a message to the topic '{self.topic}' with the text value: '{self.text}'.\n" - "Before publishing, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 topics and their message types.\n" - f"2. Find the message type for the topic '{self.topic}'.\n" - "3. Retrieve the full message interface definition for that type.\n" - "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Publish the message to '{self.topic}' using the correct message type and interface.\n" - ) + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover available topics, examine the message interface " + f"structure, and publish an HRI message containing the text '{self.text}'." + ) class PublishROS2AudioMessageTask(CustomInterfacesTopicTask): complexity = "easy" - expected_audio: List[int] = [123, 456, 789] - expected_sample_rate: int = 44100 - expected_channels: int = 2 - def get_prompt(self) -> str: + def __init__( + self, + topic: str, + validators: List[Validator], + task_args: TaskArgs, + audio: List[int] = [123, 456, 789], + sample_rate: int = 44100, + channels: int = 2, + ) -> None: + super().__init__(topic, validators=validators, task_args=task_args) + self.expected_audio = audio + self.expected_sample_rate = sample_rate + self.expected_channels = channels + + def get_base_prompt(self) -> str: return ( - f"You need to publish a message to the topic '{self.topic}' with audio samples {self.expected_audio}, " - f"sample rate {self.expected_sample_rate}, and {self.expected_channels} channels.\n" - "Before publishing, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 topics and their message types.\n" - f"2. Find the message type for the topic '{self.topic}'.\n" - "3. Retrieve the full message interface definition for that type.\n" - "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Publish the message to '{self.topic}' using the correct message type and interface.\n" + f"Publish audio message to topic '{self.topic}' with samples " + f"{self.expected_audio}, sample rate {self.expected_sample_rate}, " + f"channels {self.expected_channels}." ) + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can explore available audio topics, examine the message " + f"interface, and publish audio data with samples={self.expected_audio}, " + f"sample_rate={self.expected_sample_rate}, and channels={self.expected_channels}." + ) + class PublishROS2DetectionArrayTask(CustomInterfacesTopicTask): complexity = "easy" - expected_detection_classes: List[str] = ["person", "car"] - expected_detections: List[Detection2D] = [ - Detection2D( - bbox=BoundingBox2D( - center=Pose2D(x=320.0, y=240.0, theta=0.0), - size_x=50.0, - size_y=50.0, + def __init__( + self, + topic: str, + validators: List[Validator], + task_args: TaskArgs, + detection_classes: List[str] = ["person", "car"], + bbox_center_x: float = 320.0, + bbox_center_y: float = 320.0, + bbox_size_x: float = 50.0, + bbox_size_y: float = 50.0, + ) -> None: + super().__init__(topic, validators=validators, task_args=task_args) + self.expected_detection_classes = detection_classes + self.expected_detections = [ + Detection2D( + bbox=BoundingBox2D( + center=Pose2D(x=bbox_center_x, y=bbox_center_y, theta=0.0), + size_x=bbox_size_x, + size_y=bbox_size_y, + ) ) - ) - ] + ] - def get_prompt(self) -> str: + def get_base_prompt(self) -> str: + bbox_center = self.expected_detections[0].bbox.center + bbox_size = self.expected_detections[0].bbox return ( - f"You need to publish a detection message to the topic '{self.topic}' with one detection:\n" - f"{self.expected_detections[0].model_dump()} and detection classes {self.expected_detection_classes}.\n" - "Before publishing, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 topics and their message types.\n" - f"2. Find the message type for the topic '{self.topic}'.\n" - "3. Retrieve the full message interface definition for that type.\n" - "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Publish the message to '{self.topic}' using the correct message type and interface.\n" + f"Publish detection array to topic '{self.topic}' with classes " + f"{self.expected_detection_classes} and bbox center " + f"({bbox_center.x}, {bbox_center.y}) size {bbox_size.size_x}x{bbox_size.size_y}." ) + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + bbox_center = self.expected_detections[0].bbox.center + bbox_size = self.expected_detections[0].bbox + return ( + f"{self.get_base_prompt()} " + "You can explore available detection topics, examine the message " + f"interface, and publish detection data with classes={self.expected_detection_classes} " + f"and bounding box at center ({bbox_center.x}, {bbox_center.y}) " + f"with size_x={bbox_size.size_x}, size_y={bbox_size.size_y}." + ) + class CallROS2ManipulatorMoveToServiceTask(CustomInterfacesServiceTask): complexity = "easy" - expected_initial_gripper_state = True - expected_final_gripper_state = False - expected_target_pose: PoseStamped = PoseStamped( - pose=Pose( - position=Point(x=1.0, y=2.0, z=3.0), - orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0), + def __init__( + self, + service: str, + service_args: dict[str, Any], + validators: List[Validator], + task_args: TaskArgs, + target_x: float = 1.0, + target_y: float = 2.0, + target_z: float = 3.0, + initial_gripper_state: bool = True, + final_gripper_state: bool = False, + frame_id: str = "base_link", + ) -> None: + super().__init__( + service, service_args, validators=validators, task_args=task_args + ) + self.expected_initial_gripper_state = initial_gripper_state + self.expected_final_gripper_state = final_gripper_state + self.expected_target_pose = PoseStamped( + header=Header(frame_id=frame_id), + pose=Pose( + position=Point(x=target_x, y=target_y, z=target_z), + orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0), + ), ) - ) - def get_prompt(self) -> str: + def get_base_prompt(self) -> str: + pos = self.expected_target_pose.pose.position return ( - f"You need to call the service '{self.service}' with a target_pose: " - f"{self.expected_target_pose.model_dump()} and gripper states (initial: {self.expected_initial_gripper_state}, final: {self.expected_final_gripper_state}).\n" - "Before calling, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 services and their types.\n" - f"2. Find the service type for '{self.service}'.\n" - "3. Retrieve the full message interface definition for that service.\n" - "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Call the service '{self.service}' using the correct message type and interface.\n" + f"Call service '{self.service}' to move manipulator to pose " + f"({pos.x}, {pos.y}, {pos.z}) with initial_gripper={self.expected_initial_gripper_state}, " + f"final_gripper={self.expected_final_gripper_state}." ) + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + pos = self.expected_target_pose.pose.position + return ( + f"{self.get_base_prompt()} " + "You can discover available manipulation services, examine the service " + f"interface, and call the service with target_pose position (x={pos.x}, " + f"y={pos.y}, z={pos.z}), initial_gripper_state={self.expected_initial_gripper_state}, " + f"and final_gripper_state={self.expected_final_gripper_state}." + ) + class CallGroundedSAMSegmentTask(CustomInterfacesServiceTask): complexity = "easy" - expected_detections: RAIDetectionArray = RAIDetectionArray( - header=Header(stamp=Time(sec=0, nanosec=0), frame_id="camera_frame"), - detections=[], - ) + def __init__( + self, + service: str, + service_args: dict[str, Any], + validators: List[Validator], + task_args: TaskArgs, + frame_id: str = "camera_frame", + ) -> None: + super().__init__( + service, service_args, validators=validators, task_args=task_args + ) + self.expected_detections = RAIDetectionArray( + header=Header(stamp=Time(sec=0, nanosec=0), frame_id=frame_id), + detections=[], + ) + + def get_base_prompt(self) -> str: + return f"Call service '{self.service}' for image segmentation." def get_prompt(self) -> str: - return ( - f"You need to call the service '{self.service}' with detections: {self.expected_detections.model_dump()}\n" - "Before calling, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 services and their types.\n" - f"2. Find the service type for '{self.service}'.\n" - "3. Retrieve the full message interface definition for that service.\n" - "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Call the service '{self.service}' using the correct message type and interface.\n" - ) + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + frame_id = self.expected_detections.header.frame_id + return ( + f"{self.get_base_prompt()} " + "You can discover available AI vision services, examine the service " + f"interface, and call the segmentation service with detections array " + f"(empty detections, header frame_id='{frame_id}') and source image." + ) class CallGroundingDinoClassify(CustomInterfacesServiceTask): complexity = "easy" - expected_classes: str = "bottle, book, chair" - expected_box_threshold: float = 0.4 - expected_text_threshold: float = 0.25 + def __init__( + self, + service: str, + service_args: dict[str, Any], + validators: List[Validator], + task_args: TaskArgs, + classes: str = "bottle, book, chair", + box_threshold: float = 0.4, + text_threshold: float = 0.25, + ) -> None: + super().__init__( + service, service_args, validators=validators, task_args=task_args + ) + self.expected_classes = classes + self.expected_box_threshold = box_threshold + self.expected_text_threshold = text_threshold - def get_prompt(self) -> str: + def get_base_prompt(self) -> str: return ( - f"You need to call the service '{self.service}' with classes: '{self.expected_classes}', " - f"box_threshold: {self.expected_box_threshold}, text_threshold: {self.expected_text_threshold}, " - "Before calling, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 services and their types.\n" - f"2. Find the service type for '{self.service}'.\n" - "3. Retrieve the full message interface definition for that service.\n" - "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Call the service '{self.service}' using the correct message type and interface.\n" + f"Call service '{self.service}' for object classification with classes " + f"'{self.expected_classes}', box_threshold {self.expected_box_threshold}, " + f"text_threshold {self.expected_text_threshold}." ) + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover available AI detection services, examine the service " + f"interface, and call the classification service with classes='{self.expected_classes}', " + f"box_threshold={self.expected_box_threshold}, and text_threshold={self.expected_text_threshold}." + ) + class CallGetLogDigestTask(CustomInterfacesServiceTask): complexity = "easy" - def get_prompt(self) -> str: - return ( - f"You need to call the service '{self.service}' with an empty request.\n" - "Before calling, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 services and their types.\n" - f"2. Find the service type for '{self.service}'.\n" - "3. Retrieve the full message interface definition for that service.\n" - "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Call the service '{self.service}' using the correct message type and interface.\n" + def __init__( + self, + service: str, + service_args: dict[str, Any], + validators: List[Validator], + task_args: TaskArgs, + ) -> None: + super().__init__( + service, service_args, validators=validators, task_args=task_args ) + def get_base_prompt(self) -> str: + return f"Call service '{self.service}' to get log digest." + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover available logging services, examine the service " + "interface, and call the service with an empty request to retrieve " + "system log information." + ) + class CallVectorStoreRetrievalTask(CustomInterfacesServiceTask): complexity = "easy" - expected_query: str = "What is the purpose of this robot?" - def get_prompt(self) -> str: - return ( - f"You need to call the service '{self.service}' with the query: '{self.expected_query}'.\n" - "Before calling, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 services and their types.\n" - f"2. Find the service type for '{self.service}'.\n" - "3. Retrieve the full message interface definition for that service.\n" - "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Call the service '{self.service}' using the correct message type and interface.\n" + def __init__( + self, + service: str, + service_args: dict[str, Any], + validators: List[Validator], + task_args: TaskArgs, + query: str = "What is the purpose of this robot?", + ) -> None: + super().__init__( + service, service_args, validators=validators, task_args=task_args ) + self.expected_query = query + + def get_base_prompt(self) -> str: + return f"Call service '{self.service}' with query '{self.expected_query}'" + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover available knowledge services, examine the service " + f"interface, and call the retrieval service with query='{self.expected_query}' " + "to search the robot's knowledge base." + ) class CallWhatISeeTask(CustomInterfacesServiceTask): complexity = "easy" - expected_observations: List[str] = ["table", "cup", "notebook"] - expected_perception_source: str = "front_camera" - - expected_image: Image = Image( - header=Header(frame_id="camera_frame"), - height=480, - width=640, - ) + def __init__( + self, + service: str, + service_args: dict[str, Any], + validators: List[Validator], + task_args: TaskArgs, + ) -> None: + super().__init__( + service, service_args, validators=validators, task_args=task_args + ) - expected_pose: Pose = Pose( - position=Point(x=1.0, y=2.0, z=0.5), - orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0), - ) + def get_base_prompt(self) -> str: + return f"Call service '{self.service}' to get visual observations." def get_prompt(self) -> str: - return ( - f"You need to call the service '{self.service}' with an empty request.\n" - "Before calling, follow these steps:\n" - "1. Use the tool to retrieve the available ROS2 services and their types.\n" - f"2. Find the service type for '{self.service}'.\n" - "3. Retrieve the full message interface definition for that service.\n" - "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" - f"5. Call the service '{self.service}' using the correct message type and interface.\n" - ) + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover available vision services, examine the service " + "interface, and call the service with an empty request to get " + "visual observations and camera pose information." + ) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py index 1cf65fa37..bf6bea985 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py @@ -12,25 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Any, Dict, List import inflect from langchain_core.tools import BaseTool from rai.tools.ros2 import MoveToPointToolInput from rai.types import Point -from rai_bench.tool_calling_agent.interfaces import Task, Validator +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator +from rai_bench.tool_calling_agent.mocked_ros2_interfaces import ( + COMMON_INTERFACES, + COMMON_SERVICES_AND_TYPES, + COMMON_TOPICS_AND_TYPES, + MANIPULATION_ACTIONS_AND_TYPES, + MANIPULATION_INTERFACES, + MANIPULATION_SERVICES_AND_TYPES, + MANIPULATION_TOPICS_AND_TYPES, +) from rai_bench.tool_calling_agent.mocked_tools import ( MockGetObjectPositionsTool, + MockGetROS2MessageInterfaceTool, + MockGetROS2ServicesNamesAndTypesTool, MockGetROS2TopicsNamesAndTypesTool, MockMoveToPointTool, ) -loggers_type = logging.Logger +INTERFACES = COMMON_INTERFACES | MANIPULATION_INTERFACES +TOPCIS_AND_TYPES = COMMON_TOPICS_AND_TYPES | MANIPULATION_TOPICS_AND_TYPES +SERVICES_AND_TYPES = COMMON_SERVICES_AND_TYPES | MANIPULATION_SERVICES_AND_TYPES + +TOPIC_STRINGS = [ + f"topic: {topic}\ntype: {topic_type}\n" + for topic, topic_type in COMMON_TOPICS_AND_TYPES.items() +] + +ACTION_STRINGS = [ + f"action: {action}\ntype: {act_type}\n" + for action, act_type in MANIPULATION_ACTIONS_AND_TYPES.items() +] + +SERVICE_STRINGS = [ + f"service: {service}\ntype: {srv_type}\n" + for service, srv_type in SERVICES_AND_TYPES.items() +] -PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT = """ +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT = """ You are a robotic arm with interfaces to detect and manipulate objects. Here are the coordinates information: x - front to back (positive is forward) @@ -38,6 +65,23 @@ z - up to down (positive is up). """ +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT = ( + PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT + + """ +Example of tool calls: +- get_object_positions, args: {} +- move_to_point, args: {'x': 0.5, 'y': 0.2, 'z': 0.3, 'task': 'grab'}""" +) + +PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT = ( + PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + + """ +- move_to_point, args: {'x': 1.7, 'y': 1.8, 'z': 1.9, 'task': 'drop'} +- move_to_point, args: {'x': 0.1, 'y': -0.2, 'z': 0.1, 'task': 'grab'} +- move_to_point, args: {'x': 0.7, 'y': 0.8, 'z': 0.9, 'task': 'drop'} +""" +) + class TaskParametrizationError(Exception): """Exception raised when the task parameters are not valid.""" @@ -46,48 +90,28 @@ class TaskParametrizationError(Exception): class ManipulationTask(Task, ABC): - @property - def type(self) -> str: - return "manipulation" - - def get_system_prompt(self) -> str: - return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT - + type = "manipulation" -class GrabTask(ManipulationTask, ABC): def __init__( self, objects: Dict[str, List[Point]], - object_to_grab: str, validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, + task_args: TaskArgs, + **kwargs: Any, ) -> None: - super().__init__( - validators=validators, - extra_tool_calls=extra_tool_calls, - logger=logger, - ) + super().__init__(validators=validators, task_args=task_args, **kwargs) self.objects = objects - self.object_to_grab = object_to_grab self._verify_args() - @abstractmethod - def _verify_args(self) -> None: - pass + @property + def optional_tool_calls_number(self) -> int: + return 0 @property def available_tools(self) -> List[BaseTool]: return [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - ] + mock_topics_names_and_types=TOPIC_STRINGS ), MockGetObjectPositionsTool( target_frame="panda_link0", @@ -98,41 +122,76 @@ def available_tools(self) -> List[BaseTool]: mock_objects=self.objects, ), MockMoveToPointTool(manipulator_frame="panda_link0"), + MockGetROS2ServicesNamesAndTypesTool( + mock_service_names_and_types=SERVICE_STRINGS + ), + MockGetROS2MessageInterfaceTool(mock_interfaces=INTERFACES), ] + def _verify_args(self) -> None: + pass + + def get_system_prompt(self) -> str: + if self.n_shots == 0: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT + elif self.n_shots == 2: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + else: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT + + +class GrabTask(ManipulationTask, ABC): + def __init__( + self, + objects: Dict[str, List[Point]], + object_to_grab: str, + validators: List[Validator], + task_args: TaskArgs, + **kwargs: Any, + ) -> None: + super().__init__( + validators=validators, objects=objects, task_args=task_args, **kwargs + ) + self.object_to_grab = object_to_grab + self._verify_args() + + @abstractmethod + def _verify_args(self) -> None: + pass + class MoveToPointTask(ManipulationTask): complexity = "easy" def __init__( self, + objects: Dict[str, List[Point]], move_to_tool_input: MoveToPointToolInput, validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, + task_args: TaskArgs, + **kwargs: Any, ) -> None: super().__init__( - validators=validators, extra_tool_calls=extra_tool_calls, logger=logger + validators=validators, objects=objects, task_args=task_args, **kwargs ) - self.move_to_tool_input = move_to_tool_input - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - ] - ), - MockMoveToPointTool(manipulator_frame="base_link"), - ] + def get_base_prompt(self) -> str: + return ( + f"Move the arm to point x={self.move_to_tool_input.x}, " + f"y={self.move_to_tool_input.y}, z={self.move_to_tool_input.z} " + f"to {self.move_to_tool_input.task} an object." + ) def get_prompt(self) -> str: - return f"Move the arm to a point x={self.move_to_tool_input.x}, y={self.move_to_tool_input.y}, z={self.move_to_tool_input.z} to {self.move_to_tool_input.task} an object." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can control the arm movement to the specified coordinates " + f"and perform the {self.move_to_tool_input.task} action at that location." + ) class GetObjectPositionsTask(ManipulationTask): @@ -142,42 +201,15 @@ def __init__( self, objects: Dict[str, List[Point]], validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, + task_args: TaskArgs, + **kwargs: Any, ) -> None: super().__init__( - validators=validators, extra_tool_calls=extra_tool_calls, logger=logger + validators=validators, objects=objects, task_args=task_args, **kwargs ) - """Task to get the positions of the objects - - Examples - -------- - objects = { - "banana": [(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], - "cube": [(0.7, 0.8, 0.9)], - } - """ self.objects = objects - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - ] - ), - MockGetObjectPositionsTool(mock_objects=self.objects), - ] - - def get_prompt(self) -> str: - """Generates a prompt based on the objects provided in the task. If there is more than one object, the object in the prompt will be pluralized. - Returns: - str: Formatted prompt for the task - """ + def get_base_prompt(self) -> str: inflector = inflect.engine() object_counts = {obj: len(positions) for obj, positions in self.objects.items()} formatted_objects = [ @@ -190,24 +222,35 @@ def get_prompt(self) -> str: ) else: objects_list = formatted_objects[0] + return f"Get the {objects_list} positions." + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can detect all objects and retrieve their 3D coordinates " + "for manipulation planning." + ) + class GrabExistingObjectTask(GrabTask): complexity = "medium" - """ - Task to grab an object. - Parameters - ---------- - objects : Dict[str, List[dict[str, float]]] - Dictionary of object types and their positions. - object_to_grab : str - The object to be grabbed (must have a single position). - """ + def get_base_prompt(self) -> str: + return f"Grab {self.object_to_grab}." def get_prompt(self) -> str: - return f"Grab {self.object_to_grab}." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can locate the object in the workspace and move the arm " + "to grab it at the correct coordinates." + ) def _verify_args(self): if self.object_to_grab not in self.objects: @@ -223,19 +266,19 @@ def _verify_args(self): class GrabNotExistingObjectTask(GrabTask): complexity = "medium" - """ - Task to attempt grabbing an object that does not exist. - Parameters - ---------- - objects : Dict[str, List[dict[str, float]]] - Available objects and their positions. - object_to_grab : str - Object that should not be present in the list. - """ + def get_base_prompt(self) -> str: + return f"Grab {self.object_to_grab}." def get_prompt(self) -> str: - return f"Grab {self.object_to_grab}." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can check if the object exists in the environment and " + "attempt to grab it if found." + ) def _verify_args(self): if self.object_to_grab in self.objects: @@ -245,30 +288,20 @@ def _verify_args(self): class MoveExistingObjectLeftTask(GrabTask): - """Task to move an existing object to the left. - - Parameters - ---------- - objects : Dict[str, List[dict[str, float]]] - Dictionary containing the object types and their positions. Object type should be passed as singular. - object_to_grab : str - Object type should be passed as singular. Object to be grabbed should be defined in the objects argument with only one instance (one position). - logger : loggers_type | None, optional - Logger, by default None - - Examples - -------- - objects = { - "banana": [(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], - "cube": [(0.7, 0.8, 0.9)], - } - object_to_grab = "cube" - """ + complexity = "hard" - complexity = "medium" + def get_base_prompt(self) -> str: + return f"Move {self.object_to_grab} 20 cm to the left." def get_prompt(self) -> str: - return f"Move {self.object_to_grab} 20 cm to the left." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can locate the object, grab it with the manipulator, " + "and move it to a position 20 cm to the left of its current location." + ) def _verify_args(self): if self.object_to_grab not in self.objects: @@ -283,22 +316,20 @@ def _verify_args(self): class MoveExistingObjectFrontTask(GrabTask): - """Task to move an existing object to the front - - Parameters - ---------- - objects : Dict[str, List[dict[str, float]]] - Dictionary containing the object types and their positions. Object type should be passed as singular. - object_to_grab : str - Object to grab. Object type should be passed as singular. Object to be grabbed should be defined in the objects argument with only one instance (one position). - logger : loggers_type | None, optional - Logger, by default None - """ + complexity = "hard" - complexity = "medium" + def get_base_prompt(self) -> str: + return f"Move {self.object_to_grab} 60 cm to the front." def get_prompt(self) -> str: - return f"Move {self.object_to_grab} 60 cm to the front." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can locate the object, grab it with the manipulator, " + "and move it to a position 60 cm forward from its current location." + ) def _verify_args(self): if self.object_to_grab not in self.objects: @@ -312,71 +343,37 @@ def _verify_args(self): raise TaskParametrizationError(error_message) -class SwapObjectsTask(Task): - """Task to swap objects - - Parameters - ---------- - objects : Dict[str, List[Dict[str, float]]] - Dictionary containing the object types and their positions. Object type should be passed as singular. - objects_to_swap : List[str] - Objects to be swapped. Object type should be passed as singular. Objects to be swapped should be defined in the objects argument with only one instance (one position). - logger : loggers_type | None, optional - Logger, by default None - - Examples - -------- - objects = { - "banana": [(0.1, 0.2, 0.1)], - "cube": [(0.7, 0.8, 0.1)], - "apple": [(0.3, 0.4, 0.1), (0.5, 0.6, 0.1)], - - } - objects_to_swap = ["cube", "banana"] - """ - +class SwapObjectsTask(ManipulationTask): complexity = "hard" def __init__( self, objects: Dict[str, List[Point]], - objects_to_swap: str, + objects_to_swap: List[str], validators: List[Validator], - extra_tool_calls: int = 0, - logger: loggers_type | None = None, + task_args: TaskArgs, + **kwargs: Any, ) -> None: super().__init__( - validators=validators, - extra_tool_calls=extra_tool_calls, - logger=logger, + validators=validators, objects=objects, task_args=task_args, **kwargs ) self.objects = objects self.objects_to_swap = objects_to_swap self._verify_args() - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - ] - ), - MockGetObjectPositionsTool( - target_frame="panda_link0", - source_frame="RGBDCamera5", - camera_topic="/color_image5", - depth_topic="/depth_image5", - camera_info_topic="/color_camera_info5", - mock_objects=self.objects, - ), - MockMoveToPointTool(manipulator_frame="panda_link0"), - ] + def get_base_prompt(self) -> str: + return f"Swap {self.objects_to_swap[0]} and {self.objects_to_swap[1]}." + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can locate both objects in the workspace, then perform a sequence " + f"of grab and move operations to swap the positions of {self.objects_to_swap[0]} " + f"and {self.objects_to_swap[1]}." + ) def _verify_args(self): for obj in self.objects_to_swap: @@ -392,6 +389,3 @@ def _verify_args(self): error_message = f"Number of requested objects to swap {len(self.objects_to_swap)} should be equal to 2." self.logger.error(msg=error_message) raise TaskParametrizationError(error_message) - - def get_prompt(self) -> str: - return f"Move {self.objects_to_swap[0]} to the initial position of {self.objects_to_swap[1]}, and move {self.objects_to_swap[1]} to the initial position of {self.objects_to_swap[0]}." diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py index 037872e04..79514d256 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py @@ -12,44 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Dict, List, Type +from typing import List from langchain_core.tools import BaseTool -from pydantic import BaseModel -from rai_open_set_vision.tools.gdino_tools import ( - DistanceMeasurement, -) from rai_bench.tool_calling_agent.interfaces import Task -from rai_bench.tool_calling_agent.messages.actions import ( - AssistedTeleopGoal, - BackUpGoal, - ComputePathThroughPosesGoal, - ComputePathToPoseGoal, - DriveOnHeadingGoal, - FollowPathGoal, - FollowWaypointsGoal, - NavigateThroughPosesGoal, - NavigateToPoseGoal, - SmoothPathGoal, - SpinGoal, - WaitGoal, +from rai_bench.tool_calling_agent.mocked_ros2_interfaces import ( + COMMON_SERVICES_AND_TYPES, + COMMON_TOPICS_AND_TYPES, + NAVIGATION_ACTION_MODELS, + NAVIGATION_ACTIONS_AND_TYPES, + NAVIGATION_INTERFACES, + NAVIGATION_SERVICES_AND_TYPES, + NAVIGATION_TOPICS_AND_TYPES, ) from rai_bench.tool_calling_agent.mocked_tools import ( MockActionsToolkit, - MockGetDistanceToObjectsTool, - MockGetROS2ActionFeedbackTool, - MockGetROS2ActionResultTool, - MockGetROS2ActionsNamesAndTypesTool, MockGetROS2MessageInterfaceTool, + MockGetROS2ServicesNamesAndTypesTool, MockGetROS2TopicsNamesAndTypesTool, - MockStartROS2ActionTool, ) -loggers_type = logging.Logger - -ROBOT_NAVIGATION_SYSTEM_PROMPT = """You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. +ROBOT_NAVIGATION_SYSTEM_PROMPT_0_SHOT = """You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. Do not make assumptions about the environment you are currently in. You can use ros2 topics, services and actions to operate. @@ -98,768 +82,36 @@ (0.79, 5.73, 0.0), (0.92, 1.01, 0.0) - Before starting anything, make sure to load available topics, services and actions. + Before starting anything, make sure to load available topics, services and actions.""" + +ROBOT_NAVIGATION_SYSTEM_PROMPT_2_SHOT = ( + ROBOT_NAVIGATION_SYSTEM_PROMPT_0_SHOT + + """ + Example tool calls: - - get_ros2_message_interface, args: {'msg_type': 'turtlesim/srv/TeleportAbsolute'} - - publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}} - - start_ros2_action, args: {'action_name': '/dock', 'action_type': 'nav2_msgs/action/Dock', 'action_args': {}} - """ - -TOPICS_NAMES_AND_TYPES = [ - "topic: /assisted_teleop/_action/feedback\ntype: nav2_msgs/action/AssistedTeleop_FeedbackMessage\n", - "topic: /assisted_teleop/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /backup/_action/feedback\ntype: nav2_msgs/action/BackUp_FeedbackMessage\n", - "topic: /backup/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /behavior_server/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /behavior_tree_log\ntype: nav2_msgs/msg/BehaviorTreeLog\n", - "topic: /bond\ntype: bond/msg/Status\n", - "topic: /bt_navigator/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /camera/camera/color/camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /camera/camera/color/image_raw\ntype: sensor_msgs/msg/Image\n", - "topic: /camera/camera/depth/camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /camera/camera/depth/image_rect_raw\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /cmd_vel_nav\ntype: geometry_msgs/msg/Twist\n", - "topic: /cmd_vel_teleop\ntype: geometry_msgs/msg/Twist\n", - "topic: /compute_path_through_poses/_action/feedback\ntype: nav2_msgs/action/ComputePathThroughPoses_FeedbackMessage\n", - "topic: /compute_path_through_poses/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /compute_path_to_pose/_action/feedback\ntype: nav2_msgs/action/ComputePathToPose_FeedbackMessage\n", - "topic: /compute_path_to_pose/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /controller_server/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /diagnostics\ntype: diagnostic_msgs/msg/DiagnosticArray\n", - "topic: /drive_on_heading/_action/feedback\ntype: nav2_msgs/action/DriveOnHeading_FeedbackMessage\n", - "topic: /drive_on_heading/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /follow_path/_action/feedback\ntype: nav2_msgs/action/FollowPath_FeedbackMessage\n", - "topic: /follow_path/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /follow_waypoints/_action/feedback\ntype: nav2_msgs/action/FollowWaypoints_FeedbackMessage\n", - "topic: /follow_waypoints/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /global_costmap/costmap\ntype: nav_msgs/msg/OccupancyGrid\n", - "topic: /global_costmap/costmap_raw\ntype: nav2_msgs/msg/Costmap\n", - "topic: /global_costmap/costmap_updates\ntype: map_msgs/msg/OccupancyGridUpdate\n", - "topic: /global_costmap/footprint\ntype: geometry_msgs/msg/Polygon\n", - "topic: /global_costmap/global_costmap/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /global_costmap/published_footprint\ntype: geometry_msgs/msg/PolygonStamped\n", - "topic: /global_costmap/scan\ntype: sensor_msgs/msg/LaserScan\n", - "topic: /goal_pose\ntype: geometry_msgs/msg/PoseStamped\n", - "topic: /led_strip\ntype: sensor_msgs/msg/Image\n", - "topic: /local_costmap/costmap\ntype: nav_msgs/msg/OccupancyGrid\n", - "topic: /local_costmap/costmap_raw\ntype: nav2_msgs/msg/Costmap\n", - "topic: /local_costmap/costmap_updates\ntype: map_msgs/msg/OccupancyGridUpdate\n", - "topic: /local_costmap/footprint\ntype: geometry_msgs/msg/Polygon\n", - "topic: /local_costmap/local_costmap/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /local_costmap/published_footprint\ntype: geometry_msgs/msg/PolygonStamped\n", - "topic: /local_costmap/scan\ntype: sensor_msgs/msg/LaserScan\n", - "topic: /map\ntype: nav_msgs/msg/OccupancyGrid\n", - "topic: /map_metadata\ntype: nav_msgs/msg/MapMetaData\n", - "topic: /map_saver/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /navigate_through_poses/_action/feedback\ntype: nav2_msgs/action/NavigateThroughPoses_FeedbackMessage\n", - "topic: /navigate_through_poses/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /navigate_to_pose/_action/feedback\ntype: nav2_msgs/action/NavigateToPose_FeedbackMessage\n", - "topic: /navigate_to_pose/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /odom\ntype: nav_msgs/msg/Odometry\n", - "topic: /odometry/filtered\ntype: nav_msgs/msg/Odometry\n", - "topic: /parameter_events\ntype: rcl_interfaces/msg/ParameterEvent\n", - "topic: /plan\ntype: nav_msgs/msg/Path\n", - "topic: /plan_smoothed\ntype: nav_msgs/msg/Path\n", - "topic: /planner_server/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /pose\ntype: geometry_msgs/msg/PoseWithCovarianceStamped\n", - "topic: /preempt_teleop\ntype: std_msgs/msg/Empty\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /scan\ntype: sensor_msgs/msg/LaserScan\n", - "topic: /slam_toolbox/feedback\ntype: visualization_msgs/msg/InteractiveMarkerFeedback\n", - "topic: /slam_toolbox/graph_visualization\ntype: visualization_msgs/msg/MarkerArray\n", - "topic: /slam_toolbox/scan_visualization\ntype: sensor_msgs/msg/LaserScan\n", - "topic: /slam_toolbox/update\ntype: visualization_msgs/msg/InteractiveMarkerUpdate\n", - "topic: /smooth_path/_action/feedback\ntype: nav2_msgs/action/SmoothPath_FeedbackMessage\n", - "topic: /smooth_path/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /smoother_server/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /speed_limit\ntype: nav2_msgs/msg/SpeedLimit\n", - "topic: /spin/_action/feedback\ntype: nav2_msgs/action/Spin_FeedbackMessage\n", - "topic: /spin/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /trajectories\ntype: visualization_msgs/msg/MarkerArray\n", - "topic: /transformed_global_plan\ntype: nav_msgs/msg/Path\n", - "topic: /unsmoothed_plan\ntype: nav_msgs/msg/Path\n", - "topic: /velocity_smoother/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", - "topic: /wait/_action/feedback\ntype: nav2_msgs/action/Wait_FeedbackMessage\n", - "topic: /wait/_action/status\ntype: action_msgs/msg/GoalStatusArray\n", - "topic: /waypoint_follower/transition_event\ntype: lifecycle_msgs/msg/TransitionEvent\n", + - get_ros2_actions_names_and_types, args: {} + - start_ros2_action, args: {'action': '/navigate_to_pose', 'action_type': 'nav2_msgs/action/NavigateToPose', 'goal': {'pose': {'header': {'frame_id': 'map'}, 'pose': {'position': {'x': 2.0, 'y': 2.0, 'z': 0.0}}}}}""" +) + +ROBOT_NAVIGATION_SYSTEM_PROMPT_5_SHOT = ( + ROBOT_NAVIGATION_SYSTEM_PROMPT_2_SHOT + + """ + - get_ros2_message_interface, args: {'msg_type': 'nav2_msgs/action/Spin'} + - start_ros2_action, args: {'action': '/spin', 'action_type': 'nav2_msgs/action/Spin', 'goal': {'target_yaw': 3.14}} + - start_ros2_action, args: {'action': '/drive_on_heading', 'action_type': 'nav2_msgs/action/DriveOnHeading', 'goal': {'target': {'x': 1.0, 'y': 0.0, 'z': 0.0}, 'speed': 0.5}}""" +) +TOPICS_AND_TYPES = COMMON_TOPICS_AND_TYPES | NAVIGATION_TOPICS_AND_TYPES +SERVICES_AND_TYPES = COMMON_SERVICES_AND_TYPES | NAVIGATION_SERVICES_AND_TYPES + + +TOPIC_STRINGS = [ + f"topic: {topic}\ntype: {topic_type}\n" + for topic, topic_type in COMMON_TOPICS_AND_TYPES.items() ] -ACTIONS_AND_TYPES: Dict[str, str] = { - "/assisted_teleop": "nav2_msgs/action/AssistedTeleop", - "/backup": "nav2_msgs/action/BackUp", - "/compute_path_through_poses": "nav2_msgs/action/ComputePathThroughPoses", - "/compute_path_to_pose": "nav2_msgs/action/ComputePathToPose", - "/drive_on_heading": "nav2_msgs/action/DriveOnHeading", - "/follow_path": "nav2_msgs/action/FollowPath", - "/follow_waypoints": "nav2_msgs/action/FollowWaypoints", - "/navigate_through_poses": "nav2_msgs/action/NavigateThroughPoses", - "/navigate_to_pose": "nav2_msgs/action/NavigateToPose", - "/smooth_path": "nav2_msgs/action/SmoothPath", - "/spin": "nav2_msgs/action/Spin", - "/wait": "nav2_msgs/action/Wait", -} - -SERVICES_AND_TYPES: Dict[str, str] = { - "/assisted_teleop/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/assisted_teleop/_action/get_result": "nav2_msgs/action/AssistedTeleop_GetResult", - "/assisted_teleop/_action/send_goal": "nav2_msgs/action/AssistedTeleop_SendGoal", - "/backup/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/backup/_action/get_result": "nav2_msgs/action/BackUp_GetResult", - "/backup/_action/send_goal": "nav2_msgs/action/BackUp_SendGoal", - "/behavior_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/behavior_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/behavior_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/behavior_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/behavior_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/behavior_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/behavior_server/get_state": "lifecycle_msgs/srv/GetState", - "/behavior_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/behavior_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/behavior_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/behavior_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/bt_navigator/change_state": "lifecycle_msgs/srv/ChangeState", - "/bt_navigator/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/bt_navigator/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/bt_navigator/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/bt_navigator/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/bt_navigator/get_parameters": "rcl_interfaces/srv/GetParameters", - "/bt_navigator/get_state": "lifecycle_msgs/srv/GetState", - "/bt_navigator/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/bt_navigator/list_parameters": "rcl_interfaces/srv/ListParameters", - "/bt_navigator/set_parameters": "rcl_interfaces/srv/SetParameters", - "/bt_navigator/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/bt_navigator_navigate_through_poses_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/bt_navigator_navigate_to_pose_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/compute_path_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/compute_path_through_poses/_action/get_result": "nav2_msgs/action/ComputePathThroughPoses_GetResult", - "/compute_path_through_poses/_action/send_goal": "nav2_msgs/action/ComputePathThroughPoses_SendGoal", - "/compute_path_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/compute_path_to_pose/_action/get_result": "nav2_msgs/action/ComputePathToPose_GetResult", - "/compute_path_to_pose/_action/send_goal": "nav2_msgs/action/ComputePathToPose_SendGoal", - "/controller_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/controller_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/controller_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/controller_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/controller_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/controller_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/controller_server/get_state": "lifecycle_msgs/srv/GetState", - "/controller_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/controller_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/controller_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/controller_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/drive_on_heading/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/drive_on_heading/_action/get_result": "nav2_msgs/action/DriveOnHeading_GetResult", - "/drive_on_heading/_action/send_goal": "nav2_msgs/action/DriveOnHeading_SendGoal", - "/follow_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/follow_path/_action/get_result": "nav2_msgs/action/FollowPath_GetResult", - "/follow_path/_action/send_goal": "nav2_msgs/action/FollowPath_SendGoal", - "/follow_waypoints/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/follow_waypoints/_action/get_result": "nav2_msgs/action/FollowWaypoints_GetResult", - "/follow_waypoints/_action/send_goal": "nav2_msgs/action/FollowWaypoints_SendGoal", - "/global_costmap/clear_around_global_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", - "/global_costmap/clear_entirely_global_costmap": "nav2_msgs/srv/ClearEntireCostmap", - "/global_costmap/clear_except_global_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", - "/global_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", - "/global_costmap/global_costmap/change_state": "lifecycle_msgs/srv/ChangeState", - "/global_costmap/global_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/global_costmap/global_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/global_costmap/global_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/global_costmap/global_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/global_costmap/global_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", - "/global_costmap/global_costmap/get_state": "lifecycle_msgs/srv/GetState", - "/global_costmap/global_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/global_costmap/global_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", - "/global_costmap/global_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", - "/global_costmap/global_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/grounded_sam/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/grounded_sam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/grounded_sam/get_parameters": "rcl_interfaces/srv/GetParameters", - "/grounded_sam/list_parameters": "rcl_interfaces/srv/ListParameters", - "/grounded_sam/set_parameters": "rcl_interfaces/srv/SetParameters", - "/grounded_sam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam", - "/grounding_dino/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/grounding_dino/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/grounding_dino/get_parameters": "rcl_interfaces/srv/GetParameters", - "/grounding_dino/list_parameters": "rcl_interfaces/srv/ListParameters", - "/grounding_dino/set_parameters": "rcl_interfaces/srv/SetParameters", - "/grounding_dino/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino", - "/is_path_valid": "nav2_msgs/srv/IsPathValid", - "/launch_ros_138640/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/launch_ros_138640/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/launch_ros_138640/get_parameters": "rcl_interfaces/srv/GetParameters", - "/launch_ros_138640/list_parameters": "rcl_interfaces/srv/ListParameters", - "/launch_ros_138640/set_parameters": "rcl_interfaces/srv/SetParameters", - "/launch_ros_138640/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/lifecycle_manager_navigation/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/lifecycle_manager_navigation/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/lifecycle_manager_navigation/get_parameters": "rcl_interfaces/srv/GetParameters", - "/lifecycle_manager_navigation/is_active": "std_srvs/srv/Trigger", - "/lifecycle_manager_navigation/list_parameters": "rcl_interfaces/srv/ListParameters", - "/lifecycle_manager_navigation/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", - "/lifecycle_manager_navigation/set_parameters": "rcl_interfaces/srv/SetParameters", - "/lifecycle_manager_navigation/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/lifecycle_manager_slam/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/lifecycle_manager_slam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/lifecycle_manager_slam/get_parameters": "rcl_interfaces/srv/GetParameters", - "/lifecycle_manager_slam/is_active": "std_srvs/srv/Trigger", - "/lifecycle_manager_slam/list_parameters": "rcl_interfaces/srv/ListParameters", - "/lifecycle_manager_slam/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", - "/lifecycle_manager_slam/set_parameters": "rcl_interfaces/srv/SetParameters", - "/lifecycle_manager_slam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/local_costmap/clear_around_local_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", - "/local_costmap/clear_entirely_local_costmap": "nav2_msgs/srv/ClearEntireCostmap", - "/local_costmap/clear_except_local_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", - "/local_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", - "/local_costmap/local_costmap/change_state": "lifecycle_msgs/srv/ChangeState", - "/local_costmap/local_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/local_costmap/local_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/local_costmap/local_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/local_costmap/local_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/local_costmap/local_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", - "/local_costmap/local_costmap/get_state": "lifecycle_msgs/srv/GetState", - "/local_costmap/local_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/local_costmap/local_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", - "/local_costmap/local_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", - "/local_costmap/local_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/map_saver/change_state": "lifecycle_msgs/srv/ChangeState", - "/map_saver/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/map_saver/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/map_saver/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/map_saver/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/map_saver/get_parameters": "rcl_interfaces/srv/GetParameters", - "/map_saver/get_state": "lifecycle_msgs/srv/GetState", - "/map_saver/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/map_saver/list_parameters": "rcl_interfaces/srv/ListParameters", - "/map_saver/save_map": "nav2_msgs/srv/SaveMap", - "/map_saver/set_parameters": "rcl_interfaces/srv/SetParameters", - "/map_saver/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/nav2_container/_container/list_nodes": "composition_interfaces/srv/ListNodes", - "/nav2_container/_container/load_node": "composition_interfaces/srv/LoadNode", - "/nav2_container/_container/unload_node": "composition_interfaces/srv/UnloadNode", - "/navigate_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/navigate_through_poses/_action/get_result": "nav2_msgs/action/NavigateThroughPoses_GetResult", - "/navigate_through_poses/_action/send_goal": "nav2_msgs/action/NavigateThroughPoses_SendGoal", - "/navigate_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/navigate_to_pose/_action/get_result": "nav2_msgs/action/NavigateToPose_GetResult", - "/navigate_to_pose/_action/send_goal": "nav2_msgs/action/NavigateToPose_SendGoal", - "/o3de_ros2_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/o3de_ros2_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/o3de_ros2_node/get_parameters": "rcl_interfaces/srv/GetParameters", - "/o3de_ros2_node/list_parameters": "rcl_interfaces/srv/ListParameters", - "/o3de_ros2_node/set_parameters": "rcl_interfaces/srv/SetParameters", - "/o3de_ros2_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/planner_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/planner_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/planner_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/planner_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/planner_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/planner_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/planner_server/get_state": "lifecycle_msgs/srv/GetState", - "/planner_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/planner_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/planner_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/planner_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/rai_ros2_ari_connector_b6ed00ab6356/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/rai_ros2_ari_connector_b6ed00ab6356/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/rai_ros2_ari_connector_b6ed00ab6356/get_parameters": "rcl_interfaces/srv/GetParameters", - "/rai_ros2_ari_connector_b6ed00ab6356/list_parameters": "rcl_interfaces/srv/ListParameters", - "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters": "rcl_interfaces/srv/SetParameters", - "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/slam_toolbox/clear_changes": "slam_toolbox/srv/Clear", - "/slam_toolbox/clear_queue": "slam_toolbox/srv/ClearQueue", - "/slam_toolbox/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/slam_toolbox/deserialize_map": "slam_toolbox/srv/DeserializePoseGraph", - "/slam_toolbox/dynamic_map": "nav_msgs/srv/GetMap", - "/slam_toolbox/get_interactive_markers": "visualization_msgs/srv/GetInteractiveMarkers", - "/slam_toolbox/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/slam_toolbox/get_parameters": "rcl_interfaces/srv/GetParameters", - "/slam_toolbox/list_parameters": "rcl_interfaces/srv/ListParameters", - "/slam_toolbox/manual_loop_closure": "slam_toolbox/srv/LoopClosure", - "/slam_toolbox/pause_new_measurements": "slam_toolbox/srv/Pause", - "/slam_toolbox/save_map": "slam_toolbox/srv/SaveMap", - "/slam_toolbox/serialize_map": "slam_toolbox/srv/SerializePoseGraph", - "/slam_toolbox/set_parameters": "rcl_interfaces/srv/SetParameters", - "/slam_toolbox/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/slam_toolbox/toggle_interactive_mode": "slam_toolbox/srv/ToggleInteractive", - "/smooth_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/smooth_path/_action/get_result": "nav2_msgs/action/SmoothPath_GetResult", - "/smooth_path/_action/send_goal": "nav2_msgs/action/SmoothPath_SendGoal", - "/smoother_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/smoother_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/smoother_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/smoother_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/smoother_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/smoother_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/smoother_server/get_state": "lifecycle_msgs/srv/GetState", - "/smoother_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/smoother_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/smoother_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/smoother_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/spin/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/spin/_action/get_result": "nav2_msgs/action/Spin_GetResult", - "/spin/_action/send_goal": "nav2_msgs/action/Spin_SendGoal", - "/tf2_frames": "tf2_msgs/srv/FrameGraph", - "/velocity_smoother/change_state": "lifecycle_msgs/srv/ChangeState", - "/velocity_smoother/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/velocity_smoother/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/velocity_smoother/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/velocity_smoother/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/velocity_smoother/get_parameters": "rcl_interfaces/srv/GetParameters", - "/velocity_smoother/get_state": "lifecycle_msgs/srv/GetState", - "/velocity_smoother/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/velocity_smoother/list_parameters": "rcl_interfaces/srv/ListParameters", - "/velocity_smoother/set_parameters": "rcl_interfaces/srv/SetParameters", - "/velocity_smoother/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/wait/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/wait/_action/get_result": "nav2_msgs/action/Wait_GetResult", - "/wait/_action/send_goal": "nav2_msgs/action/Wait_SendGoal", - "/waypoint_follower/change_state": "lifecycle_msgs/srv/ChangeState", - "/waypoint_follower/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/waypoint_follower/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/waypoint_follower/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/waypoint_follower/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/waypoint_follower/get_parameters": "rcl_interfaces/srv/GetParameters", - "/waypoint_follower/get_state": "lifecycle_msgs/srv/GetState", - "/waypoint_follower/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/waypoint_follower/list_parameters": "rcl_interfaces/srv/ListParameters", - "/waypoint_follower/set_parameters": "rcl_interfaces/srv/SetParameters", - "/waypoint_follower/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", -} -INTERFACES: Dict[str, str] = { - "nav2_msgs/action/NavigateToPose": """#goal definition -geometry_msgs/PoseStamped pose - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string behavior_tree ---- -#result definition -std_msgs/Empty result ---- -#feedback definition -geometry_msgs/PoseStamped current_pose - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration navigation_time - int32 sec - uint32 nanosec -builtin_interfaces/Duration estimated_time_remaining - int32 sec - uint32 nanosec -int16 number_of_recoveries -float32 distance_remaining -""", - "nav2_msgs/action/AssistedTeleop": """#goal definition -builtin_interfaces/Duration time_allowance - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback -builtin_interfaces/Duration current_teleop_duration - int32 sec - uint32 nanosec""", - "nav2_msgs/action/BackUp": """#goal definition -geometry_msgs/Point target - float64 x - float64 y - float64 z -float32 speed -builtin_interfaces/Duration time_allowance - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback definition -float32 distance_traveled""", - "nav2_msgs/action/ComputePathThroughPoses": """#goal definition -geometry_msgs/PoseStamped[] goals - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -geometry_msgs/PoseStamped start - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string planner_id -bool use_start # If false, use current robot pose as path start, if true, use start above instead ---- -#result definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration planning_time - int32 sec - uint32 nanosec ---- -#feedback definition""", - "nav2_msgs/action/ComputePathToPose": """#goal definition -geometry_msgs/PoseStamped goal - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -geometry_msgs/PoseStamped start - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string planner_id -bool use_start # If false, use current robot pose as path start, if true, use start above instead ---- -#result definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration planning_time - int32 sec - uint32 nanosec ---- -#feedback definition""", - "nav2_msgs/action/DriveOnHeading": """#goal definition -geometry_msgs/Point target - float64 x - float64 y - float64 z -float32 speed -builtin_interfaces/Duration time_allowance - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback definition -float32 distance_traveled""", - "nav2_msgs/action/FollowPath": """#goal definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string controller_id -string goal_checker_id ---- -#result definition -std_msgs/Empty result ---- -#feedback definition -float32 distance_to_goal -float32 speed""", - "nav2_msgs/action/FollowWaypoints": """#goal definition -geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 ---- -#result definition -int32[] missed_waypoints ---- -#feedback definition -uint32 current_waypoint""", - "nav2_msgs/action/NavigateThroughPoses": """#goal definition -geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string behavior_tree ---- -#result definition -std_msgs/Empty result ---- -#feedback definition -geometry_msgs/PoseStamped current_pose - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration navigation_time - int32 sec - uint32 nanosec -builtin_interfaces/Duration estimated_time_remaining - int32 sec - uint32 nanosec -int16 number_of_recoveries -float32 distance_remaining -int16 number_of_poses_remaining -""", - "nav2_msgs/action/SmoothPath": """#goal definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string smoother_id -builtin_interfaces/Duration max_smoothing_duration - int32 sec - uint32 nanosec -bool check_for_collisions ---- -#result definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration smoothing_duration - int32 sec - uint32 nanosec -bool was_completed ---- -#feedback definition -""", - "nav2_msgs/action/Wait": """#goal definition -builtin_interfaces/Duration time - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback definition -builtin_interfaces/Duration time_left - int32 sec - uint32 nanosec""", -} - -ACTION_MODELS: Dict[str, Type[BaseModel]] = { - "nav2_msgs/action/NavigateToPose": NavigateToPoseGoal, - "nav2_msgs/action/Spin": SpinGoal, - "nav2_msgs/action/AssistedTeleop": AssistedTeleopGoal, - "nav2_msgs/action/BackUp": BackUpGoal, - "nav2_msgs/action/ComputePathThroughPoses": ComputePathThroughPosesGoal, - "nav2_msgs/action/ComputePathToPose": ComputePathToPoseGoal, - "nav2_msgs/action/DriveOnHeading": DriveOnHeadingGoal, - "nav2_msgs/action/FollowPath": FollowPathGoal, - "nav2_msgs/action/FollowWaypoints": FollowWaypointsGoal, - "nav2_msgs/action/NavigateThroughPoses": NavigateThroughPosesGoal, - "nav2_msgs/action/SmoothPath": SmoothPathGoal, - "nav2_msgs/action/Wait": WaitGoal, -} ACTION_STRINGS = [ f"action: {action}\ntype: {act_type}\n" - for action, act_type in ACTIONS_AND_TYPES.items() + for action, act_type in NAVIGATION_ACTIONS_AND_TYPES.items() ] SERVICE_STRINGS = [ @@ -869,79 +121,111 @@ class NavigationTask(Task): - @property - def type(self) -> str: - return "navigation" - - def get_system_prompt(self) -> str: - return ROBOT_NAVIGATION_SYSTEM_PROMPT + type = "navigation" @property def available_tools(self) -> List[BaseTool]: tools = MockActionsToolkit( mock_actions_names_and_types=ACTION_STRINGS, - available_actions=list(ACTIONS_AND_TYPES.keys()), - available_action_types=list(ACTIONS_AND_TYPES.values()), - available_action_models=ACTION_MODELS, + available_actions=list(NAVIGATION_ACTIONS_AND_TYPES.keys()), + available_action_types=list(NAVIGATION_ACTIONS_AND_TYPES.values()), + available_action_models=NAVIGATION_ACTION_MODELS, ).get_tools() - tools.append(MockGetROS2MessageInterfaceTool(mock_interfaces=INTERFACES)) + tools.append( + MockGetROS2TopicsNamesAndTypesTool( + mock_topics_names_and_types=TOPIC_STRINGS + ) + ) + tools.append( + MockGetROS2ServicesNamesAndTypesTool( + mock_service_names_and_types=SERVICE_STRINGS + ) + ) + tools.append( + MockGetROS2MessageInterfaceTool(mock_interfaces=NAVIGATION_INTERFACES) + ) return tools + @property + def optional_tool_calls_number(self) -> int: + # list topics and get interface + return 2 + + def get_system_prompt(self) -> str: + if self.n_shots == 0: + return ROBOT_NAVIGATION_SYSTEM_PROMPT_0_SHOT + elif self.n_shots == 2: + return ROBOT_NAVIGATION_SYSTEM_PROMPT_2_SHOT + else: + return ROBOT_NAVIGATION_SYSTEM_PROMPT_5_SHOT + class NavigateToPointTask(NavigationTask): - complexity = "medium" + complexity = "easy" + + def get_base_prompt(self) -> str: + return "Navigate to point (2.0, 2.0, 0.0)." def get_prompt(self) -> str: - return "Navigate to the point (2.0, 2.0, 0.0). Remember to list actions and get interface" + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can use the navigation tools to move the robot to the specified coordinates. " + "First get the available actions, then set up the navigation goal to reach point (2.0, 2.0, 0.0)." + ) class SpinAroundTask(NavigationTask): recursion_limit = 50 complexity = "medium" - def get_prompt(self) -> str: + def get_base_prompt(self) -> str: return "Spin around by 3 radians." + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can locate the robot's current orientation and execute a spinning motion " + "to rotate the robot by 3 radians from its current heading." + ) + class MoveToFrontTask(NavigationTask): recursion_limit = 50 complexity = "medium" - def get_prompt(self) -> str: + def get_base_prompt(self) -> str: return "Move 2 meters to the front." + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can determine the robot's current position and orientation, " + "then move it 2 meters forward in the direction it is currently facing." + ) + class MoveToBedTask(NavigationTask): recursion_limit = 50 - complexity = "medium" + complexity = "hard" - @property - def available_tools(self) -> List[BaseTool]: - return [ - MockGetROS2ActionsNamesAndTypesTool( - mock_actions_names_and_types=ACTION_STRINGS - ), - MockStartROS2ActionTool( - available_actions=list(ACTIONS_AND_TYPES.keys()), - available_action_types=list(ACTIONS_AND_TYPES.values()), - available_action_models=ACTION_MODELS, - ), - MockGetROS2ActionFeedbackTool(), - MockGetROS2ActionResultTool(), - MockGetROS2MessageInterfaceTool(mock_interfaces=INTERFACES), - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=TOPICS_NAMES_AND_TYPES - ), - MockGetDistanceToObjectsTool( - available_topics=[ - "/camera/camera/color/image_raw", - "/camera/camera/depth/image_rect_raw", - ], - mock_distance_measurements=[ - DistanceMeasurement(name="bed", distance=5.0) - ], - ), - ] + def get_base_prompt(self) -> str: + return "Move closer to the bed leaving 1 meter space." def get_prompt(self) -> str: - return "Move closer to the to the bed. Leave 1 meter of space between the bed and you." + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can locate the bed in the environment, calculate the appropriate position " + "that maintains 1 meter distance from the bed, and navigate to that position." + ) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py index 24e328d31..94e81742e 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py @@ -14,18 +14,36 @@ import logging -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import List from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from rai.messages import preprocess_image -from rai_bench.tool_calling_agent.interfaces import Task, Validator +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator loggers_type = logging.Logger -SPATIAL_REASONING_SYSTEM_PROMPT = "You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response with the use of the provided tools." +SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT = """You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response with the use of the provided tools.""" + +SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT = ( + SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT + + """ + +Example of tool calls: +- return_bool_response, args: {'response': True} +- return_bool_response, args: {'response': False}""" +) + +# NOTE (jmatejcz) In this case we are using only one tool so there is no difference bettween 2 and 5 shot +SPATIAL_REASONING_SYSTEM_PROMPT_5_SHOT = ( + SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT + + """ +- return_bool_response, args: {'response': True} # When object is clearly visible +- return_bool_response, args: {'response': False} # When object is not present +- return_bool_response, args: {'response': True} # When spatial relationship is correct""" +) class TaskParametrizationError(Exception): @@ -34,32 +52,54 @@ class TaskParametrizationError(Exception): pass +class ReturnBoolResponseToolInput(BaseModel): + response: bool = Field(..., description="The response to the question.") + + +class ReturnBoolResponseTool(BaseTool): + """Tool that returns a boolean response.""" + + name: str = "return_bool_response" + description: str = "Return a bool response to the question." + args_schema = ReturnBoolResponseToolInput + + def _run(self, response: bool) -> bool: + if type(response) is bool: + return response + raise ValueError("Invalid response type. Response must be a boolean.") + + +class BoolImageTaskInput(BaseModel): + question: str = Field(..., description="The question to be answered.") + images_paths: List[str] = Field( + ..., + description="List of image file paths to be used for answering the question.", + ) + + class SpatialReasoningAgentTask(Task): """Abstract class for spatial reasoning tasks for tool calling agent.""" + type = "spatial_reasoning" + def __init__( self, validators: List[Validator], - extra_tool_calls: int = 0, + task_args: TaskArgs, logger: loggers_type | None = None, ) -> None: super().__init__( validators=validators, - extra_tool_calls=extra_tool_calls, + task_args=task_args, logger=logger, ) self.expected_tools: List[BaseTool] self.question: str self.images_paths: List[str] - @property - def type(self) -> str: - return "spatial_reasoning" - @abstractmethod def get_images(self) -> List[str]: """Get the images related to the task. - Returns ------- List[str] @@ -68,47 +108,25 @@ def get_images(self) -> List[str]: pass def get_system_prompt(self) -> str: - return SPATIAL_REASONING_SYSTEM_PROMPT - - -class ReturnBoolResponseToolInput(BaseModel): - response: bool = Field(..., description="The response to the question.") + if self.n_shots == 0: + return SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT + elif self.n_shots == 2: + return SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT + else: + return SPATIAL_REASONING_SYSTEM_PROMPT_5_SHOT -class ReturnBoolResponseTool(BaseTool): - """Tool that returns a boolean response.""" - - name: str = "return_bool_response" - description: str = "Return a bool response to the question." - args_schema = ReturnBoolResponseToolInput - - def _run(self, response: bool) -> bool: - if type(response) is bool: - return response - raise ValueError("Invalid response type. Response must be a boolean.") - - -class BoolImageTaskInput(BaseModel): - question: str = Field(..., description="The question to be answered.") - images_paths: List[str] = Field( - ..., - description="List of image file paths to be used for answering the question.", - ) - - -class BoolImageTask(SpatialReasoningAgentTask): - complexity = "easy" - +class BoolImageTask(SpatialReasoningAgentTask, ABC): def __init__( self, task_input: BoolImageTaskInput, validators: List[Validator], - extra_tool_calls: int = 0, + task_args: TaskArgs, logger: loggers_type | None = None, ) -> None: super().__init__( validators=validators, - extra_tool_calls=extra_tool_calls, + task_args=task_args, logger=logger, ) self.question = task_input.question @@ -118,9 +136,41 @@ def __init__( def available_tools(self) -> List[BaseTool]: return [ReturnBoolResponseTool()] - def get_prompt(self): + @property + def optional_tool_calls_number(self) -> int: + return 0 + + def get_base_prompt(self) -> str: return self.question + def get_prompt(self): + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()}" + "You can examine the provided image(s) carefully to identify relevant features, " + "analyze the visual content, and provide a boolean response based on your observations." + ) + def get_images(self): images = [preprocess_image(image_path) for image_path in self.images_paths] return images + + +# NOTE (jmatejcz) spatial reasoning task's difficulty is based solely on prompt and image +# so in this case when declaring task, please subjectivly decide how hard is the task +# examples: +# easy -> locating single object, tell if it is present +# medium -> tell in what state is the object (is door open?) or locating multiple objects +# hard -> locating multiple objects and resoning about their relative positions (is X on the right side of Y?) +class BoolImageTaskEasy(BoolImageTask): + complexity = "easy" + + +class BoolImageTaskMedium(BoolImageTask): + complexity = "medium" + + +class BoolImageTaskHard(BoolImageTask): + complexity = "hard" diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index 0da9aee5f..87629ac43 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -43,7 +43,23 @@ def parse_tool_calling_benchmark_args(): nargs="+", choices=["easy", "medium", "hard"], default=["easy", "medium", "hard"], - help="Complexity levels to include in the benchmark", + help="Complexity levels of tasks to include in the benchmark", + ) + parser.add_argument( + "--prompt-detail", + type=str, + nargs="+", + choices=["brief", "descriptive"], + default=["brief", "descriptive"], + help="Prompt detail level to include in the benchmark", + ) + parser.add_argument( + "--n-shots", + type=int, + nargs="+", + choices=[0, 2, 5], + default=[0, 2, 5], + help="Number of examples in system prompt for few-shot prompting", ) parser.add_argument( "--task-types", diff --git a/src/rai_core/rai/initialization/model_initialization.py b/src/rai_core/rai/initialization/model_initialization.py index ce0c4ace5..c7f4fa263 100644 --- a/src/rai_core/rai/initialization/model_initialization.py +++ b/src/rai_core/rai/initialization/model_initialization.py @@ -190,6 +190,8 @@ def get_llm_model_direct( elif vendor == "ollama": from langchain_ollama import ChatOllama + # Suppress httpx info logging for Ollama + logging.getLogger("httpx").setLevel(logging.WARNING) model_config = cast(OllamaConfig, model_config) return ChatOllama(model=model_name, base_url=model_config.base_url, **kwargs) else: From 7c499588e6b2e8b36a2fd3ec558a763713a3d64d Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Thu, 3 Jul 2025 09:17:03 +0200 Subject: [PATCH 02/13] feat: basic tasks extension (#644) --- .../rai_bench/tool_calling_agent/benchmark.py | 12 +- .../tool_calling_agent/interfaces.py | 112 +- .../mocked_ros2_interfaces.py | 495 +++++- .../tool_calling_agent/mocked_tools.py | 122 +- .../predefined/basic_tasks.py | 264 ++-- .../tool_calling_agent/tasks/basic.py | 530 ++++++- .../tool_calling_agent/validators.py | 51 + .../tool_calling_agent/test_mock_tools.py | 163 ++ .../test_predefined_tasks.py | 1362 +++++++++++++++++ .../tool_calling_agent/test_subtasks.py | 258 ++-- .../tool_calling_agent/test_validators.py | 343 +++++ 11 files changed, 3359 insertions(+), 353 deletions(-) create mode 100644 tests/rai_bench/tool_calling_agent/test_mock_tools.py create mode 100644 tests/rai_bench/tool_calling_agent/test_predefined_tasks.py diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index 4328c0a8d..cf73af238 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -87,6 +87,14 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: ) callbacks = self.score_tracing_handler.get_callbacks() run_id = uuid.uuid4() + # NOTE (jmatejcz) recursion limit calculated as all_nodes_num -> one pass though whole node + # plus (task.max_tool_calls_number-1 because the first pass is already added in) + # times number of nodes - 2 because we dont cout start and end node + # this can be to much for larger graphs that dont use all nodes on extra calls + # in such ase adjust this value + recurssion_limit = len(agent.get_graph().nodes) + ( + task.max_tool_calls_number - 1 + ) * (len(agent.get_graph().nodes) - 2) config: RunnableConfig = { "run_id": run_id, "callbacks": callbacks, @@ -97,9 +105,9 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: f"task-complexity:{task.complexity}", f"extra-tool-calls:{task.extra_tool_calls}", ], - "recursion_limit": len(agent.get_graph().nodes) - * task.max_tool_calls_number, + "recursion_limit": recurssion_limit, } + self.logger.debug(f"recurssion limit: {recurssion_limit}") ts = time.perf_counter() messages: List[BaseMessage] = [] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py index e35d5a6f4..3ee930bbf 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py @@ -118,6 +118,61 @@ def _check_tool_call( ) return True + def _check_nested_fields( + self, field_path: str, args: Dict[str, Any], expected_value: Any + ) -> bool: + """Check if a nested field in the arguments matches the expected value. + + Parameters + ---------- + field_path : str + Dot-separated path to the field (e.g., "header.frame_id"). + Including indexes for lists possible (e.g., "data[0].value"). + args : Dict[str, Any] + The arguments dictionary to check. + expected_value : Any + The expected value for the field. + + Returns + ------- + bool + True if the field matches the expected value, False otherwise. + + + Raises + ------ + SubTaskValidationError + If the field is not found or does not match the expected value. + """ + + keys = field_path.split(".") + value: Any = args + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + elif isinstance(value, list): + # try to access index + try: + index = int(key) + except ValueError: + raise SubTaskValidationError(f"Expected numeric index, got '{key}'") + + if 0 <= index < len(value): + value = value[index] + else: + raise SubTaskValidationError(f"Index {index} out of range") + + else: + raise SubTaskValidationError( + f"Field path '{field_path}' not found in the message." + ) + + if value != expected_value: + raise SubTaskValidationError( + f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'." + ) + return True + def _check_topic_tool_call_field( self, tool_call: ToolCall, @@ -173,22 +228,9 @@ def _check_topic_tool_call_field( "Tool call does not contain a 'message' argument." ) - keys = field_path.split(".") - value: Any = message - for key in keys: - if isinstance(value, dict) and key in value: - value = value[key] - else: - raise SubTaskValidationError( - f"Field path '{field_path}' not found in the message." - ) - - if value != expected_value: - raise SubTaskValidationError( - f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'." - ) - - return True + return self._check_nested_fields( + field_path=field_path, args=message, expected_value=expected_value + ) def _check_service_tool_call_field( self, @@ -253,22 +295,9 @@ def _check_service_tool_call_field( f"Expected empty service_args, but got: {service_args}" ) - keys = field_path.split(".") - value: Any = service_args - for key in keys: - if isinstance(value, dict) and key in value: - value = value[key] - else: - raise SubTaskValidationError( - f"Field path '{field_path}' not found in the message." - ) - - if value != expected_value: - raise SubTaskValidationError( - f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'." - ) - - return True + return self._check_nested_fields( + field_path=field_path, args=service_args, expected_value=expected_value + ) def _check_action_tool_call_field( self, @@ -333,22 +362,9 @@ def _check_action_tool_call_field( f"Expected empty action_args, but got: {action_args}" ) - keys = field_path.split(".") - value: Any = action_args - for key in keys: - if isinstance(value, dict) and key in value: - value = value[key] - else: - raise SubTaskValidationError( - f"Field path '{field_path}' not found in the action_args." - ) - - if value != expected_value: - raise SubTaskValidationError( - f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'." - ) - - return True + return self._check_nested_fields( + field_path=field_path, args=action_args, expected_value=expected_value + ) class Validator(ABC): diff --git a/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py index f9758499f..4dd584cb2 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py @@ -53,6 +53,499 @@ # the dict contains custom as well as couple other common interfaces COMMON_INTERFACES: Dict[str, str] = { + "std_srvs/srv/Empty": """# Empty service - no request or response +--- +""", + "std_srvs/srv/Trigger": """# Simple service to trigger an action +--- +bool success # indicate successful run of triggered service +string message # informational, e.g. for error messages +""", + "std_srvs/srv/SetBool": """bool data # e.g. for hardware enabling / disabling +--- +bool success # indicate successful run of triggered service +string message # informational, e.g. for error messages +""", + "std_srvs/srv/SetString": """string data +--- +bool success +string message +""", + "lifecycle_msgs/srv/ChangeState": """Transition transition + uint8 id + string label +--- +bool success +""", + "lifecycle_msgs/srv/GetState": """--- +State current_state + uint8 id + string label +""", + "lifecycle_msgs/srv/GetAvailableStates": """--- +State[] available_states + uint8 id + string label +""", + "lifecycle_msgs/srv/GetAvailableTransitions": """--- +TransitionDescription[] available_transitions + Transition transition + uint8 id + string label + State start_state + uint8 id + string label + State goal_state + uint8 id + string label +""", + "rcl_interfaces/msg/ParameterEvent": """# This message is published when parameters change for a node +Parameter[] changed_parameters + string name + ParameterValue value + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value + +Parameter[] deleted_parameters + string name + ParameterValue value + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value + +string node +""", + "rcl_interfaces/msg/Log": """# This message represents a log message published on the /rosout topic +# severity level constants +byte DEBUG=10 +byte INFO=20 +byte WARN=30 +byte ERROR=40 +byte FATAL=50 + +# message fields +builtin_interfaces/Time stamp + int32 sec + uint32 nanosec +byte level +string name # name of the node +string msg # message text +string file # file the message came from +string function # function the message came from +uint32 line # line the message came from +""", + "tf2_msgs/msg/TFMessage": """# An array of transforms with a header for the coordinate frame +geometry_msgs/TransformStamped[] transforms + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + string child_frame_id + Transform transform + Vector3 translation + float64 x + float64 y + float64 z + Quaternion rotation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +""", + "sensor_msgs/msg/JointState": """# This is a message that holds data to describe the state of a set of torque controlled joints. +# +# The state of each joint (revolute or prismatic) is defined by: +# * the position of the joint (rad or m), +# * the velocity of the joint (rad/s or m/s) and +# * the effort that is applied in the joint (Nm or N). +# +# Each joint is uniquely identified by its name +# The header specifies the time at which the joint states were recorded. All the joint states +# in one message have to be recorded at the same time. +# +# This message consists of a multiple arrays, one for each part of the joint state. +# The goal is to make each of the fields optional. When e.g. your joints have no +# velocity or effort sensors, you can leave the velocity and effort arrays empty. +# +# All arrays in this message should have the same size. + +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id +string[] name +float64[] position +float64[] velocity +float64[] effort +""", + "std_msgs/msg/String": """# Please look at the Standard ROS Messages documentation before using this. +# http://wiki.ros.org/std_msgs +string data +""", + "bond/msg/Status": """# An array of bond ids that this node is maintaining +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id +string id # unique identifier for the bond +string instance_id # identifier for this instance of the node +bool active # whether the bond is currently active +float32 heartbeat_timeout # timeout for heartbeat in seconds +float32 heartbeat_period # period for heartbeat messages in seconds +""", + "diagnostic_msgs/msg/DiagnosticArray": """# This message contains a list of diagnostic statuses +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + +DiagnosticStatus[] status + # Possible levels of operations + byte OK=0 + byte WARN=1 + byte ERROR=2 + byte STALE=3 + + byte level # level of operation enumerated above + string name # a description of the test/component reporting + string message # a description of the status + string hardware_id # a hardware unique string + KeyValue[] values # an array of values associated with the status + string key + string value +""", + "sensor_msgs/msg/PointCloud2": """# This message holds a collection of N-dimensional points, which may +# contain additional information such as normals, intensity, etc. The +# point data is stored as a binary blob, its format described by the +# contents of the \"fields\" array. + +# The point cloud data may be organized 2d (image-like) or 1d +# (unordered). Point clouds organized as 2d images may be produced by +# camera depth sensors such as stereo or time-of-flight. + +# Time of sensor data acquisition, and the coordinate frame ID (for 3d +# points). +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + +# 2D structure of the point cloud. If the cloud is unordered, height is +# 1 and width is the length of the point cloud. +uint32 height +uint32 width + +# Describes the channels and their layout in the binary data blob. +PointField[] fields + uint8 INT8 = 1 + uint8 UINT8 = 2 + uint8 INT16 = 3 + uint8 UINT16 = 4 + uint8 INT32 = 5 + uint8 UINT32 = 6 + uint8 FLOAT32 = 7 + uint8 FLOAT64 = 8 + + string name # Name of field + uint32 offset # Offset from start of point struct + uint8 datatype # Datatype enumeration, see above + uint32 count # How many elements in the field + +bool is_bigendian # Is this data bigendian? +uint32 point_step # Length of a point in bytes +uint32 row_step # Length of a row in bytes +uint8[] data # Actual point data, size is (row_step*height) + +bool is_dense # True if there are no invalid points +""", + "sensor_msgs/msg/LaserScan": """# Single scan from a planar laser range-finder +# +# If you have another ranging device with different behavior (e.g. a sonar +# array), please find or create a different message, since applications +# will make fairly laser-specific assumptions about this data + +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id +float32 angle_min # start angle of the scan [rad] +float32 angle_max # end angle of the scan [rad] +float32 angle_increment # angular distance between measurements [rad] + +float32 time_increment # time between measurements [seconds] - if your scanner + # is moving, this will be used in interpolating position + # of 3d points +float32 scan_time # time between scans [seconds] + +float32 range_min # minimum range value [m] +float32 range_max # maximum range value [m] + +float32[] ranges # range data [m] (Note: values < range_min or > range_max should be discarded) +float32[] intensities # intensity data [device-specific units]. If your + # device does not provide intensities, please leave + # the array empty. +""", + "nav_msgs/msg/Odometry": """# This represents an estimate of a position and velocity in free space. +# The pose in this message should be specified in the coordinate frame given by header.frame_id. +# The twist in this message should be specified in the coordinate frame given by the child_frame_id +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id +string child_frame_id +geometry_msgs/PoseWithCovariance pose + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64[36] covariance # Row-major representation of the 6x6 covariance matrix +geometry_msgs/TwistWithCovariance twist + Twist twist + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z + float64[36] covariance # Row-major representation of the 6x6 covariance matrix +""", + # Services from COMMON_SERVICES_AND_TYPES that are missing + "tf2_msgs/srv/FrameGraph": """--- +string frame_yaml +""", + "composition_interfaces/srv/ListNodes": """--- +# All unique node names within the container +string[] unique_names +# Full node names corresponding to each unique node name +string[] full_node_names +""", + "composition_interfaces/srv/LoadNode": """LoadNodeRequest request + string package_name + string plugin_name + string node_name + string node_namespace + string[] remap_rules + Parameter[] parameters + string name + ParameterValue value + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value + string[] extra_arguments +--- +bool success +string error_message +string full_node_name +uint64 unique_id +""", + "composition_interfaces/srv/UnloadNode": """uint64 unique_id +--- +bool success +string error_message +""", + "rcl_interfaces/srv/DescribeParameters": """string[] names +--- +ParameterDescriptor[] descriptors + string name + uint8 type + string description + string additional_constraints + bool read_only + bool dynamic_typing + ParameterValue floating_point_range + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value + ParameterValue integer_range + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value +""", + "rcl_interfaces/srv/GetParameterTypes": """string[] names +--- +uint8[] types +# Possible parameter types: +uint8 PARAMETER_NOT_SET=0 +uint8 PARAMETER_BOOL=1 +uint8 PARAMETER_INTEGER=2 +uint8 PARAMETER_DOUBLE=3 +uint8 PARAMETER_STRING=4 +uint8 PARAMETER_BYTE_ARRAY=5 +uint8 PARAMETER_BOOL_ARRAY=6 +uint8 PARAMETER_INTEGER_ARRAY=7 +uint8 PARAMETER_DOUBLE_ARRAY=8 +uint8 PARAMETER_STRING_ARRAY=9 +""", + "rcl_interfaces/srv/GetParameters": """string[] names +--- +ParameterValue[] values + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value +""", + "rcl_interfaces/srv/ListParameters": """ListParametersRequest request + string[] prefixes + uint64 depth +--- +ListParametersResult result + string[] names + string[] prefixes +""", + "rcl_interfaces/srv/SetParametersAtomically": """Parameter[] parameters + string name + ParameterValue value + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value +--- +SetParametersResult result + bool successful + string reason +""", + "gazebo_msgs/srv/GetWorldProperties": """# Service to get world properties +--- +string[] model_names +string[] light_names +bool rendering_enabled +bool physics_enabled +bool physics_paused +float64 sim_time +""", + "gazebo_msgs/srv/GetModelState": """string model_name +string relative_entity_name # return pose relative to this entity + # an empty string will return world relative pose +--- +geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +geometry_msgs/Twist twist + Vector3 linear + float64 x + float64 y + float64 z + Vector3 angular + float64 x + float64 y + float64 z +bool success +string status_message +""", + "gazebo_msgs/srv/DeleteEntity": """string name # Name of the Gazebo entity to be deleted. This can be either + # a model or a light. +--- +bool success # Return true if deletion is successful. +string status_message # Comments if available. +""", + "gazebo_msgs/srv/SpawnEntity": """string name # Name of the entity to be spawned (optional). +string xml # Entity XML description as a string, either URDF or SDF. +string robot_namespace # Spawn robot and all ROS interfaces under this namespace +geometry_msgs/Pose initial_pose # Initial entity pose. + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +string reference_frame # initial_pose is defined relative to the frame of this entity. + # If left empty or "world" or "map", then gazebo world frame is + # used. + # If non-existent entity is specified, an error is returned + # and the entity is not spawned. +--- +bool success # Return true if spawned successfully. +string status_message # Comments if available.""", + "rcl_interfaces/srv/SetParameters": """# A list of parameters to set. +Parameter[] parameters + string name + ParameterValue value + uint8 type + bool bool_value + int64 integer_value + float64 double_value + string string_value + byte[] byte_array_value + bool[] bool_array_value + int64[] integer_array_value + float64[] double_array_value + string[] string_array_value + +--- +# Indicates whether setting each parameter succeeded or not and why.""", "sensor_msgs/msg/CameraInfo": """ # This message defines meta information for a camera. It should be in a # camera namespace on topic "camera_info" and accompanied by up to five @@ -2888,8 +3381,6 @@ "/depth_image5": "sensor_msgs/msg/Image", "/pointcloud": "sensor_msgs/msg/PointCloud2", "/scan": "sensor_msgs/msg/LaserScan", - "/odom": "nav_msgs/msg/Odometry", - "/odometry/filtered": "nav_msgs/msg/Odometry", } MANIPULATION_TOPICS_AND_TYPES: Dict[str, str] = { diff --git a/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py b/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py index 85505df8e..443a620a6 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import uuid from threading import Lock -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type from unittest.mock import MagicMock import numpy as np import numpy.typing as npt from langchain_core.tools import BaseTool -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError, computed_field +from rai.communication.ros2.api.conversion import import_message_from_str from rai.communication.ros2.connectors import ROS2Connector from rai.communication.ros2.messages import ROS2Message from rai.messages import MultimodalArtifact, preprocess_image @@ -48,6 +50,7 @@ GetDistanceToObjectsTool, GetGrabbingPointTool, ) +from rosidl_runtime_py import set_message_fields class MockGetROS2TopicsNamesAndTypesTool(GetROS2TopicsNamesAndTypesTool): @@ -265,21 +268,116 @@ def _run(self, msg_type: str) -> str: raise ImportError(f"Module {msg_type} not found.") +class ServiceValidator: + """ + Validator that is responsible for checking if given service type exists + and if it is used correctly. + Validator uses ROS 2 native types when available, + falls back to Pydantic models of custom interfaces when not. + """ + + def __init__(self, custom_models: Dict[str, Type[BaseModel]]): + self.custom_models = custom_models + self.ros2_services_cache: Dict[str, Any] = {} + + def validate_with_ros2(self, service_type: str, args: Dict[str, Any]): + """Validate using installed ROS2 packages services definition + + Parameters + ---------- + service_type : str + args : Dict[str, Any] + Dictionary of arguments to validate against the service definition. + + Raises + ------ + TypeError + When service type does not exist in ROS2 installed packages + """ + service_class = import_message_from_str(service_type) + if not service_class: + raise TypeError(f"Service type: {service_type} does not exist.") + + request = service_class.Request() + # set message fields converts them to object so we need deepcopy to avoid it + args_to_validate = copy.deepcopy(args) + set_message_fields(request, args_to_validate) + + def validate_with_custom(self, service_type: str, args: Dict[str, Any]): + """ + Validate using Pydantic model of custom messages. + + Parameters + ---------- + service_type : str + args : Dict[str, Any] + Dictionary of arguments to validate against the Pydantic model. + + Raises + ------ + ValueError + If service_type is not found in custom_models or if Pydantic + validation fails. + """ + if service_type not in self.custom_models: + raise ValueError(f"Service type: {service_type} is invalid custom type") + + model = self.custom_models[service_type] + try: + model.model_validate(args) + except ValidationError as e: + raise ValueError(f"Pydantic validation failed: {e}") from e + + def validate(self, service_type: str, args: Dict[str, Any]): + """ + Try ROS 2 validation first, fall back to Pydantic models. + + Parameters + ---------- + service_type : str + args : Dict[str, Any] + Dictionary of arguments to validate. + """ + if service_type in self.custom_models: + self.validate_with_custom(service_type, args) + else: + self.validate_with_ros2(service_type, args) + + class MockCallROS2ServiceTool(CallROS2ServiceTool): connector: ROS2Connector = MagicMock(spec=ROS2Connector) available_services: List[str] available_service_types: List[str] available_service_models: Dict[str, Type[BaseModel]] + @computed_field + @property + def models_validator(self) -> ServiceValidator: + """computed field for instancinating ServiceValidator with available service models""" + return ServiceValidator(self.available_service_models) + def _run( self, service_name: str, service_type: str, - service_args: Dict[str, Any], + service_args: Optional[Dict[str, Any]] = None, + timeout_sec: float = 1.0, ) -> str: + """ + Execute the mocked ROS2 service call with validation of service type and its args. + + Parameters + ---------- + service_name : str + Name of the service to call + service_type : str + Type of the service + service_args : Optional[Dict[str, Any]], optional + Arguments for the service call, by default None + """ if service_name not in self.available_services: raise ValueError( - f"Service {service_name} is not available within 1.0 seconds. Check if the service exists." + f"Service {service_name} is not available within {timeout_sec} seconds. Check if the service exists." ) if service_type not in self.available_service_types: raise TypeError( @@ -287,12 +385,10 @@ def _run( self.available_service_types, service_type ) ) - if service_type in self.available_service_models: - model = self.available_service_models[service_type] - try: - model.model_validate(service_args) - except ValidationError as e: - raise ValueError(f"Failed to populate fields: {e}") + if not service_args: + service_args = {} + try: + self.models_validator.validate(service_type, service_args) response = ROS2Message(payload={"response": "success"}) return str( { @@ -300,10 +396,8 @@ def _run( "metadata": response.metadata, } ) - else: - raise KeyError( - f"Model for service type {service_type} not included in models" - ) + except ValueError as e: + raise ValueError(f"Failed to populate fields: {e}") class MockGetROS2ServicesNamesAndTypesTool(GetROS2ServicesNamesAndTypesTool): diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py index 2a2ae7e88..aa8e99f1d 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import List, Literal from rai_bench.tool_calling_agent.interfaces import ( @@ -20,223 +19,112 @@ ) from rai_bench.tool_calling_agent.subtasks import ( CheckArgsToolCallSubTask, + CheckServiceFieldsToolCallSubTask, ) from rai_bench.tool_calling_agent.tasks.basic import ( - AssessSensorDataQualityTask, - CheckRobotHealthTask, + BOX1_ENTITY, + BOX1_POSITION, + BOX2_ENTITY, + BOX2_POSITION, + COLOR_IMAGE_TOPIC, + DEFAULT_DINO_CONFIDENCE, + DEFAULT_FPS, + DEFAULT_PUBLISH_FREQUENCY, + DEFAULT_SAM_CONFIDENCE, + DEPTH_IMAGE_TOPIC, + GET_SPAWNABLE_NAMES_SERVICE, + GET_WORLD_PROPERTIES_TYPE, + LIST_PARAMETERS_TYPE, + POINTCLOUD_TOPIC, + ROBOT_DESCRIPTION_TOPIC, + ROBOT_STATE_PUBLISHER_LIST_PARAMS, + TOMATO_ENTITY, + CheckSpawnableEntitiesTask, + ConfigureVisionPipelineTask, GetAllROS2CamerasTask, GetPointcloudTask, GetRobotDescriptionTask, GetROS2DepthCameraTask, GetROS2RGBCameraTask, + GetROS2ServicesTask, GetROS2TopicsTask, + GetSpecificParameterTask, + ListRobotParametersTask, + RespawnEntitiesTask, + SetRobotParameterTask, + SpawnEntityTask, ) from rai_bench.tool_calling_agent.validators import ( NotOrderedCallsValidator, OrderedCallsValidator, ) -########## SUBTASKS ################################################################# +########## SUBTASKS FOR TASKS WITHOUT REFACTORED VALIDATORS ################################################################# get_topics_subtask = CheckArgsToolCallSubTask( expected_tool_name="get_ros2_topics_names_and_types", expected_args={} ) color_image5_subtask = CheckArgsToolCallSubTask( expected_tool_name="get_ros2_image", - expected_args={"topic": "/color_image5"}, + expected_args={"topic": COLOR_IMAGE_TOPIC}, expected_optional_args={"timeout_sec": int}, ) depth_image5_subtask = CheckArgsToolCallSubTask( expected_tool_name="get_ros2_image", - expected_args={"topic": "/depth_image5"}, - expected_optional_args={"timeout_sec": int}, -) - -color_camera_info5_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_image", - expected_args={"topic": "/color_image5"}, - expected_optional_args={"timeout_sec": int}, -) -depth_camera_info5_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_image", - expected_args={"topic": "/depth_image5"}, - expected_optional_args={"timeout_sec": int}, -) - -receive_robot_desc_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description"}, + expected_args={"topic": DEPTH_IMAGE_TOPIC}, expected_optional_args={"timeout_sec": int}, ) receive_pointcloud_subtask = CheckArgsToolCallSubTask( expected_tool_name="receive_ros2_message", - expected_args={"topic": "/pointcloud"}, - expected_optional_args={"timeout_sec": int}, -) - -# System health subtasks -diagnostics_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/diagnostics"}, - expected_optional_args={"timeout_sec": int}, -) -rosout_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/rosout"}, - expected_optional_args={"timeout_sec": int}, -) -joint_states_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/joint_states"}, + expected_args={"topic": POINTCLOUD_TOPIC}, expected_optional_args={"timeout_sec": int}, ) -# Odometry subtasks -odom_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/odom"}, - expected_optional_args={"timeout_sec": int}, -) -filtered_odom_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/odometry/filtered"}, - expected_optional_args={"timeout_sec": int}, -) - -# Transform subtasks -tf_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/tf"}, - expected_optional_args={"timeout_sec": int}, -) -tf_static_subtask = CheckArgsToolCallSubTask( +receive_robot_desc_subtask = CheckArgsToolCallSubTask( expected_tool_name="receive_ros2_message", - expected_args={"topic": "/tf_static"}, + expected_args={"topic": ROBOT_DESCRIPTION_TOPIC}, expected_optional_args={"timeout_sec": int}, ) - -# Robot description subtasks -robot_description_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description"}, - expected_optional_args={"timeout_sec": int}, -) -robot_description_semantic_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description_semantic"}, - expected_optional_args={"timeout_sec": int}, +get_services_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_services_names_and_types", expected_args={} ) -# Sensor data subtasks -scan_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/scan"}, - expected_optional_args={"timeout_sec": int}, +list_parameters_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=ROBOT_STATE_PUBLISHER_LIST_PARAMS, + expected_service_type=LIST_PARAMETERS_TYPE, + expected_fields={"": {}}, ) -pointcloud_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/pointcloud"}, - expected_optional_args={"timeout_sec": int}, -) - -# Robot description subtasks -robot_description_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description"}, - expected_optional_args={"timeout_sec": int}, -) -robot_description_semantic_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description_semantic"}, - expected_optional_args={"timeout_sec": int}, +check_spawnable_entities_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GET_SPAWNABLE_NAMES_SERVICE, + expected_service_type=GET_WORLD_PROPERTIES_TYPE, + expected_fields={"": {}}, ) -# Sensor data subtasks -scan_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/scan"}, - expected_optional_args={"timeout_sec": int}, -) -pointcloud_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/pointcloud"}, - expected_optional_args={"timeout_sec": int}, -) - -# Robot description subtasks -robot_description_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description"}, - expected_optional_args={"timeout_sec": int}, -) -robot_description_semantic_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_description_semantic"}, - expected_optional_args={"timeout_sec": int}, -) - -######### VALIDATORS ######################################################################################### +######### VALIDATORS FOR TASKS WITHOUT REFACTORED VALIDATORS ######################################################################################### topics_ord_val = OrderedCallsValidator(subtasks=[get_topics_subtask]) color_image_ord_val = OrderedCallsValidator(subtasks=[color_image5_subtask]) depth_image_ord_val = OrderedCallsValidator(subtasks=[depth_image5_subtask]) -color_camera_info_ord_val = OrderedCallsValidator(subtasks=[color_camera_info5_subtask]) -depth_camera_info_ord_val = OrderedCallsValidator(subtasks=[depth_camera_info5_subtask]) - -color_image_with_info_ord_val = NotOrderedCallsValidator( - subtasks=[color_image5_subtask, color_camera_info5_subtask] -) -depth_image_with_info_ord_val = NotOrderedCallsValidator( - subtasks=[depth_image5_subtask, color_camera_info5_subtask] -) - all_camera_images_notord_val = NotOrderedCallsValidator( subtasks=[ color_image5_subtask, depth_image5_subtask, ] ) -all_camera_info_notord_val = NotOrderedCallsValidator( - subtasks=[ - color_camera_info5_subtask, - depth_camera_info5_subtask, - ] -) -all_camera_images_with_info_notord_val = NotOrderedCallsValidator( - subtasks=[ - color_image5_subtask, - depth_image5_subtask, - color_camera_info5_subtask, - depth_camera_info5_subtask, - ] -) - -joint_states_ord_val = OrderedCallsValidator(subtasks=[joint_states_subtask]) -diagnostics_ord_val = OrderedCallsValidator(subtasks=[diagnostics_subtask]) get_pointcloud_ord_val = OrderedCallsValidator(subtasks=[receive_pointcloud_subtask]) get_robot_desc_ord_val = OrderedCallsValidator(subtasks=[receive_robot_desc_subtask]) -robot_health_val = NotOrderedCallsValidator( - subtasks=[diagnostics_subtask, joint_states_subtask, rosout_subtask] -) - -odometry_comparison_val = NotOrderedCallsValidator( - subtasks=[odom_subtask, filtered_odom_subtask] -) -sensor_data_val = NotOrderedCallsValidator( - subtasks=[ - scan_subtask, - receive_pointcloud_subtask, - color_image5_subtask, - depth_image5_subtask, - color_camera_info5_subtask, - depth_camera_info5_subtask, - ] +services_ord_val = OrderedCallsValidator(subtasks=[get_services_subtask]) +list_parameters_val = OrderedCallsValidator(subtasks=[list_parameters_subtask]) +check_spawnable_entities_val = OrderedCallsValidator( + subtasks=[check_spawnable_entities_subtask] ) @@ -271,33 +159,67 @@ def get_basic_tasks( tasks.extend( [ GetROS2RGBCameraTask( - validators=[color_image_ord_val], task_args=task_args, + validators=[color_image_ord_val], ), GetROS2TopicsTask( - validators=[topics_ord_val], task_args=task_args, + validators=[topics_ord_val], ), GetROS2DepthCameraTask( - validators=[depth_image_ord_val], task_args=task_args, + validators=[depth_image_ord_val], ), GetAllROS2CamerasTask( - validators=[all_camera_images_notord_val], task_args=task_args, + validators=[all_camera_images_notord_val], ), GetPointcloudTask( - validators=[get_pointcloud_ord_val], task_args=task_args + task_args=task_args, + validators=[get_pointcloud_ord_val], ), GetRobotDescriptionTask( - validators=[get_robot_desc_ord_val], task_args=task_args + task_args=task_args, + validators=[get_robot_desc_ord_val], + ), + GetROS2ServicesTask( + task_args=task_args, + validators=[services_ord_val], + ), + ListRobotParametersTask( + task_args=task_args, + validators=[list_parameters_val], + ), + CheckSpawnableEntitiesTask( + task_args=task_args, + validators=[check_spawnable_entities_val], + ), + # Tasks with refactored validators - now use defaults + GetSpecificParameterTask( + parameter="publish_frequency", + task_args=task_args, + ), + SpawnEntityTask( + entity=TOMATO_ENTITY, + task_args=task_args, + ), + SetRobotParameterTask( + value=DEFAULT_PUBLISH_FREQUENCY, + task_args=task_args, + ), + SetRobotParameterTask( + value=25.0, + task_args=task_args, ), - CheckRobotHealthTask( - validators=[robot_health_val], + ConfigureVisionPipelineTask( + sam_confidence_threshold=DEFAULT_SAM_CONFIDENCE, + dino_confidence_threshold=DEFAULT_DINO_CONFIDENCE, + fps=DEFAULT_FPS, task_args=task_args, ), - AssessSensorDataQualityTask( - validators=[sensor_data_val], + RespawnEntitiesTask( + names=[BOX1_ENTITY, BOX2_ENTITY], + coords=[BOX1_POSITION, BOX2_POSITION], task_args=task_args, ), ] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py index 9eb1a159c..929f027b8 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py @@ -14,21 +14,97 @@ import logging from abc import ABC -from typing import List +from typing import List, Optional, Tuple from langchain_core.tools import BaseTool from rai_bench.tool_calling_agent.interfaces import ( + SubTask, Task, + TaskArgs, + Validator, +) +from rai_bench.tool_calling_agent.mocked_ros2_interfaces import ( + COMMON_INTERFACES, + COMMON_SERVICES_AND_TYPES, + COMMON_TOPICS_AND_TYPES, ) -from rai_bench.tool_calling_agent.mocked_ros2_interfaces import COMMON_TOPICS_AND_TYPES from rai_bench.tool_calling_agent.mocked_tools import ( + MockCallROS2ServiceTool, MockGetROS2ImageTool, + MockGetROS2MessageInterfaceTool, + MockGetROS2ServicesNamesAndTypesTool, MockGetROS2TopicsNamesAndTypesTool, MockReceiveROS2MessageTool, ) +from rai_bench.tool_calling_agent.subtasks import CheckServiceFieldsToolCallSubTask +from rai_bench.tool_calling_agent.validators import ( + NotOrderedCallsValidator, + OneFromManyValidator, + OrderedCallsValidator, +) + +COLOR_IMAGE_TOPIC = "/color_image5" +DEPTH_IMAGE_TOPIC = "/depth_image5" +COLOR_CAMERA_INFO_TOPIC = "/color_camera_info5" +DEPTH_CAMERA_INFO_TOPIC = "/depth_camera_info5" +ROBOT_DESCRIPTION_TOPIC = "/robot_description" +POINTCLOUD_TOPIC = "/pointcloud" +SCAN_TOPIC = "/scan" + +ROBOT_STATE_PUBLISHER_LIST_PARAMS = "/robot_state_publisher/list_parameters" +ROBOT_STATE_PUBLISHER_GET_PARAMS = "/robot_state_publisher/get_parameters" +ROBOT_STATE_PUBLISHER_SET_PARAMS = "/robot_state_publisher/set_parameters" +ROBOT_STATE_PUBLISHER_SET_PARAMS_ATOMICALLY = ( + "/robot_state_publisher/set_parameters_atomically" +) + +SPAWN_ENTITY_SERVICE = "/spawn_entity" +DELETE_ENTITY_SERVICE = "/delete_entity" +GET_SPAWNABLE_NAMES_SERVICE = "/get_available_spawnable_names" +GROUNDED_SAM_SET_PARAMS = "/grounded_sam/set_parameters" +GROUNDED_SAM_SET_PARAMS_ATOMICALLY = "/grounded_sam/set_parameters_atomically" +GROUNDING_DINO_SET_PARAMS = "/grounding_dino/set_parameters" +GROUNDING_DINO_SET_PARAMS_ATOMICALLY = "/grounding_dino/set_parameters_atomically" +O3DE_SET_PARAMS = "/o3de_ros2_node/set_parameters" +O3DE_SET_PARAMS_ATOMICALLY = "/o3de_ros2_node/set_parameters_atomically" + +LIST_PARAMETERS_TYPE = "rcl_interfaces/srv/ListParameters" +SET_PARAMETERS_TYPE = "rcl_interfaces/srv/SetParameters" +SET_PARAMETERS_ATOMICALLY_TYPE = "rcl_interfaces/srv/SetParametersAtomically" +GET_PARAMETERS_TYPE = "rcl_interfaces/srv/GetParameters" +SPAWN_ENTITY_TYPE = "gazebo_msgs/srv/SpawnEntity" +DELETE_ENTITY_TYPE = "gazebo_msgs/srv/DeleteEntity" +GET_WORLD_PROPERTIES_TYPE = "gazebo_msgs/srv/GetWorldProperties" + +DEFAULT_PUBLISH_FREQUENCY = 30.0 +DEFAULT_FPS = 30 +DEFAULT_SAM_CONFIDENCE = 0.8 +DEFAULT_DINO_CONFIDENCE = 0.7 +SAM_CONFIDENCE_2 = 0.6 +DINO_CONFIDENCE_2 = 0.6 +FPS_2 = 10 + +TOMATO_ENTITY = "tomato" +BOX1_ENTITY = "box1" +BOX2_ENTITY = "box2" +BOX1_POSITION = (0.2, 0.2, 0.2) +BOX2_POSITION = (0.4, 0.4, 0.2) + +CAMERA_TOPICS_AND_TYPES = [ + f"topic: {COLOR_CAMERA_INFO_TOPIC}\ntype: sensor_msgs/msg/CameraInfo\n", + f"topic: {COLOR_IMAGE_TOPIC}\ntype: sensor_msgs/msg/Image\n", + f"topic: {DEPTH_CAMERA_INFO_TOPIC}\ntype: sensor_msgs/msg/CameraInfo\n", + f"topic: {DEPTH_IMAGE_TOPIC}\ntype: sensor_msgs/msg/Image\n", +] + +CAMERA_TOPICS = [ + COLOR_CAMERA_INFO_TOPIC, + COLOR_IMAGE_TOPIC, + DEPTH_CAMERA_INFO_TOPIC, + DEPTH_IMAGE_TOPIC, +] -loggers_type = logging.Logger PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT = """You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. Be proactive and use the tools to answer questions.""" @@ -36,16 +112,33 @@ PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT + """ Example of tool calls: -- get_ros2_message_interface, args: {'msg_type': 'geometry_msgs/msg/Twist'} -- publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}}""" +- name: get_ros2_topics_names_and_types, args: {} +- name: get_ros2_message_interface, args: {"msg_type": "tf2_msgs/srv/LookupTransform"}""" ) PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT = ( PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + """ -- get_ros2_topics_names_and_types, args: {} -- get_ros2_image, args: {'topic': '/camera/image_raw', 'timeout_sec': 10} -- publish_ros2_message, args: {'topic': '/turtle1/teleport_absolute', 'message_type': 'turtlesim/srv/TeleportAbsolute', 'message': {x: 5.0, y: 2.0, theta: 1.57}}""" +- name: get_ros2_image, args: {'topic': '/camera/image_raw', 'timeout_sec': 10} +- name: receive_ros2_message, args: {'topic': '/cmd_vel', 'timeout_sec': 10} +- name: call_ros2_service, args: { + "service_name": "/execute_trajectory", + "service_type": "moveit_msgs/srv/ExecuteKnownTrajectory", + "service_args": { + "trajectory": { + "joint_trajectory": { + "header": {"frame_id": "base_link"}, + "joint_names": ["joint1", "joint2"], + "points": [{ + "positions": [0.0, 1.57], + "time_from_start": {"sec": 2, "nanosec": 0} + }] + } + }, + "wait_for_execution": True + } + } +""" ) TOPIC_STRINGS = [ @@ -53,6 +146,11 @@ for topic, topic_type in COMMON_TOPICS_AND_TYPES.items() ] +SERVICE_STRINGS = [ + f"service: {service}\ntype: {msg_type}\n" + for service, msg_type in COMMON_SERVICES_AND_TYPES.items() +] + class BasicTask(Task, ABC): type = "basic" @@ -67,6 +165,15 @@ def available_tools(self) -> List[BaseTool]: MockReceiveROS2MessageTool( available_topics=list(COMMON_TOPICS_AND_TYPES.keys()) ), + MockGetROS2ServicesNamesAndTypesTool( + mock_service_names_and_types=SERVICE_STRINGS + ), + MockGetROS2MessageInterfaceTool(mock_interfaces=COMMON_INTERFACES), + MockCallROS2ServiceTool( + available_services=list(COMMON_SERVICES_AND_TYPES.keys()), + available_service_types=list(COMMON_SERVICES_AND_TYPES.values()), + available_service_models={}, + ), ] @property @@ -115,7 +222,7 @@ def get_prompt(self) -> str: else: return ( f"{self.get_base_prompt()} " - "You can explore available camera topics and capture the RGB color image." + "You can list available camera topics and capture the RGB color image." ) @@ -131,7 +238,7 @@ def get_prompt(self) -> str: else: return ( f"{self.get_base_prompt()} " - "You can explore available camera topics and capture the depth image data." + "You can list available camera topics and capture the depth image data." ) @@ -147,7 +254,7 @@ def get_prompt(self) -> str: else: return ( f"{self.get_base_prompt()} " - "You can discover available sensor topics and receive the pointcloud information." + "You can list available topics to find appropriate topic and receive the pointcloud information from it." ) @@ -162,8 +269,8 @@ def get_prompt(self) -> str: return self.get_base_prompt() else: return ( - f"{self.get_base_prompt()} You can explore the system " - "to find robot description data." + f"{self.get_base_prompt()}. You can list available topics to find appropriate topic " + "and receive robot description data from it." ) @@ -180,40 +287,413 @@ def get_prompt(self) -> str: return ( f"{self.get_base_prompt()} from all available camera sources in the system. " "This includes both RGB color images and depth images. " - "You can discover what camera topics are available and capture images from each." + "You can list what camera topics are available and capture images from each." + ) + + +#### calling services #### +class GetROS2ServicesTask(BasicTask): + complexity = "easy" + + @property + def optional_tool_calls_number(self) -> int: + return 0 + + def get_base_prompt(self) -> str: + return "Get all services" + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} available in the ROS2 system with their names and service types. " + "You can list what services are currently available in the system." ) -class CheckRobotHealthTask(BasicTask): +class ListRobotParametersTask(BasicTask): + complexity = "easy" + + def get_base_prompt(self) -> str: + return "List robot state publisher parameters" + + def get_prompt(self) -> str: + base_prompt = "List robot state publisher parameters" + if self.prompt_detail == "brief": + return base_prompt + else: + return ( + f"{self.get_base_prompt()} available for configuration. " + "You can list available services to find the appropriate service and receive the parameters from it" + ) + + +class GetSpecificParameterTask(BasicTask): + complexity = "easy" + + def __init__( + self, + parameter: str, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, + ) -> None: + self.parameter = parameter + if validators is None: + # Default validator for this task + get_parameters_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=ROBOT_STATE_PUBLISHER_GET_PARAMS, + expected_service_type=GET_PARAMETERS_TYPE, + expected_fields={"names.0": parameter}, + ) + validators = [OrderedCallsValidator(subtasks=[get_parameters_subtask])] + super().__init__(validators, task_args, logger) + + @property + def optional_tool_calls_number(self) -> int: + # list services and get interfaces + return 2 + + def get_base_prompt(self) -> str: + return f"Get robot `{self.parameter}` parameter" + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} value from the robot state publisher. " + "You can list available services to find the appropriate service, " + f"check its type's interface and retrieve the {self.parameter} parameter value." + ) + + +class SetRobotParameterTask(BasicTask): complexity = "medium" + def __init__( + self, + value: float, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, + ) -> None: + self.value = value + if validators is None: + # Default validators for this task - allow either regular or atomic set + set_robot_state_params_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=ROBOT_STATE_PUBLISHER_SET_PARAMS, + expected_service_type=SET_PARAMETERS_TYPE, + expected_fields={ + "parameters.0.name": "publish_frequency", + "parameters.0.value.type": "3", + "parameters.0.value.double_value": value, + }, + ) + set_robot_state_params_atomically_subtask = ( + CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=ROBOT_STATE_PUBLISHER_SET_PARAMS_ATOMICALLY, + expected_service_type=SET_PARAMETERS_ATOMICALLY_TYPE, + expected_fields={ + "parameters.0.name": "publish_frequency", + "parameters.0.value.type": "3", + "parameters.0.value.double_value": value, + }, + ) + ) + validators = [ + OneFromManyValidator( + subtasks=[ + set_robot_state_params_subtask, + set_robot_state_params_atomically_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) + + @property + def optional_tool_calls_number(self) -> int: + # list services, get interfaces + return 2 + def get_base_prompt(self) -> str: - return "Check robot health status" + return f"Set robot state parameter `publish frequency` to {self.value} Hz" def get_prompt(self) -> str: if self.prompt_detail == "brief": return self.get_base_prompt() else: return ( - f"{self.get_base_prompt()} by examining system diagnostics and monitoring data. " - "You can explore available diagnostic topics and gather information " - "about robot health, joint states, and system logs." + f"{self.get_base_prompt()} using parameter service. " + "You can list available services to find the appropriate service, " + f"check its type's interface and set the publish_frequency parameter to {self.value}." ) -class AssessSensorDataQualityTask(BasicTask): +class CheckSpawnableEntitiesTask(BasicTask): + complexity = "easy" + + @property + def optional_tool_calls_number(self) -> int: + # list services + return 1 + + def get_base_prompt(self) -> str: + return "Check available spawnable entities" + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} in the simulation environment. " + "You can list available services to find the appropriate " + "service and see what entities can be spawned." + ) + + +class SpawnEntityTask(BasicTask): + complexity = "medium" + + def __init__( + self, + entity: str, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, + ) -> None: + self.entity = entity + if validators is None: + # Default validator for this task + spawn_entity_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=SPAWN_ENTITY_SERVICE, + expected_service_type=SPAWN_ENTITY_TYPE, + expected_fields={ + "name": entity, + }, + ) + validators = [OrderedCallsValidator(subtasks=[spawn_entity_subtask])] + super().__init__(validators, task_args, logger) + + @property + def optional_tool_calls_number(self) -> int: + # list services, get interface + return 2 + + def get_base_prompt(self) -> str: + return f"Spawn a {self.entity} entity" + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} in the simulation environment. " + "You can list available services to find the appropriate service, " + f"check its type's interface and add a {self.entity} with any name and SDF/XML description." + ) + + +class ConfigureVisionPipelineTask(BasicTask): complexity = "hard" + def __init__( + self, + sam_confidence_threshold: float, + dino_confidence_threshold: float, + fps: int, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, + ) -> None: + self.sam_confidence_threshold = sam_confidence_threshold + self.dino_confidence_threshold = dino_confidence_threshold + self.fps = fps + + if validators is None: + # Default validators for this task + set_grounded_sam_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDED_SAM_SET_PARAMS, + expected_service_type=SET_PARAMETERS_TYPE, + expected_fields={ + "parameters.0.name": "confidence_threshold", + "parameters.0.value.type": 3, + "parameters.0.value.double_value": sam_confidence_threshold, + }, + ) + set_grounded_sam_atomically_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDED_SAM_SET_PARAMS_ATOMICALLY, + expected_service_type=SET_PARAMETERS_ATOMICALLY_TYPE, + expected_fields={ + "parameters.0.name": "confidence_threshold", + "parameters.0.value.type": 3, + "parameters.0.value.double_value": sam_confidence_threshold, + }, + ) + + set_grounded_dino_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDING_DINO_SET_PARAMS, + expected_service_type=SET_PARAMETERS_TYPE, + expected_fields={ + "parameters.0.name": "confidence_threshold", + "parameters.0.value.type": 3, + "parameters.0.value.double_value": dino_confidence_threshold, + }, + ) + set_grounding_dino_atomically_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDING_DINO_SET_PARAMS_ATOMICALLY, + expected_service_type=SET_PARAMETERS_ATOMICALLY_TYPE, + expected_fields={ + "parameters.0.name": "confidence_threshold", + "parameters.0.value.type": 3, + "parameters.0.value.double_value": dino_confidence_threshold, + }, + ) + + set_o3de_fps_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=O3DE_SET_PARAMS, + expected_service_type=SET_PARAMETERS_TYPE, + expected_fields={ + "parameters.0.name": "fps", + "parameters.0.value.type": 2, + "parameters.0.value.integer_value": fps, + }, + ) + set_o3de_fps_atomically_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=O3DE_SET_PARAMS_ATOMICALLY, + expected_service_type=SET_PARAMETERS_ATOMICALLY_TYPE, + expected_fields={ + "parameters.0.name": "fps", + "parameters.0.value.type": 2, + "parameters.0.value.integer_value": fps, + }, + ) + + validators = [ + OneFromManyValidator( + subtasks=[ + set_grounded_sam_subtask, + set_grounded_sam_atomically_subtask, + ] + ), + OneFromManyValidator( + subtasks=[ + set_grounded_dino_subtask, + set_grounding_dino_atomically_subtask, + ] + ), + OneFromManyValidator( + subtasks=[set_o3de_fps_subtask, set_o3de_fps_atomically_subtask] + ), + ] + + super().__init__(validators, task_args, logger) + + @property + def optional_tool_calls_number(self) -> int: + return 2 # list services, get interface + def get_base_prompt(self) -> str: - return "Assess sensor data quality" + return ( + f"Configure AI vision pipeline: set grounded_sam `confidence_threshold` " + f"to {self.sam_confidence_threshold}, grounding_dino `confidence_threshold` " + f"to {self.dino_confidence_threshold}, o3de_ros2_node `fps` to {self.fps}. " + "Ensure that each parameter is set in separate service call and in the order specified above " + ) + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} using parameter services. " + "You can list parameter services to find appropriate services " + "check their type's interface and set appropriate parameters." + ) + + +class RespawnEntitiesTask(BasicTask): + complexity = "hard" + + def __init__( + self, + names: List[str], + coords: List[Tuple[float, float, float]], + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, + ) -> None: + self.names = names + self.coords = coords + + if validators is None: + # Default validators for this task + delete_subtasks: List[SubTask] = [] + spawn_subtasks: List[SubTask] = [] + + for name, coord in zip(names, coords): + delete_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=DELETE_ENTITY_SERVICE, + expected_service_type=DELETE_ENTITY_TYPE, + expected_fields={ + "name": name, + }, + ) + spawn_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=SPAWN_ENTITY_SERVICE, + expected_service_type=SPAWN_ENTITY_TYPE, + expected_fields={ + "name": name, + "initial_pose.position.x": coord[0], + "initial_pose.position.y": coord[1], + "initial_pose.position.z": coord[2], + }, + ) + delete_subtasks.append(delete_subtask) + spawn_subtasks.append(spawn_subtask) + + validators = [ + NotOrderedCallsValidator(subtasks=delete_subtasks), + NotOrderedCallsValidator(subtasks=spawn_subtasks), + ] + + super().__init__(validators, task_args, logger) + + @property + def optional_tool_calls_number(self) -> int: + return 3 # list services, get interfaces of spawn and despawn + + def get_base_prompt(self) -> str: + names_str = ", ".join(self.names) + positions: List[str] = [] + for coord in self.coords: + positions.append(f"({coord[0]}, {coord[1]}, {coord[2]})") + positions_str = ", ".join(positions) + + return ( + f"Reconfigure simulation: remove old `cube` entities named {names_str}, " + f"then respawn them at positions [{positions_str}]" + ) def get_prompt(self) -> str: if self.prompt_detail == "brief": return self.get_base_prompt() else: return ( - f"{self.get_base_prompt()} across all available sensors in the robot system. " - "You can explore sensor topics and gather data from various sources " - "including laser scans, cameras, pointclouds, and odometry to evaluate " - "overall sensor performance." + f"{self.get_base_prompt()} using entity management services. " + "You can list services to find appropriate services, check their type's interface " + "and use them to delete old and spawn new `cube` entities with specific names and positions." ) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/validators.py b/src/rai_bench/rai_bench/tool_calling_agent/validators.py index 55399b2e6..666df9faf 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/validators.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/validators.py @@ -144,3 +144,54 @@ def validate(self, tool_calls: List[ToolCall]) -> Tuple[bool, List[ToolCall]]: if len(tool_calls) > self.required_calls: self.extra_calls_used = len(tool_calls) - self.required_calls return False, [] + + +class OneFromManyValidator(Validator): + """ + Validator that passes when any one of the given subtasks passes. + """ + + def __init__( + self, subtasks: List[SubTask], logger: loggers_type | None = None + ) -> None: + super().__init__(subtasks=subtasks, logger=logger) + if len(self.subtasks) < 1: + raise ValueError("Validator must have at least 1 subtask.") + + @property + def type(self) -> str: + return "optional" + + @property + def required_calls(self) -> int: + # For optional validator, we only need 1 call to pass any subtask + return 1 + + def validate(self, tool_calls: List[ToolCall]) -> Tuple[bool, List[ToolCall]]: + self.reset() + if len(tool_calls) < 1: + self.logger.debug("Not a single tool call to validate") + self.passed = False + return False, tool_calls + + for i, tool_call in enumerate(tool_calls): + # Check if this tool call matches any subtask + for u, subtask in enumerate(self.subtasks): + try: + if subtask.validate(tool_call=tool_call): + # Found a matching subtask - validation succeeds + self.subtasks_passed[u] = True + self.passed = True + self.extra_calls_used = ( + len(tool_calls) - 1 + ) # We only needed 1 call + return True, tool_calls[i + 1 :] + except SubTaskValidationError as e: + # Store error but continue trying other subtasks + self.add_subtask_errors(idx=u, msgs=[str(e)]) + + # No tool call matched any subtask + self.logger.debug("Validation failed - no tool call matched any subtask") + self.passed = False + self.extra_calls_used = len(tool_calls) - 1 if len(tool_calls) > 1 else 0 + return False, [] diff --git a/tests/rai_bench/tool_calling_agent/test_mock_tools.py b/tests/rai_bench/tool_calling_agent/test_mock_tools.py new file mode 100644 index 000000000..cec31df32 --- /dev/null +++ b/tests/rai_bench/tool_calling_agent/test_mock_tools.py @@ -0,0 +1,163 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict + +import pytest +from rai.types.rai_interfaces import ManipulatorMoveToRequest + +from rai_bench.tool_calling_agent.mocked_tools import ServiceValidator + + +class TestServiceValidator: + """Test suite for ServiceValidator using real ROS 2 interfaces and custom models""" + + @pytest.fixture + def custom_models(self) -> Dict[str, Any]: + """Fixture providing actual custom Pydantic models""" + return { + "rai_interfaces/srv/ManipulatorMoveTo": ManipulatorMoveToRequest, + } + + @pytest.fixture + def validator(self, custom_models: Dict[str, Any]) -> ServiceValidator: + """Fixture providing ServiceValidator instance with real models""" + return ServiceValidator(custom_models) + + def test_validate_with_ros2_setparameters_valid_args( + self, validator: ServiceValidator + ): + """Test successful validation with valid SetParameters arguments""" + args: Dict[str, Any] = { + "parameters": [ + { + "name": "test_param", + "value": { + "type": 2, # PARAMETER_INTEGER + "bool_value": False, + "integer_value": 42, + "double_value": 0.0, + "string_value": "", + }, + } + ] + } + + validator.validate_with_ros2("rcl_interfaces/srv/SetParameters", args) + + def test_validate_with_ros2_setparameters_minimal_valid_args( + self, validator: ServiceValidator + ): + """Test validation with minimal valid SetParameters arguments""" + args: Dict[str, Any] = { + "parameters": [ + { + "name": "test_param", + "value": { + "type": 4, # PARAMETER_STRING + "string_value": "test_value", + }, + } + ] + } + + validator.validate_with_ros2("rcl_interfaces/srv/SetParameters", args) + + def test_validate_with_ros2_invalid_field_name(self, validator: ServiceValidator): + """Test validation with invalid field in SetParameters""" + args: Dict[str, Any] = {"parameters": [], "invalid_field": "should_not_exist"} + + with pytest.raises(AttributeError): # set_message_fields will raise + validator.validate_with_ros2("rcl_interfaces/srv/SetParameters", args) + + def test_validate_with_ros2_wrong_parameter_type(self, validator: ServiceValidator): + """Test validation with wrong parameter structure""" + args: Dict[str, Any] = { + "parameters": [{"name": "test_param", "value": "should_be_dict_not_string"}] + } + + with pytest.raises(TypeError): + validator.validate_with_ros2("rcl_interfaces/srv/SetParameters", args) + + def test_validate_with_ros2_empty_args(self, validator: ServiceValidator): + """Test validation with empty args (should use defaults)""" + args: Dict[str, Any] = {} + + # Should work - ROS 2 messages have default values + validator.validate_with_ros2("rcl_interfaces/srv/GetParameters", args) + + def test_validate_with_custom_model_valid_args(self, validator: ServiceValidator): + """Test successful validation with valid ManipulatorMoveTo arguments""" + args: Dict[str, Any] = { + "initial_gripper_state": True, + "final_gripper_state": False, + "target_pose": { + "header": { + "stamp": {"sec": 0, "nanosec": 0}, + "frame_id": "base_link", + }, + "pose": { + "position": {"x": 1.0, "y": 2.0, "z": 3.0}, + "orientation": {"x": 0.0, "y": 0.0, "z": 0.0, "w": 1.0}, + }, + }, + } + + validator.validate_with_custom("rai_interfaces/srv/ManipulatorMoveTo", args) + + def test_validate_with_custom_manipulator_minimal_args( + self, validator: ServiceValidator + ): + """Test validation with minimal ManipulatorMoveTo arguments (using defaults)""" + args: Dict[str, Any] = {} # All fields have defaults + + validator.validate_with_custom("rai_interfaces/srv/ManipulatorMoveTo", args) + + def test_validate_with_custom_manipulator_invalid_type( + self, validator: ServiceValidator + ): + """Test validation with invalid type for ManipulatorMoveTo""" + args: Dict[str, Any] = { + "initial_gripper_state": "should_be_bool_not_string", + "final_gripper_state": False, + } + + with pytest.raises(ValueError, match="Pydantic validation failed"): + validator.validate_with_custom("rai_interfaces/srv/ManipulatorMoveTo", args) + + def test_validate_with_custom_service_not_in_models( + self, validator: ServiceValidator + ): + """Test custom validation when service type not in custom models""" + args: Dict[str, Any] = {"some_field": "value"} + + with pytest.raises(ValueError, match="is invalid custom type"): + validator.validate_with_custom("unknown/srv/Service", args) + + def test_validate_routes_to_custom_when_available( + self, validator: ServiceValidator + ): + """Test that validate() uses custom models when service type is in custom_models""" + args = {"initial_gripper_state": True} + + # Should route to custom validation (not ROS 2) + validator.validate("rai_interfaces/srv/ManipulatorMoveTo", args) + + def test_validate_service_not_in_custom_or_ros2(self, validator: ServiceValidator): + """Test validation when service exists in neither custom models nor ROS 2""" + args = {"some_field": "value"} + + with pytest.raises(ImportError): + validator.validate("nonexistent_package/srv/NonexistentService", args) diff --git a/tests/rai_bench/tool_calling_agent/test_predefined_tasks.py b/tests/rai_bench/tool_calling_agent/test_predefined_tasks.py new file mode 100644 index 000000000..792db0ee9 --- /dev/null +++ b/tests/rai_bench/tool_calling_agent/test_predefined_tasks.py @@ -0,0 +1,1362 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List + +import pytest + +from rai_bench.tool_calling_agent.interfaces import ( + TaskArgs, +) +from rai_bench.tool_calling_agent.predefined.basic_tasks import ( + all_camera_images_notord_val, + check_spawnable_entities_val, + color_image_ord_val, + depth_image_ord_val, + get_pointcloud_ord_val, + get_robot_desc_ord_val, + list_parameters_val, + services_ord_val, + topics_ord_val, +) +from rai_bench.tool_calling_agent.tasks.basic import ( + BOX1_ENTITY, + BOX1_POSITION, + BOX2_ENTITY, + BOX2_POSITION, + COLOR_IMAGE_TOPIC, + DEFAULT_DINO_CONFIDENCE, + DEFAULT_FPS, + DEFAULT_PUBLISH_FREQUENCY, + DEFAULT_SAM_CONFIDENCE, + DELETE_ENTITY_SERVICE, + DELETE_ENTITY_TYPE, + DEPTH_IMAGE_TOPIC, + DINO_CONFIDENCE_2, + FPS_2, + GET_PARAMETERS_TYPE, + GET_SPAWNABLE_NAMES_SERVICE, + GET_WORLD_PROPERTIES_TYPE, + GROUNDED_SAM_SET_PARAMS, + GROUNDED_SAM_SET_PARAMS_ATOMICALLY, + GROUNDING_DINO_SET_PARAMS, + GROUNDING_DINO_SET_PARAMS_ATOMICALLY, + LIST_PARAMETERS_TYPE, + O3DE_SET_PARAMS, + POINTCLOUD_TOPIC, + ROBOT_DESCRIPTION_TOPIC, + ROBOT_STATE_PUBLISHER_GET_PARAMS, + ROBOT_STATE_PUBLISHER_LIST_PARAMS, + ROBOT_STATE_PUBLISHER_SET_PARAMS, + SAM_CONFIDENCE_2, + SET_PARAMETERS_ATOMICALLY_TYPE, + SET_PARAMETERS_TYPE, + SPAWN_ENTITY_SERVICE, + SPAWN_ENTITY_TYPE, + TOMATO_ENTITY, + CheckSpawnableEntitiesTask, + ConfigureVisionPipelineTask, + GetAllROS2CamerasTask, + GetPointcloudTask, + GetRobotDescriptionTask, + GetROS2DepthCameraTask, + GetROS2RGBCameraTask, + GetROS2ServicesTask, + GetROS2TopicsTask, + GetSpecificParameterTask, + ListRobotParametersTask, + RespawnEntitiesTask, + SetRobotParameterTask, + SpawnEntityTask, +) + + +@pytest.fixture +def task_args() -> TaskArgs: + """Create basic task arguments for testing.""" + return TaskArgs( + extra_tool_calls=0, + prompt_detail="brief", + examples_in_system_prompt=0, + ) + + +class TestSetParameterTask: + """Test SetRobotParameterTask validation.""" + + def test_set_parameter_task_valid(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": SET_PARAMETERS_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "publish_frequency", + "value": { + "type": "3", + "double_value": DEFAULT_PUBLISH_FREQUENCY, + }, + } + ] + }, + }, + }, + ] + + # Test with refactored task (uses internal validators) + task = SetRobotParameterTask( + value=DEFAULT_PUBLISH_FREQUENCY, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 # All validators should pass + + def test_set_parameter_task_wrong_parameter_type(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "publish_frequency", + "value": { + "type": 2, # Wrong type (integer instead of double) + "integer_value": 30, + }, + } + ] + }, + }, + }, + ] + + task = SetRobotParameterTask( + value=DEFAULT_PUBLISH_FREQUENCY, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_set_parameter_task_wrong_parameter_missing_type( + self, task_args: TaskArgs + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "publish_frequency", + "value": { + # missing type field + "integer_value": 30, + }, + } + ] + }, + }, + }, + ] + + task = SetRobotParameterTask( + value=DEFAULT_PUBLISH_FREQUENCY, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_set_parameter_task_wrong_parameter_name(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "wrong_parameter_name", # Wrong parameter name + "value": { + "type": "3", + "double_value": DEFAULT_PUBLISH_FREQUENCY, + }, + } + ] + }, + }, + }, + ] + + task = SetRobotParameterTask( + value=DEFAULT_PUBLISH_FREQUENCY, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_set_parameter_task_wrong_tools(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": SET_PARAMETERS_TYPE}, + }, + {"name": "get_ros2_services_names_and_types", "args": {}}, + {"name": "get_ros2_services_names_and_types", "args": {}}, + ] + + task = SetRobotParameterTask( + value=DEFAULT_PUBLISH_FREQUENCY, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetTopicsTask: + """Test GetROS2TopicsTask validation.""" + + def test_get_topics_task_valid(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}} + ] + + task = GetROS2TopicsTask(validators=[topics_ord_val], task_args=task_args) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_topics_task_wrong_tool(self, task_args: TaskArgs) -> None: + """Test get ROS2 topics task with wrong tool name.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "wrong_tool_name", "args": {}} # Wrong tool name + ] + + task = GetROS2TopicsTask(validators=[topics_ord_val], task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_topics_task_unexpected_args(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_topics_names_and_types", + "args": {"unexpected": "arg"}, + } # Unexpected args + ] + + task = GetROS2TopicsTask(validators=[topics_ord_val], task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetRGBCameraTask: + """Test GetROS2RGBCameraTask validation.""" + + def test_get_rgb_camera_task_valid(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + { + "name": "get_ros2_image", + "args": {"topic": COLOR_IMAGE_TOPIC, "timeout_sec": 5}, + }, + ] + + task = GetROS2RGBCameraTask( + validators=[color_image_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_rgb_camera_task_wrong_topic(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_image", + "args": {"topic": "/wrong_topic", "timeout_sec": 5}, # Wrong topic + } + ] + + task = GetROS2RGBCameraTask( + validators=[color_image_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_rgb_camera_task_missing_required_arg( + self, task_args: TaskArgs + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_image", + "args": {"timeout_sec": 5}, # Missing required topic arg + } + ] + + task = GetROS2RGBCameraTask( + validators=[color_image_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_rgb_camera_task_wrong_timeout_type(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_image", + "args": { + "topic": COLOR_IMAGE_TOPIC, + "timeout_sec": "not_an_int", + }, # Wrong type + } + ] + + task = GetROS2RGBCameraTask( + validators=[color_image_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetDepthCameraTask: + """Test GetROS2DepthCameraTask validation.""" + + def test_get_depth_camera_task_valid(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + { + "name": "get_ros2_image", + "args": {"topic": DEPTH_IMAGE_TOPIC, "timeout_sec": 5}, + }, + ] + + task = GetROS2DepthCameraTask( + validators=[depth_image_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_depth_camera_task_wrong_topic(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_image", + "args": { + "topic": COLOR_IMAGE_TOPIC, # Wrong topic (color instead of depth) + "timeout_sec": 5, + }, + } + ] + + task = GetROS2DepthCameraTask( + validators=[depth_image_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetAllCamerasTask: + """Test GetAllROS2CamerasTask validation.""" + + def test_get_all_cameras_task_valid(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + { + "name": "get_ros2_image", + "args": {"topic": COLOR_IMAGE_TOPIC, "timeout_sec": 5}, + }, + { + "name": "get_ros2_image", + "args": {"topic": DEPTH_IMAGE_TOPIC, "timeout_sec": 5}, + }, + ] + + task = GetAllROS2CamerasTask( + validators=[all_camera_images_notord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_all_cameras_task_missing_depth(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_image", + "args": {"topic": COLOR_IMAGE_TOPIC, "timeout_sec": 5}, + } + # Missing depth camera call + ] + + task = GetAllROS2CamerasTask( + validators=[all_camera_images_notord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_all_cameras_task_wrong_order_should_pass( + self, task_args: TaskArgs + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + # Reversed order - depth first, then color + { + "name": "get_ros2_image", + "args": {"topic": DEPTH_IMAGE_TOPIC, "timeout_sec": 5}, + }, + { + "name": "get_ros2_image", + "args": {"topic": COLOR_IMAGE_TOPIC, "timeout_sec": 5}, + }, + ] + + task = GetAllROS2CamerasTask( + validators=[all_camera_images_notord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 # Should pass with NotOrderedCallsValidator + + +class TestGetPointcloudTask: + """Test GetPointcloudTask validation.""" + + def test_get_pointcloud_task_valid(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + { + "name": "receive_ros2_message", + "args": {"topic": POINTCLOUD_TOPIC, "timeout_sec": 10}, + }, + ] + + task = GetPointcloudTask( + validators=[get_pointcloud_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_pointcloud_task_wrong_topic(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "receive_ros2_message", + "args": { + "topic": "/wrong_pointcloud_topic", + "timeout_sec": 10, + }, # Wrong topic + } + ] + + task = GetPointcloudTask( + validators=[get_pointcloud_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetRobotDescriptionTask: + """Test GetRobotDescriptionTask validation.""" + + def test_get_robot_description_task_valid(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + { + "name": "receive_ros2_message", + "args": {"topic": ROBOT_DESCRIPTION_TOPIC, "timeout_sec": 10}, + }, + ] + + task = GetRobotDescriptionTask( + validators=[get_robot_desc_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_robot_description_task_wrong_tool_name( + self, task_args: TaskArgs + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message", # Wrong tool name (missing "receive_") + "args": {"topic": ROBOT_DESCRIPTION_TOPIC, "timeout_sec": 10}, + } + ] + + task = GetRobotDescriptionTask( + validators=[get_robot_desc_ord_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetROS2ServicesTask: + """Test GetROS2ServicesTask validation.""" + + def test_get_services_task_valid(self, task_args: TaskArgs) -> None: + """Test get ROS2 services task with valid call.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}} + ] + + task = GetROS2ServicesTask(validators=[services_ord_val], task_args=task_args) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_services_task_wrong_tool_name(self, task_args: TaskArgs) -> None: + """Test get ROS2 services task with wrong tool name.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "wrong_tool_name", "args": {}} # Wrong tool name + ] + + task = GetROS2ServicesTask(validators=[services_ord_val], task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_services_task_unexpected_args(self, task_args: TaskArgs) -> None: + """Test get ROS2 services task with unexpected arguments.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_services_names_and_types", + "args": {"unexpected": "arg"}, + } # Unexpected args + ] + + task = GetROS2ServicesTask(validators=[services_ord_val], task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestListRobotParametersTask: + """Test ListRobotParametersTask validation.""" + + def test_list_parameters_task_valid(self, task_args: TaskArgs) -> None: + """Test list parameters task with valid service call.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_LIST_PARAMS, + "service_type": LIST_PARAMETERS_TYPE, + "service_args": {}, + }, + }, + ] + + task = ListRobotParametersTask( + validators=[list_parameters_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_list_parameters_task_wrong_service_name(self, task_args: TaskArgs) -> None: + """Test list parameters task with wrong service name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": "/wrong_node/list_parameters", # Wrong service name + "service_type": LIST_PARAMETERS_TYPE, + "service_args": {}, + }, + } + ] + + task = ListRobotParametersTask( + validators=[list_parameters_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_list_parameters_task_wrong_tool_name(self, task_args: TaskArgs) -> None: + """Test list parameters task with wrong tool name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "wrong_tool_name", # Wrong tool name + "args": { + "service_name": ROBOT_STATE_PUBLISHER_LIST_PARAMS, + "service_type": LIST_PARAMETERS_TYPE, + "service_args": {}, + }, + } + ] + + task = ListRobotParametersTask( + validators=[list_parameters_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetSpecificParameterTask: + """Test GetSpecificParameterTask validation.""" + + def test_get_parameter_task_valid(self, task_args: TaskArgs) -> None: + """Test get specific parameter task with valid service call.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": GET_PARAMETERS_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_GET_PARAMS, + "service_type": GET_PARAMETERS_TYPE, + "service_args": {"names": ["publish_frequency"]}, + }, + }, + ] + + # Test with refactored task (uses internal validators) + task = GetSpecificParameterTask( + parameter="publish_frequency", + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_parameter_task_wrong_parameter_name(self, task_args: TaskArgs) -> None: + """Test get specific parameter task with wrong parameter name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_GET_PARAMS, + "service_type": GET_PARAMETERS_TYPE, + "service_args": { + "names": ["wrong_parameter_name"] # Wrong parameter name + }, + }, + } + ] + + task = GetSpecificParameterTask( + parameter="publish_frequency", + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_parameter_task_missing_names_field(self, task_args: TaskArgs) -> None: + """Test get specific parameter task with missing names field.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": ROBOT_STATE_PUBLISHER_GET_PARAMS, + "service_type": GET_PARAMETERS_TYPE, + "service_args": {}, # Missing names field + }, + } + ] + + task = GetSpecificParameterTask( + parameter="publish_frequency", + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestCheckSpawnableEntitiesTask: + """Test CheckSpawnableEntitiesTask validation.""" + + def test_check_spawnable_entities_task_valid(self, task_args: TaskArgs) -> None: + """Test check spawnable entities task with valid service call.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "call_ros2_service", + "args": { + "service_name": GET_SPAWNABLE_NAMES_SERVICE, + "service_type": GET_WORLD_PROPERTIES_TYPE, + "service_args": {}, + }, + }, + ] + + task = CheckSpawnableEntitiesTask( + validators=[check_spawnable_entities_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_check_spawnable_entities_task_wrong_service_name( + self, task_args: TaskArgs + ) -> None: + """Test check spawnable entities task with wrong service name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": "/wrong_service_name", # Wrong service name + "service_type": GET_WORLD_PROPERTIES_TYPE, + "service_args": {}, + }, + } + ] + + task = CheckSpawnableEntitiesTask( + validators=[check_spawnable_entities_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_check_spawnable_entities_task_wrong_tool_name( + self, task_args: TaskArgs + ) -> None: + """Test check spawnable entities task with wrong tool name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "wrong_tool_name", # Wrong tool name + "args": { + "service_name": GET_SPAWNABLE_NAMES_SERVICE, + "service_type": GET_WORLD_PROPERTIES_TYPE, + "service_args": {}, + }, + } + ] + + task = CheckSpawnableEntitiesTask( + validators=[check_spawnable_entities_val], task_args=task_args + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestSpawnEntityTask: + """Test SpawnEntityTask validation.""" + + def test_spawn_entity_task_valid_tomato(self, task_args: TaskArgs) -> None: + """Test spawn entity task with tomato entity.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": SPAWN_ENTITY_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": TOMATO_ENTITY, + "xml": "tomato model", + }, + }, + }, + ] + + # Test with refactored task (uses internal validators) + task = SpawnEntityTask(entity=TOMATO_ENTITY, task_args=task_args) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_spawn_entity_task_wrong_service_name(self, task_args: TaskArgs) -> None: + """Test spawn entity task with wrong service name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": "/wrong_spawn_service", # Wrong service name + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": TOMATO_ENTITY, + "xml": "test", + }, + }, + } + ] + + task = SpawnEntityTask(entity=TOMATO_ENTITY, task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_spawn_entity_task_wrong_entity_name(self, task_args: TaskArgs) -> None: + """Test spawn entity task with wrong entity name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": "wrong_entity_name", # Wrong entity name + "xml": "test", + }, + }, + } + ] + + task = SpawnEntityTask(entity=TOMATO_ENTITY, task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_spawn_entity_task_missing_service_args(self, task_args: TaskArgs) -> None: + """Test spawn entity task with missing service args.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + # Missing service_args + }, + } + ] + + task = SpawnEntityTask(entity=TOMATO_ENTITY, task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_spawn_entity_task_wrong_tool_name(self, task_args: TaskArgs) -> None: + """Test spawn entity task with wrong tool name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "wrong_tool_name", # Wrong tool name + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": TOMATO_ENTITY, + "xml": "test", + }, + }, + } + ] + + task = SpawnEntityTask(entity=TOMATO_ENTITY, task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestConfigureVisionPipelineTask: + """Test ConfigureVisionPipelineTask validation.""" + + def test_configure_vision_pipeline_task_valid_config1( + self, task_args: TaskArgs + ) -> None: + """Test configure vision pipeline task with first configuration.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": SET_PARAMETERS_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDED_SAM_SET_PARAMS_ATOMICALLY, + "service_type": SET_PARAMETERS_ATOMICALLY_TYPE, + "service_args": { + "parameters": [ + { + "name": "confidence_threshold", + "value": { + "type": 3, + "double_value": DEFAULT_SAM_CONFIDENCE, + }, + } + ] + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SET_PARAMS_ATOMICALLY, + "service_type": SET_PARAMETERS_ATOMICALLY_TYPE, + "service_args": { + "parameters": [ + { + "name": "confidence_threshold", + "value": { + "type": 3, + "double_value": DEFAULT_DINO_CONFIDENCE, + }, + } + ] + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": O3DE_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "fps", + "value": {"type": 2, "integer_value": DEFAULT_FPS}, + } + ] + }, + }, + }, + ] + + # Test with refactored task (uses internal validators) + task = ConfigureVisionPipelineTask( + sam_confidence_threshold=DEFAULT_SAM_CONFIDENCE, + dino_confidence_threshold=DEFAULT_DINO_CONFIDENCE, + fps=DEFAULT_FPS, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_configure_vision_pipeline_task_valid_config2( + self, task_args: TaskArgs + ) -> None: + """Test configure vision pipeline task with second configuration.""" + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": SET_PARAMETERS_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDED_SAM_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "confidence_threshold", + "value": { + "type": 3, + "double_value": SAM_CONFIDENCE_2, + }, + } + ] + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "confidence_threshold", + "value": { + "type": 3, + "double_value": DINO_CONFIDENCE_2, + }, + } + ] + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": O3DE_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "fps", + "value": {"type": 2, "integer_value": FPS_2}, + } + ] + }, + }, + }, + ] + + task = ConfigureVisionPipelineTask( + sam_confidence_threshold=SAM_CONFIDENCE_2, + dino_confidence_threshold=DINO_CONFIDENCE_2, + fps=FPS_2, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_configure_vision_pipeline_task_missing_calls( + self, task_args: TaskArgs + ) -> None: + """Test configure vision pipeline task with missing service calls.""" + + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDED_SAM_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "confidence_threshold", + "value": { + "type": 3, + "double_value": DEFAULT_SAM_CONFIDENCE, + }, + } + ] + }, + }, + } + ] + + task = ConfigureVisionPipelineTask( + sam_confidence_threshold=DEFAULT_SAM_CONFIDENCE, + dino_confidence_threshold=DEFAULT_DINO_CONFIDENCE, + fps=DEFAULT_FPS, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert abs(score - 0.3333333333333333) < 0.01 + + def test_configure_vision_pipeline_task_setting_in_one_call( + self, task_args: TaskArgs + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDED_SAM_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": { + "parameters": [ + { + "name": "confidence_threshold", + "value": { + "type": 3, + "double_value": DEFAULT_SAM_CONFIDENCE, + }, + }, + { + "name": "fps", + "value": {"type": 2, "integer_value": FPS_2}, + }, + { + "name": "confidence_threshold", + "value": { + "type": 3, + "double_value": DEFAULT_DINO_CONFIDENCE, + }, + }, + ] + }, + }, + } + ] + + task = ConfigureVisionPipelineTask( + sam_confidence_threshold=DEFAULT_SAM_CONFIDENCE, + dino_confidence_threshold=DEFAULT_DINO_CONFIDENCE, + fps=DEFAULT_FPS, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert abs(score - 0.3333333333333333) < 0.01 + + def test_configure_vision_pipeline_task_empty_call( + self, task_args: TaskArgs + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDED_SAM_SET_PARAMS, + "service_type": SET_PARAMETERS_TYPE, + "service_args": {"parameters": []}, + }, + } + ] + + task = ConfigureVisionPipelineTask( + sam_confidence_threshold=DEFAULT_SAM_CONFIDENCE, + dino_confidence_threshold=DEFAULT_DINO_CONFIDENCE, + fps=DEFAULT_FPS, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0 + + +class TestRespawnEntitiesTask: + """Test RespawnEntitiesTask validation.""" + + def test_respawn_entities_task_valid(self, task_args: TaskArgs) -> None: + """Test respawn entities task with valid calls.""" + + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_services_names_and_types", "args": {}}, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": DELETE_ENTITY_TYPE}, + }, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": SPAWN_ENTITY_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": DELETE_ENTITY_SERVICE, + "service_type": DELETE_ENTITY_TYPE, + "service_args": {"name": BOX1_ENTITY}, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": DELETE_ENTITY_SERVICE, + "service_type": DELETE_ENTITY_TYPE, + "service_args": {"name": BOX2_ENTITY}, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": BOX1_ENTITY, + "xml": "box1 model", + "initial_pose": { + "position": { + "x": BOX1_POSITION[0], + "y": BOX1_POSITION[1], + "z": BOX1_POSITION[2], + } + }, + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": BOX2_ENTITY, + "xml": "box2 model", + "initial_pose": { + "position": { + "x": BOX2_POSITION[0], + "y": BOX2_POSITION[1], + "z": BOX2_POSITION[2], + } + }, + }, + }, + }, + ] + + # Test with refactored task (uses internal validators) + task = RespawnEntitiesTask( + names=[BOX1_ENTITY, BOX2_ENTITY], + coords=[BOX1_POSITION, BOX2_POSITION], + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_respawn_entities_task_missing_delete(self, task_args: TaskArgs) -> None: + """Test respawn entities task with missing delete calls.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": DELETE_ENTITY_SERVICE, + "service_type": DELETE_ENTITY_TYPE, + "service_args": {"name": BOX1_ENTITY}, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": BOX1_ENTITY, + "xml": "box1 model", + "initial_pose": { + "position": { + "x": BOX1_POSITION[0], + "y": BOX1_POSITION[1], + "z": BOX1_POSITION[2], + } + }, + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": BOX2_ENTITY, + "xml": "box2 model", + "initial_pose": { + "position": { + "x": BOX2_POSITION[0], + "y": BOX2_POSITION[1], + "z": BOX2_POSITION[2], + } + }, + }, + }, + }, + ] + + task = RespawnEntitiesTask( + names=[BOX1_ENTITY, BOX2_ENTITY], + coords=[BOX1_POSITION, BOX2_POSITION], + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 # first validator fail so second fail too + + def test_respawn_entities_task_missing_spawn(self, task_args: TaskArgs) -> None: + """Test respawn entities task with missing spawn calls.""" + + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": DELETE_ENTITY_SERVICE, + "service_type": DELETE_ENTITY_TYPE, + "service_args": {"name": BOX1_ENTITY}, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": DELETE_ENTITY_SERVICE, + "service_type": DELETE_ENTITY_TYPE, + "service_args": {"name": BOX2_ENTITY}, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": SPAWN_ENTITY_SERVICE, + "service_type": SPAWN_ENTITY_TYPE, + "service_args": { + "name": BOX1_ENTITY, + "xml": "box1 model", + "initial_pose": { + "position": { + "x": BOX1_POSITION[0], + "y": BOX1_POSITION[1], + "z": BOX1_POSITION[2], + } + }, + }, + }, + }, + ] + + task = RespawnEntitiesTask( + names=[BOX1_ENTITY, BOX2_ENTITY], + coords=[BOX1_POSITION, BOX2_POSITION], + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.5 # Only delete validators pass, spawn validators fail + + +class TestMultiValidatorScoring: + """Test scoring with multiple validators to ensure proper fraction calculation.""" + + def test_three_validators_all_pass(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + { + "name": "get_ros2_image", + "args": {"topic": COLOR_IMAGE_TOPIC, "timeout_sec": 5}, + }, + { + "name": "get_ros2_image", + "args": {"topic": DEPTH_IMAGE_TOPIC, "timeout_sec": 5}, + }, + ] + + # Create multiple validators for testing + val1 = topics_ord_val + val2 = color_image_ord_val + val3 = depth_image_ord_val + + task = GetROS2RGBCameraTask(validators=[val1, val2, val3], task_args=task_args) + score = task.validate(tool_calls) + assert score == 1.0 # 3/3 validators pass + + def test_three_validators_two_pass(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + { + "name": "get_ros2_image", + "args": {"topic": COLOR_IMAGE_TOPIC, "timeout_sec": 5}, + }, + ] + + # Create multiple validators for testing + val1 = topics_ord_val + val2 = color_image_ord_val + val3 = depth_image_ord_val + + task = GetROS2RGBCameraTask(validators=[val1, val2, val3], task_args=task_args) + score = task.validate(tool_calls) + # Should be 2/3 = 0.6666... + assert abs(score - 0.6666666666666666) < 0.01 + + def test_three_validators_one_pass(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_ros2_topics_names_and_types", "args": {}}, + ] + + # Create multiple validators for testing + val1 = topics_ord_val + val2 = color_image_ord_val + val3 = depth_image_ord_val + + task = GetROS2RGBCameraTask(validators=[val1, val2, val3], task_args=task_args) + score = task.validate(tool_calls) + # Should be 1/3 = 0.3333... + assert abs(score - 0.3333333333333333) < 0.01 + + def test_three_validators_none_pass(self, task_args: TaskArgs) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_image", + "args": {"topic": COLOR_IMAGE_TOPIC, "timeout_sec": 5}, + }, + ] + + # Create multiple validators for testing + val1 = topics_ord_val + val2 = color_image_ord_val + val3 = depth_image_ord_val + + task = GetROS2RGBCameraTask(validators=[val1, val2, val3], task_args=task_args) + score = task.validate(tool_calls) + assert score == 0.0 diff --git a/tests/rai_bench/tool_calling_agent/test_subtasks.py b/tests/rai_bench/tool_calling_agent/test_subtasks.py index 96339b9ad..eeb034360 100644 --- a/tests/rai_bench/tool_calling_agent/test_subtasks.py +++ b/tests/rai_bench/tool_calling_agent/test_subtasks.py @@ -26,41 +26,48 @@ ) -class TestSubTaskHelpers: - """Test the helper methods in the abstract SubTask class.""" +class ConcreteSubTask(SubTask): + def validate(self, tool_call: ToolCall) -> bool: + return True - @pytest.fixture - def mock_subtask(self): - """Create a concrete implementation of the abstract SubTask for testing""" + def dump(self) -> Dict[str, Any]: + return {} - class ConcreteSubTask(SubTask): - def validate(self, tool_call: ToolCall) -> bool: - return True + @property + def info(self) -> Dict[str, Any]: + return {"name": "blybly"} - def dump(self) -> Dict[str, Any]: - return {} - @property - def info(self) -> Dict[str, Any]: - return {"name": "blybly"} +@pytest.fixture +def mock_subtask() -> ConcreteSubTask: + """Create a concrete implementation of the abstract SubTask for testing""" + return ConcreteSubTask() - return ConcreteSubTask() - def test_check_tool_call_valid(self, mock_subtask): +class TestSubTaskHelpers: + """Test the helper methods in the abstract SubTask class.""" + + def test_check_tool_call_valid(self, mock_subtask: ConcreteSubTask) -> None: """Test _check_tool_call with valid inputs.""" - tool_call = {"name": "test_tool", "args": {"arg1": "value1", "arg2": 42}} + tool_call: Dict[str, Any] = { + "name": "test_tool", + "args": {"arg1": "value1", "arg2": 42}, + } - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} assert mock_subtask._check_tool_call( tool_call=tool_call, expected_name="test_tool", expected_args=expected_args ) - def test_check_tool_call_wrong_name(self, mock_subtask): + def test_check_tool_call_wrong_name(self, mock_subtask: ConcreteSubTask) -> None: """Test _check_tool_call fails with wrong tool name.""" - tool_call = {"name": "wrong_tool", "args": {"arg1": "value1", "arg2": 42}} + tool_call: Dict[str, Any] = { + "name": "wrong_tool", + "args": {"arg1": "value1", "arg2": 42}, + } - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} with pytest.raises( SubTaskValidationError, @@ -72,11 +79,11 @@ def test_check_tool_call_wrong_name(self, mock_subtask): expected_args=expected_args, ) - def test_check_tool_call_missing_arg(self, mock_subtask): + def test_check_tool_call_missing_arg(self, mock_subtask: ConcreteSubTask) -> None: """Test _check_tool_call fails with missing argument.""" - tool_call = {"name": "test_tool", "args": {"arg1": "value1"}} + tool_call: Dict[str, Any] = {"name": "test_tool", "args": {"arg1": "value1"}} - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} with pytest.raises( SubTaskValidationError, match="Required argument 'arg2' missing" @@ -87,11 +94,16 @@ def test_check_tool_call_missing_arg(self, mock_subtask): expected_args=expected_args, ) - def test_check_tool_call_wrong_arg_value(self, mock_subtask): + def test_check_tool_call_wrong_arg_value( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_tool_call fails with wrong argument value.""" - tool_call = {"name": "test_tool", "args": {"arg1": "wrong_value", "arg2": 42}} + tool_call: Dict[str, Any] = { + "name": "test_tool", + "args": {"arg1": "wrong_value", "arg2": 42}, + } - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} with pytest.raises( SubTaskValidationError, @@ -103,14 +115,16 @@ def test_check_tool_call_wrong_arg_value(self, mock_subtask): expected_args=expected_args, ) - def test_check_tool_call_unexpected_arg(self, mock_subtask): + def test_check_tool_call_unexpected_arg( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_tool_call fails with unexpected argument.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "test_tool", "args": {"arg1": "value1", "arg2": 42, "unexpected": "surprise"}, } - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} with pytest.raises( SubTaskValidationError, match="Unexpected argument 'unexpected' found" @@ -121,9 +135,11 @@ def test_check_tool_call_unexpected_arg(self, mock_subtask): expected_args=expected_args, ) - def test_check_topic_tool_call_field_valid(self, mock_subtask): + def test_check_topic_tool_call_field_valid( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_topic_tool_call_field with valid inputs.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -147,9 +163,11 @@ def test_check_topic_tool_call_field_valid(self, mock_subtask): expected_value="base_link", ) - def test_check_topic_tool_call_field_wrong_name(self, mock_subtask): + def test_check_topic_tool_call_field_wrong_name( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_topic_tool_call_field fails with wrong tool name.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "wrong_name", "args": { "topic": "/test_topic", @@ -168,9 +186,11 @@ def test_check_topic_tool_call_field_wrong_name(self, mock_subtask): expected_value="test", ) - def test_check_topic_tool_call_field_wrong_topic(self, mock_subtask): + def test_check_topic_tool_call_field_wrong_topic( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_topic_tool_call_field fails with wrong topic.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/wrong_topic", @@ -189,9 +209,11 @@ def test_check_topic_tool_call_field_wrong_topic(self, mock_subtask): expected_value="test", ) - def test_check_topic_tool_call_field_wrong_message_type(self, mock_subtask): + def test_check_topic_tool_call_field_wrong_message_type( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_topic_tool_call_field fails with wrong message type.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -210,9 +232,11 @@ def test_check_topic_tool_call_field_wrong_message_type(self, mock_subtask): expected_value="test", ) - def test_check_topic_tool_call_field_missing_message(self, mock_subtask): + def test_check_topic_tool_call_field_missing_message( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_topic_tool_call_field fails with missing message.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -233,9 +257,11 @@ def test_check_topic_tool_call_field_missing_message(self, mock_subtask): expected_value="test", ) - def test_check_topic_tool_call_field_invalid_path(self, mock_subtask): + def test_check_topic_tool_call_field_invalid_path( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_topic_tool_call_field fails with invalid field path.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -256,9 +282,11 @@ def test_check_topic_tool_call_field_invalid_path(self, mock_subtask): expected_value="test", ) - def test_check_topic_tool_call_field_wrong_value(self, mock_subtask): + def test_check_topic_tool_call_field_wrong_value( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_topic_tool_call_field fails with wrong field value.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -279,9 +307,11 @@ def test_check_topic_tool_call_field_wrong_value(self, mock_subtask): expected_value="test", ) - def test_check_service_tool_call_field_valid(self, mock_subtask): + def test_check_service_tool_call_field_valid( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_service_tool_call_field with valid inputs.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_service", "args": { "service_name": "/test_service", @@ -299,9 +329,11 @@ def test_check_service_tool_call_field_valid(self, mock_subtask): expected_value=True, ) - def test_check_service_tool_call_field_empty_args(self, mock_subtask): + def test_check_service_tool_call_field_empty_args( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_service_tool_call_field with empty service_args.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_service", "args": { "service_name": "/test_service", @@ -319,9 +351,11 @@ def test_check_service_tool_call_field_empty_args(self, mock_subtask): expected_value={}, ) - def test_check_service_tool_call_field_wrong_service_name(self, mock_subtask): + def test_check_service_tool_call_field_wrong_service_name( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_service_tool_call_field fails with wrong service name.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_service", "args": { "service_name": "/wrong_service", @@ -340,9 +374,11 @@ def test_check_service_tool_call_field_wrong_service_name(self, mock_subtask): expected_value=True, ) - def test_check_service_tool_call_field_missing_service_args(self, mock_subtask): + def test_check_service_tool_call_field_missing_service_args( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_service_tool_call_field fails with missing service_args.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_service", "args": { "service_name": "/test_service", @@ -363,9 +399,11 @@ def test_check_service_tool_call_field_missing_service_args(self, mock_subtask): expected_value=True, ) - def test_check_action_tool_call_field_valid(self, mock_subtask): + def test_check_action_tool_call_field_valid( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_action_tool_call_field with valid inputs.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_action", "args": { "action_name": "/test_action", @@ -388,9 +426,11 @@ def test_check_action_tool_call_field_valid(self, mock_subtask): expected_value=["joint1", "joint2"], ) - def test_check_action_tool_call_field_empty_args(self, mock_subtask): + def test_check_action_tool_call_field_empty_args( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_action_tool_call_field with empty action_args.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_action", "args": { "action_name": "/test_action", @@ -408,9 +448,11 @@ def test_check_action_tool_call_field_empty_args(self, mock_subtask): expected_value={}, ) - def test_check_action_tool_call_field_wrong_action_name(self, mock_subtask): + def test_check_action_tool_call_field_wrong_action_name( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_action_tool_call_field fails with wrong action name.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_action", "args": { "action_name": "/wrong_action", @@ -429,9 +471,11 @@ def test_check_action_tool_call_field_wrong_action_name(self, mock_subtask): expected_value=True, ) - def test_check_tool_call_with_type_check_optional_args(self, mock_subtask): + def test_check_tool_call_with_type_check_optional_args( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_tool_call with type checking for optional arguments.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "test_tool", "args": { "arg1": "value1", @@ -443,9 +487,9 @@ def test_check_tool_call_with_type_check_optional_args(self, mock_subtask): }, } - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} - expected_optional_args = { + expected_optional_args: Dict[str, Any] = { "optional1": str, # expect string type "optional2": int, # expect int type "optional3": list, # expect list type @@ -460,9 +504,11 @@ def test_check_tool_call_with_type_check_optional_args(self, mock_subtask): expected_optional_args=expected_optional_args, ) - def test_check_tool_call_wrong_optional_arg_type(self, mock_subtask): + def test_check_tool_call_wrong_optional_arg_type( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_tool_call fails with wrong optional argument type.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "test_tool", "args": { "arg1": "value1", @@ -471,9 +517,11 @@ def test_check_tool_call_wrong_optional_arg_type(self, mock_subtask): }, } - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} - expected_optional_args = {"optional1": str} # expect string type + expected_optional_args: Dict[str, Any] = { + "optional1": str + } # expect string type with pytest.raises(SubTaskValidationError, match="has incorrect type"): mock_subtask._check_tool_call( @@ -483,16 +531,18 @@ def test_check_tool_call_wrong_optional_arg_type(self, mock_subtask): expected_optional_args=expected_optional_args, ) - def test_check_tool_call_multiple_types(self, mock_subtask): + def test_check_tool_call_multiple_types( + self, mock_subtask: ConcreteSubTask + ) -> None: """Test _check_tool_call with optional arguments accepting multiple types.""" - tool_call = { + tool_call: Dict[str, Any] = { "name": "test_tool", "args": {"arg1": "value1", "arg2": 42, "optional1": 123}, # int type } - expected_args = {"arg1": "value1", "arg2": 42} + expected_args: Dict[str, Any] = {"arg1": "value1", "arg2": 42} - expected_optional_args = { + expected_optional_args: Dict[str, Any] = { "optional1": (str, int) # accept either string or int } @@ -507,23 +557,29 @@ def test_check_tool_call_multiple_types(self, mock_subtask): class TestCheckArgsToolCallSubTask: """Test the CheckArgsToolCallSubTask implementation.""" - def test_validate_valid_args(self): + def test_validate_valid_args(self) -> None: """Test validate with valid arguments.""" subtask = CheckArgsToolCallSubTask( expected_tool_name="test_tool", expected_args={"arg1": "value1", "arg2": 42} ) - tool_call = {"name": "test_tool", "args": {"arg1": "value1", "arg2": 42}} + tool_call: Dict[str, Any] = { + "name": "test_tool", + "args": {"arg1": "value1", "arg2": 42}, + } assert subtask.validate(tool_call) - def test_validate_invalid_args(self): + def test_validate_invalid_args(self) -> None: """Test validate with invalid arguments.""" subtask = CheckArgsToolCallSubTask( expected_tool_name="test_tool", expected_args={"arg1": "value1", "arg2": 42} ) - tool_call = {"name": "test_tool", "args": {"arg1": "wrong_value", "arg2": 42}} + tool_call: Dict[str, Any] = { + "name": "test_tool", + "args": {"arg1": "wrong_value", "arg2": 42}, + } with pytest.raises(SubTaskValidationError): subtask.validate(tool_call) @@ -532,7 +588,7 @@ def test_validate_invalid_args(self): class TestCheckTopicFieldsToolCallSubTask: """Test the CheckTopicFieldsToolCallSubTask implementation.""" - def test_validate_valid_fields(self): + def test_validate_valid_fields(self) -> None: """Test validate with valid fields.""" subtask = CheckTopicFieldsToolCallSubTask( expected_tool_name="publish_ros2_message", @@ -541,7 +597,7 @@ def test_validate_valid_fields(self): expected_fields={"data": "test message"}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -552,7 +608,7 @@ def test_validate_valid_fields(self): assert subtask.validate(tool_call) - def test_validate_multiple_fields(self): + def test_validate_multiple_fields(self) -> None: """Test validate with multiple fields.""" subtask = CheckTopicFieldsToolCallSubTask( expected_tool_name="publish_ros2_message", @@ -565,7 +621,7 @@ def test_validate_multiple_fields(self): }, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -579,7 +635,7 @@ def test_validate_multiple_fields(self): assert subtask.validate(tool_call) - def test_validate_invalid_field(self): + def test_validate_invalid_field(self) -> None: """Test validate with invalid field value.""" subtask = CheckTopicFieldsToolCallSubTask( expected_tool_name="publish_ros2_message", @@ -588,7 +644,7 @@ def test_validate_invalid_field(self): expected_fields={"data": "expected message"}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "publish_ros2_message", "args": { "topic": "/test_topic", @@ -604,7 +660,7 @@ def test_validate_invalid_field(self): class TestCheckServiceFieldsToolCallSubTask: """Test the CheckServiceFieldsToolCallSubTask implementation.""" - def test_validate_valid_fields(self): + def test_validate_valid_fields(self) -> None: """Test validate with valid fields.""" subtask = CheckServiceFieldsToolCallSubTask( expected_tool_name="call_ros2_service", @@ -613,7 +669,7 @@ def test_validate_valid_fields(self): expected_fields={"data": True}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_service", "args": { "service_name": "/test_service", @@ -624,7 +680,7 @@ def test_validate_valid_fields(self): assert subtask.validate(tool_call) - def test_validate_multiple_fields(self): + def test_validate_multiple_fields(self) -> None: """Test validate with multiple fields.""" subtask = CheckServiceFieldsToolCallSubTask( expected_tool_name="call_ros2_service", @@ -633,7 +689,7 @@ def test_validate_multiple_fields(self): expected_fields={"request_field.subfield": "value", "flag": True}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_service", "args": { "service_name": "/test_service", @@ -644,7 +700,27 @@ def test_validate_multiple_fields(self): assert subtask.validate(tool_call) - def test_validate_empty_args(self): + def test_validate_empty_args(self) -> None: + """Test validate with empty service args.""" + subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service="/test_service", + expected_service_type="std_srvs/srv/SetParam", + expected_fields={"value.0.any": 30}, + ) + + tool_call: Dict[str, Any] = { + "name": "call_ros2_service", + "args": { + "service_name": "/test_service", + "service_type": "std_srvs/srv/SetParam", + "service_args": {"value": [{"any": 30}]}, + }, + } + + assert subtask.validate(tool_call) + + def test_validate_list_in_params(self) -> None: """Test validate with empty service args.""" subtask = CheckServiceFieldsToolCallSubTask( expected_tool_name="call_ros2_service", @@ -653,7 +729,7 @@ def test_validate_empty_args(self): expected_fields={"": {}}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_service", "args": { "service_name": "/test_service", @@ -668,7 +744,7 @@ def test_validate_empty_args(self): class TestCheckActionFieldsToolCallSubTask: """Test the CheckActionFieldsToolCallSubTask implementation.""" - def test_validate_valid_fields(self): + def test_validate_valid_fields(self) -> None: """Test validate with valid fields.""" subtask = CheckActionFieldsToolCallSubTask( expected_tool_name="call_ros2_action", @@ -677,7 +753,7 @@ def test_validate_valid_fields(self): expected_fields={"command.position": 0.5}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_action", "args": { "action_name": "/test_action", @@ -688,7 +764,7 @@ def test_validate_valid_fields(self): assert subtask.validate(tool_call) - def test_validate_multiple_fields(self): + def test_validate_multiple_fields(self) -> None: """Test validate with multiple fields.""" subtask = CheckActionFieldsToolCallSubTask( expected_tool_name="call_ros2_action", @@ -697,7 +773,7 @@ def test_validate_multiple_fields(self): expected_fields={"goal.x": 1.0, "goal.y": 2.0, "speed": 0.5}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_action", "args": { "action_name": "/test_action", @@ -708,7 +784,7 @@ def test_validate_multiple_fields(self): assert subtask.validate(tool_call) - def test_validate_empty_args(self): + def test_validate_empty_args(self) -> None: """Test validate with empty action args.""" subtask = CheckActionFieldsToolCallSubTask( expected_tool_name="call_ros2_action", @@ -717,7 +793,7 @@ def test_validate_empty_args(self): expected_fields={"": {}}, ) - tool_call = { + tool_call: Dict[str, Any] = { "name": "call_ros2_action", "args": { "action_name": "/test_action", diff --git a/tests/rai_bench/tool_calling_agent/test_validators.py b/tests/rai_bench/tool_calling_agent/test_validators.py index 98ee84252..3e21c0d39 100644 --- a/tests/rai_bench/tool_calling_agent/test_validators.py +++ b/tests/rai_bench/tool_calling_agent/test_validators.py @@ -20,6 +20,7 @@ from rai_bench.tool_calling_agent.interfaces import SubTaskValidationError, Validator from rai_bench.tool_calling_agent.validators import ( NotOrderedCallsValidator, + OneFromManyValidator, OrderedCallsValidator, ) @@ -674,3 +675,345 @@ def test_validate_reset(self): expected_subtasks_passed=[False, False], expected_errors_counts=[2, 2], ) + + +class TestOptionalValidator: + def test_init_with_empty_subtasks(self): + with pytest.raises(ValueError, match="Validator must have at least 1 subtask"): + OneFromManyValidator(subtasks=[]) + + def test_validate_empty_tool_calls(self): + subtasks = [DummySubTask("task1")] + validator = OneFromManyValidator(subtasks=subtasks) + + success, remaining = validator.validate(tool_calls=[]) + + assert not success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 0 + assert validator.subtasks_passed[0] is False + assert validator.passed is False + assert validator.extra_calls_used == 0 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=False, + expected_extra_calls=0, + expected_subtasks_passed=[False], + expected_errors_counts=[0], + ) + + def test_validate_successful_first_subtask_matches(self): + subtasks = [ + DummySubTask("task1", specific_tool="tool1"), + DummySubTask("task2", specific_tool="tool2"), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ToolCall(name="tool1")] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 0 + assert len(validator.subtasks_errors[1]) == 0 + assert validator.subtasks_passed[0] is True + assert validator.subtasks_passed[1] is False + assert validator.passed is True + assert validator.extra_calls_used == 0 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=True, + expected_extra_calls=0, + expected_subtasks_passed=[True, False], + expected_errors_counts=[0, 0], + ) + + def test_validate_successful_second_subtask_matches(self): + subtasks = [ + DummySubTask("task1", specific_tool="tool1"), + DummySubTask("task2", specific_tool="tool2"), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ToolCall(name="tool2")] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 1 # Error from trying tool1 + assert len(validator.subtasks_errors[1]) == 0 + assert validator.subtasks_passed[0] is False + assert validator.subtasks_passed[1] is True + assert validator.passed is True + assert validator.extra_calls_used == 0 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=True, + expected_extra_calls=0, + expected_subtasks_passed=[False, True], + expected_errors_counts=[1, 0], + ) + + def test_validate_successful_with_excess_tool_calls(self): + subtasks = [ + DummySubTask("task1", specific_tool="tool1"), + DummySubTask("task2", specific_tool="tool2"), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ + ToolCall(name="tool1"), + ToolCall(name="extra_tool"), + ToolCall(name="another_extra"), + ] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert success + assert len(remaining) == 2 + assert remaining[0].name == "extra_tool" + assert remaining[1].name == "another_extra" + assert len(validator.subtasks_errors[0]) == 0 + assert len(validator.subtasks_errors[1]) == 0 + assert validator.subtasks_passed[0] is True + assert validator.subtasks_passed[1] is False + assert validator.passed is True + assert validator.extra_calls_used == 2 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=True, + expected_extra_calls=2, + expected_subtasks_passed=[True, False], + expected_errors_counts=[0, 0], + ) + + def test_validate_successful_after_failed_attempts(self): + subtasks = [ + DummySubTask("task1", specific_tool="tool1"), + DummySubTask("task2", specific_tool="tool2"), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ + ToolCall(name="wrong_tool"), + ToolCall(name="another_wrong"), + ToolCall(name="tool2"), + ToolCall(name="extra_tool"), + ] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert success + assert len(remaining) == 1 + assert remaining[0].name == "extra_tool" + assert len(validator.subtasks_errors[0]) == 3 # 3 failed attempts + assert ( + len(validator.subtasks_errors[1]) == 2 + ) # 2 failed attempts before success + assert validator.subtasks_passed[0] is False + assert validator.subtasks_passed[1] is True + assert validator.passed is True + assert validator.extra_calls_used == 3 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=True, + expected_extra_calls=3, + expected_subtasks_passed=[False, True], + expected_errors_counts=[3, 2], + ) + + def test_validate_failure_no_subtask_matches(self): + subtasks = [ + DummySubTask("task1", specific_tool="tool1"), + DummySubTask("task2", specific_tool="tool2"), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ + ToolCall(name="wrong_tool"), + ToolCall(name="another_wrong"), + ] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert not success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 2 + assert len(validator.subtasks_errors[1]) == 2 + assert "Expected tool tool1, got wrong_tool" in validator.subtasks_errors[0][0] + assert "Expected tool tool2, got wrong_tool" in validator.subtasks_errors[1][0] + assert validator.subtasks_passed[0] is False + assert validator.subtasks_passed[1] is False + assert validator.passed is False + assert validator.extra_calls_used == 1 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=False, + expected_extra_calls=1, + expected_subtasks_passed=[False, False], + expected_errors_counts=[2, 2], + ) + + def test_validate_failure_subtask_validation_error(self): + subtasks = [ + DummySubTask("task1", outcomes=[False]), + DummySubTask("task2", outcomes=[False]), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ToolCall()] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert not success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 1 + assert len(validator.subtasks_errors[1]) == 1 + assert "error in task1" in validator.subtasks_errors[0][0] + assert "error in task2" in validator.subtasks_errors[1][0] + assert validator.subtasks_passed[0] is False + assert validator.subtasks_passed[1] is False + assert validator.passed is False + assert validator.extra_calls_used == 0 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=False, + expected_extra_calls=0, + expected_subtasks_passed=[False, False], + expected_errors_counts=[1, 1], + ) + + def test_validate_single_subtask_success(self): + subtasks = [DummySubTask("task1")] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ToolCall()] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 0 + assert validator.subtasks_passed[0] is True + assert validator.passed is True + assert validator.extra_calls_used == 0 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=True, + expected_extra_calls=0, + expected_subtasks_passed=[True], + expected_errors_counts=[0], + ) + + def test_validate_single_subtask_failure(self): + subtasks = [DummySubTask("task1", outcomes=[False])] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ToolCall()] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert not success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 1 + assert "error in task1" in validator.subtasks_errors[0][0] + assert validator.subtasks_passed[0] is False + assert validator.passed is False + assert validator.extra_calls_used == 0 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=False, + expected_extra_calls=0, + expected_subtasks_passed=[False], + expected_errors_counts=[1], + ) + + def test_validate_many_subtasks_last_one_succeeds(self): + subtasks = [ + DummySubTask("task1", specific_tool="tool1"), + DummySubTask("task2", specific_tool="tool2"), + DummySubTask("task3", specific_tool="tool3"), + DummySubTask("task4", specific_tool="tool4"), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ToolCall(name="tool4")] + + success, remaining = validator.validate(tool_calls=tool_calls) + + assert success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 1 + assert len(validator.subtasks_errors[1]) == 1 + assert len(validator.subtasks_errors[2]) == 1 + assert len(validator.subtasks_errors[3]) == 0 + assert validator.subtasks_passed[0] is False + assert validator.subtasks_passed[1] is False + assert validator.subtasks_passed[2] is False + assert validator.subtasks_passed[3] is True + assert validator.passed is True + assert validator.extra_calls_used == 0 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=True, + expected_extra_calls=0, + expected_subtasks_passed=[False, False, False, True], + expected_errors_counts=[1, 1, 1, 0], + ) + + def test_validate_reset(self): + subtasks = [ + DummySubTask("task1", outcomes=4 * [False]), + DummySubTask("task2", outcomes=4 * [False]), + ] + validator = OneFromManyValidator(subtasks=subtasks) + tool_calls = [ToolCall(), ToolCall()] + + # First validation call + validator.validate(tool_calls=tool_calls) + # Second validation call (should reset) + success, remaining = validator.validate(tool_calls=tool_calls) + + assert not success + assert remaining == [] + assert len(validator.subtasks_errors[0]) == 2 + assert len(validator.subtasks_errors[1]) == 2 + assert "error in task1" in validator.subtasks_errors[0][0] + assert "error in task2" in validator.subtasks_errors[1][0] + assert validator.subtasks_passed[0] is False + assert validator.subtasks_passed[1] is False + assert validator.passed is False + assert validator.extra_calls_used == 1 + + assert_dumped( + validator, + expected_type="optional", + expected_passed=False, + expected_extra_calls=1, + expected_subtasks_passed=[False, False], + expected_errors_counts=[2, 2], + ) + + def test_required_calls_property(self): + subtasks = [ + DummySubTask("task1"), + DummySubTask("task2"), + DummySubTask("task3"), + ] + validator = OneFromManyValidator(subtasks=subtasks) + + # OptionalValidator should only require 1 call + assert validator.required_calls == 1 From 5340ca9ba564e4dc2dd9828337f5a416b4ce7560 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 25 Jul 2025 11:02:48 +0200 Subject: [PATCH 03/13] feat: tool calling custom interfaces tasks extension (#636) --- src/rai_bench/rai_bench/test_models.py | 1 + .../rai_bench/tool_calling_agent/benchmark.py | 2 +- .../tool_calling_agent/interfaces.py | 2 +- .../mocked_ros2_interfaces.py | 5 +- .../tool_calling_agent/mocked_tools.py | 4 +- .../predefined/basic_tasks.py | 23 +- .../predefined/custom_interfaces_tasks.py | 255 ++- .../tool_calling_agent/tasks/basic.py | 27 - .../tasks/custom_interfaces.py | 1123 ++++++++++--- src/rai_core/rai/types/rai_interfaces.py | 1 + ...asks.py => test_predefined_basic_tasks.py} | 24 +- ...test_predefined_custom_interfaces_tasks.py | 1411 +++++++++++++++++ 12 files changed, 2548 insertions(+), 330 deletions(-) rename tests/rai_bench/tool_calling_agent/{test_predefined_tasks.py => test_predefined_basic_tasks.py} (100%) create mode 100644 tests/rai_bench/tool_calling_agent/test_predefined_custom_interfaces_tasks.py diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index def97016c..d4949f7cd 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -218,6 +218,7 @@ def test_models( vendor=vendors[i], **additional_model_args[i], ) + # TODO (jmatejcz) take param to set log level bench_logger = define_benchmark_logger(out_dir=Path(curr_out_dir)) try: if isinstance(bench_conf, ToolCallingAgentBenchmarkConfig): diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index cf73af238..aeb059d6b 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -113,7 +113,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: messages: List[BaseMessage] = [] prev_count: int = 0 try: - with self.time_limit(60): + with self.time_limit(20 * task.max_tool_calls_number): if isinstance(task, SpatialReasoningAgentTask): for state in agent.stream( { diff --git a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py index 3ee930bbf..7d8cb3187 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py @@ -273,7 +273,7 @@ def _check_service_tool_call_field( if args.get("service_name") != expected_service: raise SubTaskValidationError( - f"Expected service '{expected_service}', but got '{args.get('service')}'." + f"Expected service '{expected_service}', but got '{args.get('service_name')}'." ) if args.get("service_type") != expected_service_type: diff --git a/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py index 4dd584cb2..4971fc514 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py @@ -50,7 +50,6 @@ # dict of interfaces where keys are interfaces types and values are output # of GetROS2MessageInterfaceTool which are same as ros2 interface show outputs -# the dict contains custom as well as couple other common interfaces COMMON_INTERFACES: Dict[str, str] = { "std_srvs/srv/Empty": """# Empty service - no request or response @@ -3486,7 +3485,7 @@ } CUSTOM_TOPICS_AND_TYPES: Dict[str, str] = { - "/hri_message": "rai_interfaces/msg/HRIMessage", + "/to_human": "rai_interfaces/msg/HRIMessage", "/audio_message": "rai_interfaces/msg/AudioMessage", "/detection_array": "rai_interfaces/msg/RAIDetectionArray", } @@ -3826,7 +3825,7 @@ "/manipulator_move_to": "rai_interfaces/srv/ManipulatorMoveTo", "/get_log_digest": "rai_interfaces/srv/StringList", "/rai_whoami_documentation_service": "rai_interfaces/srv/VectorStoreRetrieval", - "/rai/whatisee/get": "rai_interfaces/srv/WhatISee", + "/rai_whatisee_get": "rai_interfaces/srv/WhatISee", } MANIPULATION_ACTIONS_AND_TYPES: Dict[str, str] = { diff --git a/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py b/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py index 443a620a6..fc80d2a82 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py @@ -15,7 +15,7 @@ import copy import uuid from threading import Lock -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Tuple, Type from unittest.mock import MagicMock import numpy as np @@ -360,7 +360,7 @@ def _run( self, service_name: str, service_type: str, - service_args: Optional[Dict[str, Any]] = None, + service_args: Dict[str, Any] = {}, timeout_sec: float = 1.0, ) -> str: """ diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py index aa8e99f1d..b0c18276a 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py @@ -22,15 +22,7 @@ CheckServiceFieldsToolCallSubTask, ) from rai_bench.tool_calling_agent.tasks.basic import ( - BOX1_ENTITY, - BOX1_POSITION, - BOX2_ENTITY, - BOX2_POSITION, COLOR_IMAGE_TOPIC, - DEFAULT_DINO_CONFIDENCE, - DEFAULT_FPS, - DEFAULT_PUBLISH_FREQUENCY, - DEFAULT_SAM_CONFIDENCE, DEPTH_IMAGE_TOPIC, GET_SPAWNABLE_NAMES_SERVICE, GET_WORLD_PROPERTIES_TYPE, @@ -38,7 +30,6 @@ POINTCLOUD_TOPIC, ROBOT_DESCRIPTION_TOPIC, ROBOT_STATE_PUBLISHER_LIST_PARAMS, - TOMATO_ENTITY, CheckSpawnableEntitiesTask, ConfigureVisionPipelineTask, GetAllROS2CamerasTask, @@ -59,6 +50,20 @@ OrderedCallsValidator, ) +DEFAULT_PUBLISH_FREQUENCY = 30.0 +DEFAULT_FPS = 30 +DEFAULT_SAM_CONFIDENCE = 0.8 +DEFAULT_DINO_CONFIDENCE = 0.7 +SAM_CONFIDENCE_2 = 0.6 +DINO_CONFIDENCE_2 = 0.6 +FPS_2 = 10 + +TOMATO_ENTITY = "tomato" +BOX1_ENTITY = "box1" +BOX2_ENTITY = "box2" +BOX1_POSITION = (0.2, 0.2, 0.2) +BOX2_POSITION = (0.4, 0.4, 0.2) + ########## SUBTASKS FOR TASKS WITHOUT REFACTORED VALIDATORS ################################################################# get_topics_subtask = CheckArgsToolCallSubTask( expected_tool_name="get_ros2_topics_names_and_types", expected_args={} diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py index 36f84a31b..e284ac452 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/custom_interfaces_tasks.py @@ -18,58 +18,77 @@ Task, TaskArgs, ) -from rai_bench.tool_calling_agent.subtasks import ( - CheckArgsToolCallSubTask, - CheckTopicFieldsToolCallSubTask, -) from rai_bench.tool_calling_agent.tasks.custom_interfaces import ( + CallGetLogDigestTask, + CallGetLogDigestTaskIndirect, + CallGroundedSAMSegmentTask, + CallGroundedSAMSegmentTaskIndirect, + CallGroundingDinoClassify, + CallGroundingDinoClassifyIndirect, + CallROS2ManipulatorMoveToServiceTask, + CallROS2ManipulatorMoveToServiceTaskIndirect, + CallVectorStoreRetrievalTask, + CallVectorStoreRetrievalTaskIndirect, + CompleteObjectInteractionTask, + CompleteObjectInteractionTaskIndirect, + EmergencyResponseProtocolTask, + EmergencyResponseProtocolTaskIndirect, + MultiModalSceneDocumentationTask, + MultiModalSceneDocumentationTaskIndirect, + PublishROS2AudioMessageTask, + PublishROS2AudioMessageTaskIndirect, + PublishROS2DetectionArrayTask, + PublishROS2DetectionArrayTaskIndirect, PublishROS2HRIMessageTextTask, -) -from rai_bench.tool_calling_agent.validators import ( - OrderedCallsValidator, + PublishROS2HRIMessageTextTaskIndirect, ) -########## SUBTASKS ################################################################# -pub_HRIMessage_text_subtask = CheckTopicFieldsToolCallSubTask( - expected_tool_name="publish_ros2_message", - expected_topic="/to_human", - expected_message_type="rai_interfaces/msg/HRIMessage", - expected_fields={"text": "Hello!"}, -) +# Object Classes +PERSON_CLASS = "person" +BOTTLE_CLASS = "bottle" -get_HRIMessage_interface_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_message_interface", - expected_args={"msg_type": "rai_interfaces/msg/HRIMessage"}, -) +# Text Messages +HRI_TEXT = "Hello!" +# Audio Parameters +BASIC_AUDIO_SAMPLES = [123, 456, 789] +BASIC_SAMPLE_RATE = 44100 +BASIC_CHANNELS = 2 -######### VALIDATORS ######################################################################################### -pub_HRIMessage_text_ord_val = OrderedCallsValidator( - subtasks=[pub_HRIMessage_text_subtask] -) -get_interface_publish_ord_val = OrderedCallsValidator( - subtasks=[ - get_HRIMessage_interface_subtask, - pub_HRIMessage_text_subtask, - ] -) +# Position Parameters +STANDARD_TARGET_POSITION = (1.0, 2.0, 3.0) + +# Query Strings +ROBOT_PURPOSE_QUERY = "What is the purpose of this robot?" +GROUNDING_DINO_CLASSES = "person, bottle" + +# Default scene objects for documentation +DEFAULT_SCENE_OBJECTS = ["person", "bottle"] def get_custom_interfaces_tasks( extra_tool_calls: List[int] = [0], prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], + include_indirect: bool = True, ) -> List[Task]: """Get predefined custom_interfaces tasks. Parameters ---------- - Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. - See the class documentation for parameter descriptions. + extra_tool_calls : List[int] + Number of extra tool calls allowed beyond the minimum required. + prompt_detail : List[Literal["brief", "descriptive"]] + Level of detail in task prompts. + n_shots : List[Literal[0, 2, 5]] + Number of examples in system prompt. + include_indirect : bool + Whether to include indirect (natural language) task variants. Returns ------- - Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. + List[Task] + List of task instances for benchmarking. """ tasks: List[Task] = [] @@ -81,15 +100,179 @@ def get_custom_interfaces_tasks( prompt_detail=detail, examples_in_system_prompt=shots, ) + + # Direct tasks (original) tasks.append( PublishROS2HRIMessageTextTask( - topic="/to_human", - validators=[ - get_interface_publish_ord_val, - ], task_args=task_args, - text="Hello!", + text=HRI_TEXT, ), ) + tasks.append( + PublishROS2AudioMessageTask( + task_args=task_args, + audio=BASIC_AUDIO_SAMPLES, + sample_rate=BASIC_SAMPLE_RATE, + channels=BASIC_CHANNELS, + ) + ) + + tasks.append( + PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=[PERSON_CLASS], + ) + ) + tasks.append( + PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=[BOTTLE_CLASS, PERSON_CLASS], + ) + ) + + tasks.append( + CallROS2ManipulatorMoveToServiceTask( + task_args=task_args, + target_x=STANDARD_TARGET_POSITION[0], + target_y=STANDARD_TARGET_POSITION[1], + target_z=STANDARD_TARGET_POSITION[2], + initial_gripper_state=True, + final_gripper_state=False, + ) + ) + + tasks.append( + CallGroundedSAMSegmentTask( + task_args=task_args, + detection_classes=[BOTTLE_CLASS], + ) + ) + + tasks.append( + CallGroundingDinoClassify( + task_args=task_args, + classes=GROUNDING_DINO_CLASSES, + ) + ) + + tasks.append( + CallGetLogDigestTask( + task_args=task_args, + ) + ) + tasks.append( + CallVectorStoreRetrievalTask( + task_args=task_args, + query=ROBOT_PURPOSE_QUERY, + ) + ) + + tasks.append( + CompleteObjectInteractionTask( + task_args=task_args, + target_class=BOTTLE_CLASS, + ) + ) + + tasks.append( + MultiModalSceneDocumentationTask( + task_args=task_args, + objects=DEFAULT_SCENE_OBJECTS, + ) + ) + + tasks.append( + EmergencyResponseProtocolTask( + task_args=task_args, + target_class=PERSON_CLASS, + ) + ) + + if include_indirect: + tasks.append( + PublishROS2HRIMessageTextTaskIndirect( + task_args=task_args, + text=HRI_TEXT, + ), + ) + tasks.append( + PublishROS2AudioMessageTaskIndirect( + task_args=task_args, + audio=BASIC_AUDIO_SAMPLES, + sample_rate=BASIC_SAMPLE_RATE, + channels=BASIC_CHANNELS, + ) + ) + + tasks.append( + PublishROS2DetectionArrayTaskIndirect( + task_args=task_args, + detection_classes=[PERSON_CLASS], + ) + ) + tasks.append( + PublishROS2DetectionArrayTaskIndirect( + task_args=task_args, + detection_classes=[BOTTLE_CLASS, PERSON_CLASS], + ) + ) + + tasks.append( + CallROS2ManipulatorMoveToServiceTaskIndirect( + task_args=task_args, + target_x=STANDARD_TARGET_POSITION[0], + target_y=STANDARD_TARGET_POSITION[1], + target_z=STANDARD_TARGET_POSITION[2], + initial_gripper_state=True, + final_gripper_state=False, + ) + ) + + tasks.append( + CallGroundedSAMSegmentTaskIndirect( + task_args=task_args, + detection_classes=[BOTTLE_CLASS], + ) + ) + + tasks.append( + CallGroundingDinoClassifyIndirect( + task_args=task_args, + classes=GROUNDING_DINO_CLASSES, + ) + ) + + tasks.append( + CallGetLogDigestTaskIndirect( + task_args=task_args, + ) + ) + tasks.append( + CallVectorStoreRetrievalTaskIndirect( + task_args=task_args, + query=ROBOT_PURPOSE_QUERY, + ) + ) + + tasks.append( + CompleteObjectInteractionTaskIndirect( + task_args=task_args, + target_class=BOTTLE_CLASS, + ) + ) + + tasks.append( + MultiModalSceneDocumentationTaskIndirect( + task_args=task_args, + objects=DEFAULT_SCENE_OBJECTS, + ) + ) + + tasks.append( + EmergencyResponseProtocolTaskIndirect( + task_args=task_args, + target_class=PERSON_CLASS, + ) + ) return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py index 929f027b8..0116f2b01 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/basic.py @@ -77,33 +77,6 @@ DELETE_ENTITY_TYPE = "gazebo_msgs/srv/DeleteEntity" GET_WORLD_PROPERTIES_TYPE = "gazebo_msgs/srv/GetWorldProperties" -DEFAULT_PUBLISH_FREQUENCY = 30.0 -DEFAULT_FPS = 30 -DEFAULT_SAM_CONFIDENCE = 0.8 -DEFAULT_DINO_CONFIDENCE = 0.7 -SAM_CONFIDENCE_2 = 0.6 -DINO_CONFIDENCE_2 = 0.6 -FPS_2 = 10 - -TOMATO_ENTITY = "tomato" -BOX1_ENTITY = "box1" -BOX2_ENTITY = "box2" -BOX1_POSITION = (0.2, 0.2, 0.2) -BOX2_POSITION = (0.4, 0.4, 0.2) - -CAMERA_TOPICS_AND_TYPES = [ - f"topic: {COLOR_CAMERA_INFO_TOPIC}\ntype: sensor_msgs/msg/CameraInfo\n", - f"topic: {COLOR_IMAGE_TOPIC}\ntype: sensor_msgs/msg/Image\n", - f"topic: {DEPTH_CAMERA_INFO_TOPIC}\ntype: sensor_msgs/msg/CameraInfo\n", - f"topic: {DEPTH_IMAGE_TOPIC}\ntype: sensor_msgs/msg/Image\n", -] - -CAMERA_TOPICS = [ - COLOR_CAMERA_INFO_TOPIC, - COLOR_IMAGE_TOPIC, - DEPTH_CAMERA_INFO_TOPIC, - DEPTH_IMAGE_TOPIC, -] PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_0_SHOT = """You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system. Be proactive and use the tools to answer questions.""" diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py index 0cff01ff7..370333d87 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/custom_interfaces.py @@ -14,23 +14,9 @@ import logging from abc import ABC -from typing import Any, List +from typing import Any, Dict, List, Optional, Tuple from langchain_core.tools import BaseTool -from rai.types import ( - BoundingBox2D, - Detection2D, - Header, - Point, - Pose, - Pose2D, - PoseStamped, - Quaternion, - Time, -) -from rai.types.rai_interfaces import ( - RAIDetectionArray, -) from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator from rai_bench.tool_calling_agent.mocked_ros2_interfaces import ( @@ -51,8 +37,66 @@ MockGetROS2TopicsNamesAndTypesTool, MockPublishROS2MessageTool, ) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, + CheckServiceFieldsToolCallSubTask, + CheckTopicFieldsToolCallSubTask, +) +from rai_bench.tool_calling_agent.validators import ( + OrderedCallsValidator, +) + +HRI_TOPIC = "/to_human" +AUDIO_TOPIC = "/audio_message" +DETECTIONS_TOPIC = "/detection_array" +MANIPULATOR_SERVICE = "/manipulator_move_to" +GROUNDED_SAM_SERVICE = "/grounded_sam_segment" +GROUNDING_DINO_SERVICE = "/grounding_dino_classify" +LOG_DIGEST_SERVICE = "/get_log_digest" +VECTOR_STORE_SERVICE = "/rai_whoami_documentation_service" + +HRI_MESSAGE_TYPE = "rai_interfaces/msg/HRIMessage" +AUDIO_MESSAGE_TYPE = "rai_interfaces/msg/AudioMessage" +DETECTION_ARRAY_MESSAGE_TYPE = "rai_interfaces/msg/RAIDetectionArray" +MANIPULATOR_SERVICE_TYPE = "rai_interfaces/srv/ManipulatorMoveTo" +GROUNDED_SAM_SERVICE_TYPE = "rai_interfaces/srv/RAIGroundedSam" +GROUNDING_DINO_SERVICE_TYPE = "rai_interfaces/srv/RAIGroundingDino" +STRING_LIST_SERVICE_TYPE = "rai_interfaces/srv/StringList" +VECTOR_STORE_SERVICE_TYPE = "rai_interfaces/srv/VectorStoreRetrieval" + +STANDARD_IMAGE_WIDTH = 640 +STANDARD_IMAGE_HEIGHT = 480 +STANDARD_IMAGE_ENCODING = "rgb8" + +PERSON_BBOX_CENTER = (320.0, 320.0) +PERSON_BBOX_SIZE = (50.0, 50.0) +PERSON_SCORE = 0.85 +PERSON_POSITION_3D = (2.0, 0.0, 0.0) + +BOTTLE_BBOX_CENTER = (320.0, 240.0) +BOTTLE_BBOX_SIZE = (80.0, 120.0) +BOTTLE_SCORE = 0.87 +BOTTLE_POSITION_3D = (1.2, 0.0, 0.5) + +DETECTION_DEFAULTS: Dict[str, Any] = { + "person": { + "bbox_center": PERSON_BBOX_CENTER, + "bbox_size": PERSON_BBOX_SIZE, + "score": PERSON_SCORE, + "position_3d": PERSON_POSITION_3D, + }, + "bottle": { + "bbox_center": BOTTLE_BBOX_CENTER, + "bbox_size": BOTTLE_BBOX_SIZE, + "score": BOTTLE_SCORE, + "position_3d": BOTTLE_POSITION_3D, + }, +} + +DEFAULT_BOX_THRESHOLD: float = 0.4 +DEFAULT_TEXT_THRESHOLD: float = 0.25 + -loggers_type = logging.Logger INTERFACES = COMMON_INTERFACES | CUSTOM_INTERFACES TOPICS_AND_TYPES = COMMON_TOPICS_AND_TYPES | CUSTOM_TOPICS_AND_TYPES @@ -85,13 +129,34 @@ PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT = ( PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + """ -- get_ros2_topics_names_and_types, args: {} -- get_ros2_message_interface, args: {'msg_type': 'rai_interfaces/msg/HRIMessage'} -- call_ros2_service, args: {'service': '/grounding_dino_classify', 'service_type': 'rai_interfaces/srv/RAIGroundingDino', 'request': {'classes': 'bottle, book', 'box_threshold': 0.4, 'text_threshold': 0.25}}""" +- get_ros2_services_names_and_types, args: {} +- get_ros2_message_interface, args: {'msg_type': 'moveit_msgs/srv/ExecuteKnownTrajectory'} +- name: call_ros2_service, args: { + "service_name": "/execute_trajectory", + "service_type": "moveit_msgs/srv/ExecuteKnownTrajectory", + "service_args": { + "trajectory": { + "joint_trajectory": { + "header": {"frame_id": "base_link"}, + "joint_names": ["joint1", "joint2"], + "points": [{ + "positions": [0.0, 1.57], + "time_from_start": {"sec": 2, "nanosec": 0} + }] + } + }, + "wait_for_execution": True + } + }""" ) class CustomInterfaceTask(Task, ABC): + """Custom Interface Tasks are designed around out custom interfaces in RAI + In these tasks we want to evaulate how well agent can understand these interfaes + and fill them as requested + """ + type = "custom_interface" @property @@ -108,14 +173,6 @@ def get_system_prompt(self) -> str: else: return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT - -class CustomInterfacesTopicTask(CustomInterfaceTask, ABC): - def __init__( - self, topic: str, validators: List[Validator], task_args: TaskArgs - ) -> None: - super().__init__(validators=validators, task_args=task_args) - self.topic = topic - @property def available_tools(self) -> List[BaseTool]: return [ @@ -128,24 +185,6 @@ def available_tools(self) -> List[BaseTool]: available_message_types=list(TOPICS_AND_TYPES.values()), available_topic_models=TOPIC_MODELS, ), - ] - - -class CustomInterfacesServiceTask(CustomInterfaceTask, ABC): - def __init__( - self, - service: str, - service_args: dict[str, Any], - validators: List[Validator], - task_args: TaskArgs, - ) -> None: - super().__init__(validators=validators, task_args=task_args) - self.service = service - self.service_args = service_args - - @property - def available_tools(self) -> List[BaseTool]: - return [ MockGetROS2ServicesNamesAndTypesTool( mock_service_names_and_types=SERVICE_STRINGS ), @@ -158,55 +197,141 @@ def available_tools(self) -> List[BaseTool]: ] -class PublishROS2HRIMessageTextTask(CustomInterfacesTopicTask): +class CustomInterfacesServiceTask(CustomInterfaceTask, ABC): + """ + Base class for tasks that involve calling SINGLE service with custom interface. + """ + + descriptive_sufix = ( + " Examine the required service interface and call " + "it with appropriate arguments." + ) + + def get_prompt(self) -> str: + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return self.get_base_prompt() + self.descriptive_sufix + + +class CustomInterfacesServicesTask(CustomInterfacesServiceTask, ABC): + """ + Base class for tasks that involve calling MULITPLE services with custom interface. + """ + + descriptive_sufix = ( + " Examine the required services interfaces and call " + "them with appropriate arguments." + ) + + +#### NOTE(jmatejcz) Tasks come in 2 versions: +#### - basic one that gives required topic or service and field values directly, +#### they verify if model can handle custom interface +#### - indirect, where prompts are more natural and indirect, +#### they verify if model can on top understand and use the topic or service in suitable situation + + +class PublishROS2HRIMessageTextTask(CustomInterfaceTask): complexity = "easy" def __init__( self, - topic: str, - validators: List[Validator], + text: str, task_args: TaskArgs, - text: str = "Hello!", + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, ) -> None: - super().__init__(topic, validators=validators, task_args=task_args) self.text = text + if validators is None: + # Default validator for this task + get_HRIMessage_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": HRI_MESSAGE_TYPE}, + ) + pub_HRIMessage_text_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=HRI_TOPIC, + expected_message_type=HRI_MESSAGE_TYPE, + expected_fields={"text": text}, + ) + validators = [ + OrderedCallsValidator( + subtasks=[ + get_HRIMessage_interface_subtask, + pub_HRIMessage_text_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) def get_base_prompt(self) -> str: - return f"Publish message to topic '{self.topic}' with text: '{self.text}'." + return f"Publish message to topic '{HRI_TOPIC}' with text '{self.text}'." def get_prompt(self) -> str: if self.prompt_detail == "brief": - return self.get_base_prompt() + return f"{self.get_base_prompt()}." else: return ( - f"{self.get_base_prompt()} " - "You can discover available topics, examine the message interface " - f"structure, and publish an HRI message containing the text '{self.text}'." + f"{self.get_base_prompt()}" + " Examine the message interface " + f"structure, and publish an HRI message with appropriate arguments." ) -class PublishROS2AudioMessageTask(CustomInterfacesTopicTask): +class PublishROS2HRIMessageTextTaskIndirect(PublishROS2HRIMessageTextTask): + complexity = "easy" + + def get_base_prompt(self) -> str: + return f"Publish message '{self.text}' to human." + + +class PublishROS2AudioMessageTask(CustomInterfaceTask): complexity = "easy" def __init__( self, - topic: str, - validators: List[Validator], + audio: List[int], + sample_rate: int, + channels: int, task_args: TaskArgs, - audio: List[int] = [123, 456, 789], - sample_rate: int = 44100, - channels: int = 2, + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, ) -> None: - super().__init__(topic, validators=validators, task_args=task_args) - self.expected_audio = audio - self.expected_sample_rate = sample_rate - self.expected_channels = channels + self.audio = audio + self.sample_rate = sample_rate + self.channels = channels + if validators is None: + # Default validator for this task + get_audio_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": AUDIO_MESSAGE_TYPE}, + ) + pub_audio_message_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=AUDIO_TOPIC, + expected_message_type=AUDIO_MESSAGE_TYPE, + expected_fields={ + "samples": audio, + "sample_rate": sample_rate, + "channels": channels, + }, + ) + validators = [ + OrderedCallsValidator( + subtasks=[ + get_audio_interface_subtask, + pub_audio_message_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) def get_base_prompt(self) -> str: return ( - f"Publish audio message to topic '{self.topic}' with samples " - f"{self.expected_audio}, sample rate {self.expected_sample_rate}, " - f"channels {self.expected_channels}." + f"Publish audio message to topic '{AUDIO_TOPIC}' with samples " + f"{self.audio}, sample rate {self.sample_rate} and " + f"channels {self.channels}." ) def get_prompt(self) -> str: @@ -214,186 +339,386 @@ def get_prompt(self) -> str: return self.get_base_prompt() else: return ( - f"{self.get_base_prompt()} " - "You can explore available audio topics, examine the message " - f"interface, and publish audio data with samples={self.expected_audio}, " - f"sample_rate={self.expected_sample_rate}, and channels={self.expected_channels}." + f"{self.get_base_prompt()}" + f" Examine the message interface, and publish audio data with appropriate arguments." ) -class PublishROS2DetectionArrayTask(CustomInterfacesTopicTask): +class PublishROS2AudioMessageTaskIndirect(PublishROS2AudioMessageTask): complexity = "easy" + def get_base_prompt(self) -> str: + return ( + f"Publish audio with samples {self.audio}, " + f"sample rate {self.sample_rate} and {self.channels} channels." + ) + + +class PublishROS2DetectionArrayTask(CustomInterfaceTask): + complexity = "medium" + def __init__( self, - topic: str, - validators: List[Validator], task_args: TaskArgs, - detection_classes: List[str] = ["person", "car"], - bbox_center_x: float = 320.0, - bbox_center_y: float = 320.0, - bbox_size_x: float = 50.0, - bbox_size_y: float = 50.0, + detection_classes: List[str], + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, ) -> None: - super().__init__(topic, validators=validators, task_args=task_args) - self.expected_detection_classes = detection_classes - self.expected_detections = [ - Detection2D( - bbox=BoundingBox2D( - center=Pose2D(x=bbox_center_x, y=bbox_center_y, theta=0.0), - size_x=bbox_size_x, - size_y=bbox_size_y, + self.detection_classes = detection_classes + + self.bbox_centers: List[Tuple[float, float]] = [] + self.bbox_sizes: List[Tuple[float, float]] = [] + for obj in detection_classes: + if obj not in DETECTION_DEFAULTS: + # use existing values + obj = "person" + defaults = DETECTION_DEFAULTS[obj] + self.bbox_centers.append(defaults["bbox_center"]) + self.bbox_sizes.append(defaults["bbox_size"]) + + if validators is None: + get_detection_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": DETECTION_ARRAY_MESSAGE_TYPE}, + ) + + expected_fields: Dict[str, Any] = {} + for i, obj in enumerate(detection_classes): + expected_fields.update( + { + f"detections.{i}.results.0.hypothesis.class_id": obj, + f"detections.{i}.bbox.center.x": self.bbox_centers[i][0], + f"detections.{i}.bbox.center.y": self.bbox_centers[i][1], + f"detections.{i}.bbox.size_x": self.bbox_sizes[i][0], + f"detections.{i}.bbox.size_y": self.bbox_sizes[i][1], + } ) + + pub_detection_array_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=DETECTIONS_TOPIC, + expected_message_type=DETECTION_ARRAY_MESSAGE_TYPE, + expected_fields=expected_fields, ) - ] + + validators = [ + OrderedCallsValidator( + subtasks=[ + get_detection_interface_subtask, + pub_detection_array_subtask, + ] + ) + ] + + super().__init__(validators=validators, task_args=task_args, logger=logger) def get_base_prompt(self) -> str: - bbox_center = self.expected_detections[0].bbox.center - bbox_size = self.expected_detections[0].bbox + detection_summaries: List[str] = [] + for _, (cls, center, size) in enumerate( + zip( + self.detection_classes, + self.bbox_centers, + self.bbox_sizes, + ) + ): + detection_summaries.append( + f"{cls} with bbox at center({center[0]}, {center[1]}) and size {size[0]}x{size[1]}" + ) + return ( - f"Publish detection array to topic '{self.topic}' with classes " - f"{self.expected_detection_classes} and bbox center " - f"({bbox_center.x}, {bbox_center.y}) size {bbox_size.size_x}x{bbox_size.size_y}." + f"Publish detection array to topic '{DETECTIONS_TOPIC}' with {len(self.detection_classes)} detections: " + f"{'; '.join(detection_summaries)}." ) def get_prompt(self) -> str: if self.prompt_detail == "brief": return self.get_base_prompt() else: - bbox_center = self.expected_detections[0].bbox.center - bbox_size = self.expected_detections[0].bbox return ( - f"{self.get_base_prompt()} " - "You can explore available detection topics, examine the message " - f"interface, and publish detection data with classes={self.expected_detection_classes} " - f"and bounding box at center ({bbox_center.x}, {bbox_center.y}) " - f"with size_x={bbox_size.size_x}, size_y={bbox_size.size_y}." + f"{self.get_base_prompt()} Examine the message interface " + "and publish detection data with appropriate arguments." ) +class PublishROS2DetectionArrayTaskIndirect(PublishROS2DetectionArrayTask): + complexity = "hard" + + def get_base_prompt(self) -> str: + detection_summaries: List[str] = [] + for _, (cls, center, size) in enumerate( + zip( + self.detection_classes, + self.bbox_centers, + self.bbox_sizes, + ) + ): + detection_summaries.append( + f"{cls} at position ({center[0]}, {center[1]}) with size {size[0]}x{size[1]}" + ) + + return ( + f"Report detected objects in the scene: {'; '.join(detection_summaries)}." + ) + + class CallROS2ManipulatorMoveToServiceTask(CustomInterfacesServiceTask): - complexity = "easy" + complexity = "medium" def __init__( self, - service: str, - service_args: dict[str, Any], - validators: List[Validator], + target_x: float, + target_y: float, + target_z: float, + initial_gripper_state: bool, + final_gripper_state: bool, task_args: TaskArgs, - target_x: float = 1.0, - target_y: float = 2.0, - target_z: float = 3.0, - initial_gripper_state: bool = True, - final_gripper_state: bool = False, - frame_id: str = "base_link", + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, ) -> None: - super().__init__( - service, service_args, validators=validators, task_args=task_args - ) - self.expected_initial_gripper_state = initial_gripper_state - self.expected_final_gripper_state = final_gripper_state - self.expected_target_pose = PoseStamped( - header=Header(frame_id=frame_id), - pose=Pose( - position=Point(x=target_x, y=target_y, z=target_z), - orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0), - ), - ) + self.target_x = target_x + self.target_y = target_y + self.target_z = target_z + self.initial_gripper_state = initial_gripper_state + self.final_gripper_state = final_gripper_state + if validators is None: + # Default validator for this task + get_manipulator_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": MANIPULATOR_SERVICE_TYPE}, + ) + call_manipulator_service_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=MANIPULATOR_SERVICE, + expected_service_type=MANIPULATOR_SERVICE_TYPE, + expected_fields={ + "target_pose.pose.position.x": target_x, + "target_pose.pose.position.y": target_y, + "target_pose.pose.position.z": target_z, + "initial_gripper_state": initial_gripper_state, + "final_gripper_state": final_gripper_state, + }, + ) + validators = [ + OrderedCallsValidator( + subtasks=[ + get_manipulator_interface_subtask, + call_manipulator_service_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) def get_base_prompt(self) -> str: - pos = self.expected_target_pose.pose.position return ( - f"Call service '{self.service}' to move manipulator to pose " - f"({pos.x}, {pos.y}, {pos.z}) with initial_gripper={self.expected_initial_gripper_state}, " - f"final_gripper={self.expected_final_gripper_state}." + f"Call service '{MANIPULATOR_SERVICE}' to move manipulator to pose " + f"({self.target_x}, {self.target_y}, {self.target_z}) with initial gripper state {self.initial_gripper_state} " + f"and final gripper state {self.final_gripper_state}." ) - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - pos = self.expected_target_pose.pose.position - return ( - f"{self.get_base_prompt()} " - "You can discover available manipulation services, examine the service " - f"interface, and call the service with target_pose position (x={pos.x}, " - f"y={pos.y}, z={pos.z}), initial_gripper_state={self.expected_initial_gripper_state}, " - f"and final_gripper_state={self.expected_final_gripper_state}." - ) + +class CallROS2ManipulatorMoveToServiceTaskIndirect( + CallROS2ManipulatorMoveToServiceTask +): + complexity = "medium" + + def get_base_prompt(self) -> str: + init_gripper_action = "close" if self.final_gripper_state else "open" + final_gripper_action = "close" if self.final_gripper_state else "open" + return ( + f"Move robot arm to position ({self.target_x}, {self.target_y}, {self.target_z}) " + f"At the start keep griper {init_gripper_action}, at the end {final_gripper_action}. " + ) class CallGroundedSAMSegmentTask(CustomInterfacesServiceTask): - complexity = "easy" + complexity = "medium" def __init__( self, - service: str, - service_args: dict[str, Any], - validators: List[Validator], task_args: TaskArgs, - frame_id: str = "camera_frame", + detection_classes: List[str], + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, ) -> None: - super().__init__( - service, service_args, validators=validators, task_args=task_args - ) - self.expected_detections = RAIDetectionArray( - header=Header(stamp=Time(sec=0, nanosec=0), frame_id=frame_id), - detections=[], - ) + self.detection_classes = detection_classes + + # Get default parameters for each detection class + self.bbox_centers: List[Tuple[float, float]] = [] + self.bbox_sizes: List[Tuple[float, float]] = [] + self.scores: List[Tuple[float, float]] = [] + self.positions_3d: List[Tuple[float, float]] = [] + + for obj in detection_classes: + if obj not in DETECTION_DEFAULTS: + # use existing values + obj = "person" + defaults = DETECTION_DEFAULTS[obj] + self.bbox_centers.append(defaults["bbox_center"]) + self.bbox_sizes.append(defaults["bbox_size"]) + self.scores.append(defaults["score"]) + self.positions_3d.append(defaults["position_3d"]) + + if validators is None: + get_grounded_sam_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": GROUNDED_SAM_SERVICE_TYPE}, + ) + + expected_fields: Dict[str, Any] = {} + for i, obj in enumerate(detection_classes): + defaults = DETECTION_DEFAULTS[obj] + expected_fields.update( + { + f"detections.detections.{i}.results.0.hypothesis.class_id": obj, + f"detections.detections.{i}.results.0.hypothesis.score": defaults[ + "score" + ], + f"detections.detections.{i}.results.0.pose.pose.position.x": defaults[ + "position_3d" + ][0], + f"detections.detections.{i}.results.0.pose.pose.position.y": defaults[ + "position_3d" + ][1], + f"detections.detections.{i}.results.0.pose.pose.position.z": defaults[ + "position_3d" + ][2], + f"detections.detections.{i}.bbox.center.x": defaults[ + "bbox_center" + ][0], + f"detections.detections.{i}.bbox.center.y": defaults[ + "bbox_center" + ][1], + f"detections.detections.{i}.bbox.size_x": defaults["bbox_size"][ + 0 + ], + f"detections.detections.{i}.bbox.size_y": defaults["bbox_size"][ + 1 + ], + } + ) + + expected_fields.update( + { + "source_img.width": STANDARD_IMAGE_WIDTH, + "source_img.height": STANDARD_IMAGE_HEIGHT, + "source_img.encoding": STANDARD_IMAGE_ENCODING, + } + ) + + call_grounded_sam_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDED_SAM_SERVICE, + expected_service_type=GROUNDED_SAM_SERVICE_TYPE, + expected_fields=expected_fields, + ) + + validators = [ + OrderedCallsValidator( + subtasks=[ + get_grounded_sam_interface_subtask, + call_grounded_sam_subtask, + ] + ) + ] + + super().__init__(validators=validators, task_args=task_args, logger=logger) def get_base_prompt(self) -> str: - return f"Call service '{self.service}' for image segmentation." + detection_summary: List[str] = [] + for cls in self.detection_classes: + defaults = DETECTION_DEFAULTS[cls] + center = defaults["bbox_center"] + size = defaults["bbox_size"] + score = defaults["score"] + pos3d = defaults["position_3d"] + detection_summary.append( + f"{cls} with score {score} at 3D position ({pos3d[0]}, {pos3d[1]}, {pos3d[2]}) " + f"bbox ({center[0]}, {center[1]}) size {size[0]}x{size[1]}" + ) - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - frame_id = self.expected_detections.header.frame_id - return ( - f"{self.get_base_prompt()} " - "You can discover available AI vision services, examine the service " - f"interface, and call the segmentation service with detections array " - f"(empty detections, header frame_id='{frame_id}') and source image." + return ( + f"Call service '{GROUNDED_SAM_SERVICE}' for image segmentation with {len(self.detection_classes)} detections: " + f"{', '.join(detection_summary)} on {STANDARD_IMAGE_WIDTH}x{STANDARD_IMAGE_HEIGHT} {STANDARD_IMAGE_ENCODING} image." + ) + + +class CallGroundedSAMSegmentTaskIndirect(CallGroundedSAMSegmentTask): + complexity = "medium" + + def get_base_prompt(self) -> str: + detection_summary: List[str] = [] + for cls in self.detection_classes: + defaults = DETECTION_DEFAULTS[cls] + center = defaults["bbox_center"] + size = defaults["bbox_size"] + score = defaults["score"] + pos3d = defaults["position_3d"] + detection_summary.append( + f"{cls} with score {score} at 3D position ({pos3d[0]}, {pos3d[1]}, {pos3d[2]}) " + f"bbox ({center[0]}, {center[1]}) size {size[0]}x{size[1]}" ) + return ( + f"Segment detected objects: " + f"{', '.join(detection_summary)}." + f"on {STANDARD_IMAGE_WIDTH}x{STANDARD_IMAGE_HEIGHT} {STANDARD_IMAGE_ENCODING} image." + ) + class CallGroundingDinoClassify(CustomInterfacesServiceTask): complexity = "easy" def __init__( self, - service: str, - service_args: dict[str, Any], - validators: List[Validator], task_args: TaskArgs, - classes: str = "bottle, book, chair", - box_threshold: float = 0.4, - text_threshold: float = 0.25, + classes: str, # Comma-separated string of classes like "person, bottle" + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, ) -> None: - super().__init__( - service, service_args, validators=validators, task_args=task_args - ) - self.expected_classes = classes - self.expected_box_threshold = box_threshold - self.expected_text_threshold = text_threshold + self.classes = classes + + if validators is None: + get_grounding_dino_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": GROUNDING_DINO_SERVICE_TYPE}, + ) + call_grounding_dino_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDING_DINO_SERVICE, + expected_service_type=GROUNDING_DINO_SERVICE_TYPE, + expected_fields={ + "classes": classes, + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_TEXT_THRESHOLD, + }, + ) + validators = [ + OrderedCallsValidator( + subtasks=[ + get_grounding_dino_interface_subtask, + call_grounding_dino_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) def get_base_prompt(self) -> str: return ( - f"Call service '{self.service}' for object classification with classes " - f"'{self.expected_classes}', box_threshold {self.expected_box_threshold}, " - f"text_threshold {self.expected_text_threshold}." + f"Call service '{GROUNDING_DINO_SERVICE}' for object classification with classes " + f"'{self.classes}', box_threshold {DEFAULT_BOX_THRESHOLD} and " + f"text_threshold {DEFAULT_TEXT_THRESHOLD}." ) - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can discover available AI detection services, examine the service " - f"interface, and call the classification service with classes='{self.expected_classes}', " - f"box_threshold={self.expected_box_threshold}, and text_threshold={self.expected_text_threshold}." - ) + +class CallGroundingDinoClassifyIndirect(CallGroundingDinoClassify): + complexity = "medium" + + def get_base_prompt(self) -> str: + return ( + f"Identify these objects in the scene: {self.classes}. " + f"box_threshold should be {DEFAULT_BOX_THRESHOLD} and " + f"text_threshold should be {DEFAULT_TEXT_THRESHOLD}." + ) class CallGetLogDigestTask(CustomInterfacesServiceTask): @@ -401,28 +726,41 @@ class CallGetLogDigestTask(CustomInterfacesServiceTask): def __init__( self, - service: str, - service_args: dict[str, Any], - validators: List[Validator], task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, ) -> None: - super().__init__( - service, service_args, validators=validators, task_args=task_args - ) + if validators is None: + # Default validator for this task + get_log_digest_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": STRING_LIST_SERVICE_TYPE}, + ) + call_log_digest_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=LOG_DIGEST_SERVICE, + expected_service_type=STRING_LIST_SERVICE_TYPE, + expected_fields={"": {}}, + ) + validators = [ + OrderedCallsValidator( + subtasks=[ + get_log_digest_interface_subtask, + call_log_digest_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) def get_base_prompt(self) -> str: - return f"Call service '{self.service}' to get log digest." + return f"Call service '{LOG_DIGEST_SERVICE}' to get log digest." - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can discover available logging services, examine the service " - "interface, and call the service with an empty request to retrieve " - "system log information." - ) + +class CallGetLogDigestTaskIndirect(CallGetLogDigestTask): + complexity = "medium" + + def get_base_prompt(self) -> str: + return "Get a summary of recent system logs." class CallVectorStoreRetrievalTask(CustomInterfacesServiceTask): @@ -430,56 +768,363 @@ class CallVectorStoreRetrievalTask(CustomInterfacesServiceTask): def __init__( self, - service: str, - service_args: dict[str, Any], - validators: List[Validator], + query: str, task_args: TaskArgs, - query: str = "What is the purpose of this robot?", + validators: Optional[List[Validator]] = None, + logger: logging.Logger | None = None, ) -> None: - super().__init__( - service, service_args, validators=validators, task_args=task_args - ) - self.expected_query = query + self.query = query + if validators is None: + # Default validator for this task + get_vector_store_interface_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_ros2_message_interface", + expected_args={"msg_type": VECTOR_STORE_SERVICE_TYPE}, + ) + call_vector_store_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=VECTOR_STORE_SERVICE, + expected_service_type=VECTOR_STORE_SERVICE_TYPE, + expected_fields={ + "query": query, + }, + ) + validators = [ + OrderedCallsValidator( + subtasks=[ + get_vector_store_interface_subtask, + call_vector_store_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) def get_base_prompt(self) -> str: - return f"Call service '{self.service}' with query '{self.expected_query}'" + return f"Call service '{VECTOR_STORE_SERVICE}' with query '{self.query}'." - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can discover available knowledge services, examine the service " - f"interface, and call the retrieval service with query='{self.expected_query}' " - "to search the robot's knowledge base." + +class CallVectorStoreRetrievalTaskIndirect(CallVectorStoreRetrievalTask): + complexity = "medium" + + def get_base_prompt(self) -> str: + return f"Search the documentation for information about: {self.query}." + + +class CompleteObjectInteractionTask(CustomInterfacesServicesTask): + complexity = "hard" + + @property + def optional_tool_calls_number(self) -> int: + # list services and get interface for all required services + return 5 + + def __init__( + self, + task_args: TaskArgs, + target_class: str = "bottle", + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self.target_class = target_class + + # Get default parameters for the target class + obj_key = target_class if target_class in DETECTION_DEFAULTS else "person" + defaults = DETECTION_DEFAULTS[obj_key] + + if validators is None: + self.initial_gripper = False + self.final_gripper = True + self.interaction_message = ( + f"Initiating object interaction sequence with detected {target_class}" ) + call_grounding_dino_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDING_DINO_SERVICE, + expected_service_type=GROUNDING_DINO_SERVICE_TYPE, + expected_fields={ + "classes": target_class, + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_BOX_THRESHOLD, + }, + ) -class CallWhatISeeTask(CustomInterfacesServiceTask): - complexity = "easy" + call_grounded_sam_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDED_SAM_SERVICE, + expected_service_type=GROUNDED_SAM_SERVICE_TYPE, + expected_fields={ + "detections.detections.0.results.0.hypothesis.class_id": target_class, + "detections.detections.0.results.0.hypothesis.score": defaults[ + "score" + ], + "detections.detections.0.results.0.pose.pose.position.x": defaults[ + "position_3d" + ][0], + "detections.detections.0.results.0.pose.pose.position.y": defaults[ + "position_3d" + ][1], + "detections.detections.0.results.0.pose.pose.position.z": defaults[ + "position_3d" + ][2], + "detections.detections.0.bbox.center.x": defaults["bbox_center"][0], + "detections.detections.0.bbox.center.y": defaults["bbox_center"][1], + "detections.detections.0.bbox.size_x": defaults["bbox_size"][0], + "detections.detections.0.bbox.size_y": defaults["bbox_size"][1], + "source_img.width": STANDARD_IMAGE_WIDTH, + "source_img.height": STANDARD_IMAGE_HEIGHT, + "source_img.encoding": STANDARD_IMAGE_ENCODING, + }, + ) + + call_manipulator_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=MANIPULATOR_SERVICE, + expected_service_type=MANIPULATOR_SERVICE_TYPE, + expected_fields={ + "target_pose.pose.position.x": defaults["position_3d"][0], + "target_pose.pose.position.y": defaults["position_3d"][1], + "target_pose.pose.position.z": defaults["position_3d"][2], + "initial_gripper_state": self.initial_gripper, + "final_gripper_state": self.final_gripper, + }, + ) + + pub_hri_interaction_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=HRI_TOPIC, + expected_message_type=HRI_MESSAGE_TYPE, + expected_fields={"text": self.interaction_message}, + ) + + validators = [ + OrderedCallsValidator( + subtasks=[ + call_grounding_dino_subtask, + call_grounded_sam_subtask, + call_manipulator_subtask, + pub_hri_interaction_subtask, + ] + ) + ] + + super().__init__(validators=validators, task_args=task_args, logger=logger) + + def get_base_prompt(self) -> str: + obj_key = ( + self.target_class if self.target_class in DETECTION_DEFAULTS else "person" + ) + defaults = DETECTION_DEFAULTS[obj_key] + return f"""Perform complete object interaction workflow with {self.target_class}: + 1) Call service '{GROUNDING_DINO_SERVICE}' to classify '{self.target_class}' with box_threshold={DEFAULT_BOX_THRESHOLD}, text_threshold={DEFAULT_TEXT_THRESHOLD} + 2) Call service '{GROUNDED_SAM_SERVICE}' to segment {self.target_class} at bbox({defaults["bbox_center"][0]}, {defaults["bbox_center"][1]}) + on {STANDARD_IMAGE_WIDTH}x{STANDARD_IMAGE_HEIGHT} {STANDARD_IMAGE_ENCODING} image + 3) Call service '{MANIPULATOR_SERVICE}' to move to position ({defaults["position_3d"][0]}, {defaults["position_3d"][1]}, {defaults["position_3d"][2]}) + with initial gripper state {self.initial_gripper} and final griper state {self.final_gripper} + 4) Publish to topic '{HRI_TOPIC}' HRI message: '{self.interaction_message}' + """ + + +class CompleteObjectInteractionTaskIndirect(CompleteObjectInteractionTask): + complexity = "hard" + + def get_base_prompt(self) -> str: + obj_key = ( + self.target_class if self.target_class in DETECTION_DEFAULTS else "person" + ) + defaults = DETECTION_DEFAULTS[obj_key] + return f"""Find and pick up the {self.target_class} in front of you by following these steps: + 1) First, look for and identify the {self.target_class} in the scene + 2) Then, get a segmentation of the {self.target_class} to understand its exact location + 3) Move the robot arm to reach the {self.target_class} at position ({defaults["position_3d"][0]}, {defaults["position_3d"][1]}, {defaults["position_3d"][2]}) and grab it + 4) Finally, inform human with '{self.interaction_message}` message that you've completed the task. + """ + + +class MultiModalSceneDocumentationTask(CustomInterfacesServiceTask): + complexity = "hard" + safety_query = "What safety protocols apply when humans and robots share workspace?" + doc_report = "Scene Documentation Complete: Recorded objects with safety analysis" def __init__( self, - service: str, - service_args: dict[str, Any], - validators: List[Validator], task_args: TaskArgs, + objects: List[str] = ["person", "bottle"], + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, ) -> None: - super().__init__( - service, service_args, validators=validators, task_args=task_args + self.objects = objects + + self.bbox_centers: List[Tuple[float, float]] = [] + self.bbox_sizes: List[Tuple[float, float]] = [] + + for obj in objects: + obj_key = obj if obj in DETECTION_DEFAULTS else "person" + defaults = DETECTION_DEFAULTS[obj_key] + self.bbox_centers.append(defaults["bbox_center"]) + self.bbox_sizes.append(defaults["bbox_size"]) + + if validators is None: + expected_fields: Dict[str, Any] = {} + for i, obj in enumerate(objects): + obj_key = obj if obj in DETECTION_DEFAULTS else "person" + defaults = DETECTION_DEFAULTS[obj_key] + expected_fields.update( + { + f"detections.{i}.results.0.hypothesis.class_id": obj, + f"detections.{i}.bbox.center.x": defaults["bbox_center"][0], + f"detections.{i}.bbox.center.y": defaults["bbox_center"][1], + f"detections.{i}.bbox.size_x": defaults["bbox_size"][0], + f"detections.{i}.bbox.size_y": defaults["bbox_size"][1], + } + ) + + pub_detection_array_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=DETECTIONS_TOPIC, + expected_message_type=DETECTION_ARRAY_MESSAGE_TYPE, + expected_fields=expected_fields, + ) + + call_vector_store_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=VECTOR_STORE_SERVICE, + expected_service_type=VECTOR_STORE_SERVICE_TYPE, + expected_fields={ + "query": self.safety_query, + }, + ) + + pub_hri_documentation_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=HRI_TOPIC, + expected_message_type=HRI_MESSAGE_TYPE, + expected_fields={"text": self.doc_report}, + ) + + validators = [ + OrderedCallsValidator( + subtasks=[ + pub_detection_array_subtask, + call_vector_store_subtask, + pub_hri_documentation_subtask, + ] + ) + ] + + super().__init__(validators=validators, task_args=task_args, logger=logger) + + @property + def optional_tool_calls_number(self) -> int: + # list services and get interface for all required services + return 5 + + def get_base_prompt(self) -> str: + object_summary = ", ".join( + [ + f"{obj} at ({center[0]}, {center[1]}) and size {size[0]}x{size[1]}" + for obj, center, size in zip( + self.objects, self.bbox_centers, self.bbox_sizes + ) + ] ) + return f"""Perform comprehensive scene documentation using multiple services: + 1) Publish to topic '{DETECTIONS_TOPIC}' detection array with {len(self.objects)} objects: {object_summary} + 2) Call service '{VECTOR_STORE_SERVICE}' to query: '{self.safety_query}' + 3) Publish to topic '{HRI_TOPIC}' final HRI report: '{self.doc_report}' + """ + + +class MultiModalSceneDocumentationTaskIndirect(MultiModalSceneDocumentationTask): + complexity = "hard" + def get_base_prompt(self) -> str: - return f"Call service '{self.service}' to get visual observations." + object_summary = ", ".join( + [ + f"{obj} at position ({center[0]}, {center[1]})" + for obj, center in zip(self.objects, self.bbox_centers) + ] + ) + return f"""Document the current scene and ensure safety protocols are followed: + 1) Record all detected objects in the scene: {object_summary} + 2) Look up safety guidelines for human-robot workspace sharing with query: `{self.safety_query}` + 3) Provide a final documentation report confirming scene analysis and safety compliance with report: `{self.doc_report}` + """ - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can discover available vision services, examine the service " - "interface, and call the service with an empty request to get " - "visual observations and camera pose information." + +class EmergencyResponseProtocolTask(CustomInterfacesServiceTask): + complexity = "hard" + emergence_audio_samples = [880, 880, 880, 1760] + emergence_sample_rate = 8000 + emergency_channels = 1 + emergency_message = "Person detected, alarm started!" + + def __init__( + self, + task_args: TaskArgs, + target_class: str = "person", + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self.target_class = target_class + if validators is None: + call_grounding_dino_emergency_subtask = CheckServiceFieldsToolCallSubTask( + expected_tool_name="call_ros2_service", + expected_service=GROUNDING_DINO_SERVICE, + expected_service_type=GROUNDING_DINO_SERVICE_TYPE, + expected_fields={ + "classes": target_class, + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_TEXT_THRESHOLD, + }, + ) + + pub_audio_emergency_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=AUDIO_TOPIC, + expected_message_type=AUDIO_MESSAGE_TYPE, + expected_fields={ + "samples": self.emergence_audio_samples, + "sample_rate": self.emergence_sample_rate, + "channels": self.emergency_channels, + }, + ) + pub_hri_emergency_subtask = CheckTopicFieldsToolCallSubTask( + expected_tool_name="publish_ros2_message", + expected_topic=HRI_TOPIC, + expected_message_type=HRI_MESSAGE_TYPE, + expected_fields={"text": self.emergency_message}, ) + validators = [ + OrderedCallsValidator( + subtasks=[ + call_grounding_dino_emergency_subtask, + pub_audio_emergency_subtask, + pub_hri_emergency_subtask, + ] + ) + ] + super().__init__(validators, task_args, logger) + + @property + def optional_tool_calls_number(self) -> int: + # list services and get interface for all required services + return 5 + + def get_base_prompt(self) -> str: + return f"""Execute emergency response protocol for {self.target_class} detection: + 1) Call service '{GROUNDING_DINO_SERVICE}' to detect emergency: '{self.target_class}' with high thresholds box={DEFAULT_BOX_THRESHOLD}, text={DEFAULT_TEXT_THRESHOLD} + 2) Publish to topic '{AUDIO_TOPIC}' emergency alert: {self.emergence_audio_samples} at {self.emergence_sample_rate}Hz + 3) Publish to topic '{HRI_TOPIC}' emergency message: '{self.emergency_message}' + """ + + +class EmergencyResponseProtocolTaskIndirect(EmergencyResponseProtocolTask): + complexity = "hard" + + def get_base_prompt(self) -> str: + return f"""Execute emergency response protocol when {self.target_class} is detected: + 1) Detect any {self.target_class} presence. + 2) Sound an emergency audio alert with samples {self.emergence_audio_samples} at {self.emergence_sample_rate}Hz + 3) Send emergency notification: '{self.emergency_message}' + """ diff --git a/src/rai_core/rai/types/rai_interfaces.py b/src/rai_core/rai/types/rai_interfaces.py index 8b231cddd..6125a5c80 100644 --- a/src/rai_core/rai/types/rai_interfaces.py +++ b/src/rai_core/rai/types/rai_interfaces.py @@ -56,6 +56,7 @@ class RAIGroundedSamResponse(BaseRaiSrv): class RAIGroundingDinoRequest(BaseRaiSrv): _prefix: str = "rai_interfaces/srv" + # TODO (jmatejcz) is it a bug that classes is str ? classes: str = "" box_threshold: float = 0.0 text_threshold: float = 0.0 diff --git a/tests/rai_bench/tool_calling_agent/test_predefined_tasks.py b/tests/rai_bench/tool_calling_agent/test_predefined_basic_tasks.py similarity index 100% rename from tests/rai_bench/tool_calling_agent/test_predefined_tasks.py rename to tests/rai_bench/tool_calling_agent/test_predefined_basic_tasks.py index 792db0ee9..c98a24466 100644 --- a/tests/rai_bench/tool_calling_agent/test_predefined_tasks.py +++ b/tests/rai_bench/tool_calling_agent/test_predefined_basic_tasks.py @@ -20,6 +20,18 @@ TaskArgs, ) from rai_bench.tool_calling_agent.predefined.basic_tasks import ( + BOX1_ENTITY, + BOX1_POSITION, + BOX2_ENTITY, + BOX2_POSITION, + DEFAULT_DINO_CONFIDENCE, + DEFAULT_FPS, + DEFAULT_PUBLISH_FREQUENCY, + DEFAULT_SAM_CONFIDENCE, + DINO_CONFIDENCE_2, + FPS_2, + SAM_CONFIDENCE_2, + TOMATO_ENTITY, all_camera_images_notord_val, check_spawnable_entities_val, color_image_ord_val, @@ -31,20 +43,10 @@ topics_ord_val, ) from rai_bench.tool_calling_agent.tasks.basic import ( - BOX1_ENTITY, - BOX1_POSITION, - BOX2_ENTITY, - BOX2_POSITION, COLOR_IMAGE_TOPIC, - DEFAULT_DINO_CONFIDENCE, - DEFAULT_FPS, - DEFAULT_PUBLISH_FREQUENCY, - DEFAULT_SAM_CONFIDENCE, DELETE_ENTITY_SERVICE, DELETE_ENTITY_TYPE, DEPTH_IMAGE_TOPIC, - DINO_CONFIDENCE_2, - FPS_2, GET_PARAMETERS_TYPE, GET_SPAWNABLE_NAMES_SERVICE, GET_WORLD_PROPERTIES_TYPE, @@ -59,12 +61,10 @@ ROBOT_STATE_PUBLISHER_GET_PARAMS, ROBOT_STATE_PUBLISHER_LIST_PARAMS, ROBOT_STATE_PUBLISHER_SET_PARAMS, - SAM_CONFIDENCE_2, SET_PARAMETERS_ATOMICALLY_TYPE, SET_PARAMETERS_TYPE, SPAWN_ENTITY_SERVICE, SPAWN_ENTITY_TYPE, - TOMATO_ENTITY, CheckSpawnableEntitiesTask, ConfigureVisionPipelineTask, GetAllROS2CamerasTask, diff --git a/tests/rai_bench/tool_calling_agent/test_predefined_custom_interfaces_tasks.py b/tests/rai_bench/tool_calling_agent/test_predefined_custom_interfaces_tasks.py new file mode 100644 index 000000000..32b2c5aef --- /dev/null +++ b/tests/rai_bench/tool_calling_agent/test_predefined_custom_interfaces_tasks.py @@ -0,0 +1,1411 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +from typing import Any, Dict, List + +import pytest + +from rai_bench.tool_calling_agent.interfaces import TaskArgs + +# Import constants from predefined tasks +from rai_bench.tool_calling_agent.predefined.custom_interfaces_tasks import ( + BASIC_AUDIO_SAMPLES, + BASIC_CHANNELS, + BASIC_SAMPLE_RATE, + BOTTLE_CLASS, + DEFAULT_SCENE_OBJECTS, + GROUNDING_DINO_CLASSES, + HRI_TEXT, + PERSON_CLASS, + ROBOT_PURPOSE_QUERY, + STANDARD_TARGET_POSITION, +) +from rai_bench.tool_calling_agent.tasks.custom_interfaces import ( + AUDIO_MESSAGE_TYPE, + AUDIO_TOPIC, + BOTTLE_BBOX_CENTER, + BOTTLE_BBOX_SIZE, + BOTTLE_POSITION_3D, + BOTTLE_SCORE, + DEFAULT_BOX_THRESHOLD, + DEFAULT_TEXT_THRESHOLD, + DETECTION_ARRAY_MESSAGE_TYPE, + DETECTION_DEFAULTS, + DETECTIONS_TOPIC, + GROUNDED_SAM_SERVICE, + GROUNDED_SAM_SERVICE_TYPE, + GROUNDING_DINO_SERVICE, + GROUNDING_DINO_SERVICE_TYPE, + HRI_MESSAGE_TYPE, + HRI_TOPIC, + LOG_DIGEST_SERVICE, + MANIPULATOR_SERVICE, + MANIPULATOR_SERVICE_TYPE, + STANDARD_IMAGE_ENCODING, + STANDARD_IMAGE_HEIGHT, + STANDARD_IMAGE_WIDTH, + STRING_LIST_SERVICE_TYPE, + VECTOR_STORE_SERVICE, + VECTOR_STORE_SERVICE_TYPE, + CallGetLogDigestTask, + CallGroundedSAMSegmentTask, + CallGroundingDinoClassify, + CallROS2ManipulatorMoveToServiceTask, + CallVectorStoreRetrievalTask, + CompleteObjectInteractionTask, + EmergencyResponseProtocolTask, + MultiModalSceneDocumentationTask, + PublishROS2AudioMessageTask, + PublishROS2DetectionArrayTask, + PublishROS2HRIMessageTextTask, +) + + +@pytest.fixture +def task_args() -> TaskArgs: + return TaskArgs( + extra_tool_calls=0, + prompt_detail="brief", + examples_in_system_prompt=0, + ) + + +class TestPublishROS2HRIMessageTextTask: + """Test PublishROS2HRIMessageTextTask validation.""" + + def test_publish_hri_message_valid(self, task_args: TaskArgs) -> None: + """Test valid HRI message publication.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": HRI_MESSAGE_TYPE}, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": HRI_TOPIC, + "message": {"text": HRI_TEXT}, + "message_type": HRI_MESSAGE_TYPE, + }, + }, + ] + + task = PublishROS2HRIMessageTextTask( + task_args=task_args, + text=HRI_TEXT, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_publish_hri_message_wrong_text(self, task_args: TaskArgs) -> None: + """Test HRI message with wrong text content.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": HRI_MESSAGE_TYPE}, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": HRI_TOPIC, + "message": {"text": "Goodbye!"}, # Wrong text + "message_type": HRI_MESSAGE_TYPE, + }, + }, + ] + + task = PublishROS2HRIMessageTextTask( + task_args=task_args, + text=HRI_TEXT, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_publish_hri_message_missing_interface_call( + self, task_args: TaskArgs + ) -> None: + """Test missing interface call.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "publish_ros2_message", + "args": { + "topic": HRI_TOPIC, + "message": {"text": HRI_TEXT}, + "message_type": HRI_MESSAGE_TYPE, + }, + }, + ] + + task = PublishROS2HRIMessageTextTask( + task_args=task_args, + text=HRI_TEXT, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_publish_hri_message_too_much_calls(self, task_args: TaskArgs) -> None: + """Test too many calls.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": HRI_MESSAGE_TYPE}, + }, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": HRI_MESSAGE_TYPE}, + }, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": HRI_MESSAGE_TYPE}, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": HRI_TOPIC, + "message": {"text": HRI_TEXT}, + "message_type": HRI_MESSAGE_TYPE, + }, + }, + ] + + task = PublishROS2HRIMessageTextTask( + task_args=task_args, + text=HRI_TEXT, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestPublishROS2AudioMessageTask: + """Test PublishROS2AudioMessageTask validation.""" + + def test_publish_audio_message_valid(self, task_args: TaskArgs) -> None: + """Test valid audio message publication.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": AUDIO_MESSAGE_TYPE}, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": AUDIO_TOPIC, + "message": { + "samples": BASIC_AUDIO_SAMPLES, + "sample_rate": BASIC_SAMPLE_RATE, + "channels": BASIC_CHANNELS, + }, + "message_type": AUDIO_MESSAGE_TYPE, + }, + }, + ] + + task = PublishROS2AudioMessageTask( + task_args=task_args, + audio=BASIC_AUDIO_SAMPLES, + sample_rate=BASIC_SAMPLE_RATE, + channels=BASIC_CHANNELS, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_publish_audio_message_wrong_param_value(self, task_args: TaskArgs) -> None: + """Test audio message with wrong sample rate.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": AUDIO_MESSAGE_TYPE}, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": AUDIO_TOPIC, + "message": { + "samples": BASIC_AUDIO_SAMPLES, + "sample_rate": 48000, # Wrong sample rate + "channels": BASIC_CHANNELS, + }, + "message_type": AUDIO_MESSAGE_TYPE, + }, + }, + ] + + task = PublishROS2AudioMessageTask( + task_args=task_args, + audio=BASIC_AUDIO_SAMPLES, + sample_rate=BASIC_SAMPLE_RATE, + channels=BASIC_CHANNELS, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_publish_audio_message_missing_call(self, task_args: TaskArgs) -> None: + """Test audio message with missing interface call.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "publish_ros2_message", + "args": { + "topic": AUDIO_TOPIC, + "message": { + "samples": BASIC_AUDIO_SAMPLES, + "sample_rate": BASIC_SAMPLE_RATE, + "channels": BASIC_CHANNELS, + }, + "message_type": AUDIO_MESSAGE_TYPE, + }, + }, + ] + + task = PublishROS2AudioMessageTask( + task_args=task_args, + audio=BASIC_AUDIO_SAMPLES, + sample_rate=BASIC_SAMPLE_RATE, + channels=BASIC_CHANNELS, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestPublishROS2DetectionArrayTask: + """Test PublishROS2DetectionArrayTask validation.""" + + def valid_template(self, obj: str) -> List[Dict[str, Any]]: + return [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": DETECTION_ARRAY_MESSAGE_TYPE}, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": DETECTIONS_TOPIC, + "message": { + "detections": [ + { + "results": [{"hypothesis": {"class_id": obj}}], + "bbox": { + "center": { + "x": DETECTION_DEFAULTS[obj]["bbox_center"][0], + "y": DETECTION_DEFAULTS[obj]["bbox_center"][1], + }, + "size_x": DETECTION_DEFAULTS[obj]["bbox_size"][0], + "size_y": DETECTION_DEFAULTS[obj]["bbox_size"][1], + }, + } + ] + }, + "message_type": DETECTION_ARRAY_MESSAGE_TYPE, + }, + }, + ] + + def valid_template_multiple_classes( + self, classes: List[str] + ) -> List[Dict[str, Any]]: + """Generate valid tool calls for multiple detection classes.""" + detections: list[Dict[str, Any]] = [] + for obj in classes: + detections.append( + { + "results": [{"hypothesis": {"class_id": obj}}], + "bbox": { + "center": { + "x": DETECTION_DEFAULTS[obj]["bbox_center"][0], + "y": DETECTION_DEFAULTS[obj]["bbox_center"][1], + }, + "size_x": DETECTION_DEFAULTS[obj]["bbox_size"][0], + "size_y": DETECTION_DEFAULTS[obj]["bbox_size"][1], + }, + } + ) + + return [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": DETECTION_ARRAY_MESSAGE_TYPE}, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": DETECTIONS_TOPIC, + "message": {"detections": detections}, + "message_type": DETECTION_ARRAY_MESSAGE_TYPE, + }, + }, + ] + + def test_publish_detection_array_person_valid(self, task_args: TaskArgs) -> None: + """Test valid detection array publication with person.""" + + task = PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=[PERSON_CLASS], + ) + score = task.validate(self.valid_template(PERSON_CLASS)) + assert score == 1.0 + + def test_publish_detection_array_multiple_classes_valid( + self, task_args: TaskArgs + ) -> None: + """Test valid detection array publication with both bottle and person classes.""" + + task = PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=[BOTTLE_CLASS, PERSON_CLASS], + ) + score = task.validate( + self.valid_template_multiple_classes([BOTTLE_CLASS, PERSON_CLASS]) + ) + assert score == 1.0 + + def test_publish_detection_array_wrong_class(self, task_args: TaskArgs) -> None: + """Test detection array with wrong class.""" + tool_calls = copy.deepcopy(self.valid_template(PERSON_CLASS)) + + # Modify the class_id to wrong value + tool_calls[1]["args"]["message"]["detections"][0]["results"][0]["hypothesis"][ + "class_id" + ] = BOTTLE_CLASS + + task = PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=[PERSON_CLASS], + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_publish_detection_array_wrong_param_value( + self, task_args: TaskArgs + ) -> None: + """Test detection array with wrong bounding box parameters.""" + tool_calls = copy.deepcopy(self.valid_template(PERSON_CLASS)) + + # Modify the bbox center and size to wrong values + tool_calls[1]["args"]["message"]["detections"][0]["bbox"]["center"]["x"] = 100.0 + tool_calls[1]["args"]["message"]["detections"][0]["bbox"]["center"]["y"] = 100.0 + tool_calls[1]["args"]["message"]["detections"][0]["bbox"]["size_x"] = 25.0 + tool_calls[1]["args"]["message"]["detections"][0]["bbox"]["size_y"] = 25.0 + + task = PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=[PERSON_CLASS], + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_publish_detection_array_missing_tool_call( + self, task_args: TaskArgs + ) -> None: + """Test detection array with missing interface call.""" + tool_calls = copy.deepcopy(self.valid_template(PERSON_CLASS)) + + # Remove the interface call + tool_calls.pop(0) + + task = PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=[PERSON_CLASS], + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_publish_detection_array_unknown_class_fallback( + self, task_args: TaskArgs + ) -> None: + """Test detection array with unknown class falls back to person defaults.""" + tool_calls = copy.deepcopy(self.valid_template(PERSON_CLASS)) + + # Modify the class_id to unknown class (should use person defaults) + tool_calls[1]["args"]["message"]["detections"][0]["results"][0]["hypothesis"][ + "class_id" + ] = "unknown_class" + + task = PublishROS2DetectionArrayTask( + task_args=task_args, + detection_classes=["unknown_class"], # Uses person defaults + ) + score = task.validate(tool_calls) + assert score == 1.0 + + +class TestCallGroundedSAMSegmentTask: + """Test CallGroundedSAMSegmentTask validation.""" + + VALID_TOOL_CALLS_TEMPLATE: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": GROUNDED_SAM_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDED_SAM_SERVICE, + "service_type": GROUNDED_SAM_SERVICE_TYPE, + "service_args": { + "detections": { + "detections": [ + { + "results": [ + { + "hypothesis": { + "class_id": BOTTLE_CLASS, + "score": BOTTLE_SCORE, + }, + "pose": { + "pose": { + "position": { + "x": BOTTLE_POSITION_3D[0], + "y": BOTTLE_POSITION_3D[1], + "z": BOTTLE_POSITION_3D[2], + } + } + }, + } + ], + "bbox": { + "center": { + "x": BOTTLE_BBOX_CENTER[0], + "y": BOTTLE_BBOX_CENTER[1], + }, + "size_x": BOTTLE_BBOX_SIZE[0], + "size_y": BOTTLE_BBOX_SIZE[1], + }, + } + ] + }, + "source_img": { + "width": STANDARD_IMAGE_WIDTH, + "height": STANDARD_IMAGE_HEIGHT, + "encoding": STANDARD_IMAGE_ENCODING, + }, + }, + }, + }, + ] + + def test_call_grounded_sam_bottle_valid(self, task_args: TaskArgs) -> None: + """Test valid grounded SAM call with bottle.""" + task = CallGroundedSAMSegmentTask( + task_args=task_args, + detection_classes=[BOTTLE_CLASS], + ) + score = task.validate(self.VALID_TOOL_CALLS_TEMPLATE) + assert score == 1.0 + + def test_call_grounded_sam_wrong_class(self, task_args: TaskArgs) -> None: + """Test grounded SAM call with wrong class ID.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the class_id to wrong value + tool_calls[1]["args"]["service_args"]["detections"]["detections"][0]["results"][ + 0 + ]["hypothesis"]["class_id"] = PERSON_CLASS + + task = CallGroundedSAMSegmentTask( + task_args=task_args, + detection_classes=[BOTTLE_CLASS], + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_grounded_sam_wrong_param_value(self, task_args: TaskArgs) -> None: + """Test grounded SAM call with wrong parameter value.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the score to wrong value + tool_calls[1]["args"]["service_args"]["detections"]["detections"][0]["results"][ + 0 + ]["hypothesis"]["score"] = 0.95 + + task = CallGroundedSAMSegmentTask( + task_args=task_args, + detection_classes=[BOTTLE_CLASS], + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_grounded_sam_missing_tool_call(self, task_args: TaskArgs) -> None: + """Test grounded SAM call missing the interface call.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Remove the interface call + tool_calls.pop(0) + + task = CallGroundedSAMSegmentTask( + task_args=task_args, + detection_classes=[BOTTLE_CLASS], + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestCallGroundingDinoClassify: + """Test CallGroundingDinoClassify validation.""" + + def test_call_grounding_dino_valid(self, task_args: TaskArgs) -> None: + """Test valid grounding DINO call.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": GROUNDING_DINO_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SERVICE, + "service_type": GROUNDING_DINO_SERVICE_TYPE, + "service_args": { + "classes": GROUNDING_DINO_CLASSES, + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_TEXT_THRESHOLD, + }, + }, + }, + ] + + task = CallGroundingDinoClassify( + task_args=task_args, + classes=GROUNDING_DINO_CLASSES, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_call_grounding_dino_wrong_classes(self, task_args: TaskArgs) -> None: + """Test grounding DINO call with wrong classes.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": GROUNDING_DINO_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SERVICE, + "service_type": GROUNDING_DINO_SERVICE_TYPE, + "service_args": { + "classes": "cat, dog", # Wrong classes - expecting GROUNDING_DINO_CLASSES + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_TEXT_THRESHOLD, + }, + }, + }, + ] + + task = CallGroundingDinoClassify( + task_args=task_args, + classes=GROUNDING_DINO_CLASSES, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_grounding_dino_wrong_param_value(self, task_args: TaskArgs) -> None: + """Test grounding DINO call with wrong parameter value.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": GROUNDING_DINO_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SERVICE, + "service_type": GROUNDING_DINO_SERVICE_TYPE, + "service_args": { + "classes": GROUNDING_DINO_CLASSES, + "box_threshold": 0.8, # Wrong threshold - expecting DEFAULT_BOX_THRESHOLD + "text_threshold": DEFAULT_TEXT_THRESHOLD, + }, + }, + }, + ] + + task = CallGroundingDinoClassify( + task_args=task_args, + classes=GROUNDING_DINO_CLASSES, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_grounding_dino_missing_tool_call(self, task_args: TaskArgs) -> None: + """Test grounding DINO call missing the interface call.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SERVICE, + "service_type": GROUNDING_DINO_SERVICE_TYPE, + "service_args": { + "classes": GROUNDING_DINO_CLASSES, + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_TEXT_THRESHOLD, + }, + }, + }, + ] + + task = CallGroundingDinoClassify( + task_args=task_args, + classes=GROUNDING_DINO_CLASSES, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestCompleteObjectInteractionTask: + """Test CompleteObjectInteractionTask validation.""" + + VALID_TOOL_CALLS_TEMPLATE: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SERVICE, + "service_type": GROUNDING_DINO_SERVICE_TYPE, + "service_args": { + "classes": "bottle", + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_BOX_THRESHOLD, + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDED_SAM_SERVICE, + "service_type": GROUNDED_SAM_SERVICE_TYPE, + "service_args": { + "detections": { + "detections": [ + { + "results": [ + { + "hypothesis": { + "class_id": "bottle", + "score": DETECTION_DEFAULTS["bottle"][ + "score" + ], + }, + "pose": { + "pose": { + "position": { + "x": DETECTION_DEFAULTS["bottle"][ + "position_3d" + ][0], + "y": DETECTION_DEFAULTS["bottle"][ + "position_3d" + ][1], + "z": DETECTION_DEFAULTS["bottle"][ + "position_3d" + ][2], + } + } + }, + } + ], + "bbox": { + "center": { + "x": DETECTION_DEFAULTS["bottle"][ + "bbox_center" + ][0], + "y": DETECTION_DEFAULTS["bottle"][ + "bbox_center" + ][1], + }, + "size_x": DETECTION_DEFAULTS["bottle"]["bbox_size"][ + 0 + ], + "size_y": DETECTION_DEFAULTS["bottle"]["bbox_size"][ + 1 + ], + }, + } + ] + }, + "source_img": { + "width": STANDARD_IMAGE_WIDTH, + "height": STANDARD_IMAGE_HEIGHT, + "encoding": STANDARD_IMAGE_ENCODING, + }, + }, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": MANIPULATOR_SERVICE, + "service_type": MANIPULATOR_SERVICE_TYPE, + "service_args": { + "target_pose": { + "pose": { + "position": { + "x": DETECTION_DEFAULTS["bottle"]["position_3d"][0], + "y": DETECTION_DEFAULTS["bottle"]["position_3d"][1], + "z": DETECTION_DEFAULTS["bottle"]["position_3d"][2], + } + } + }, + "initial_gripper_state": False, + "final_gripper_state": True, + }, + }, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": HRI_TOPIC, + "message": { + "text": "Initiating object interaction sequence with detected bottle" + }, + "message_type": HRI_MESSAGE_TYPE, + }, + }, + ] + + def test_complete_object_interaction_bottle_valid( + self, task_args: TaskArgs + ) -> None: + """Test valid complete object interaction with bottle.""" + task = CompleteObjectInteractionTask( + task_args=task_args, + target_class="bottle", + ) + score = task.validate(self.VALID_TOOL_CALLS_TEMPLATE) + assert score == 1.0 + + def test_complete_object_interaction_wrong_class(self, task_args: TaskArgs) -> None: + """Test complete object interaction with wrong target class.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the classes to wrong value + tool_calls[0]["args"]["service_args"]["classes"] = PERSON_CLASS + + task = CompleteObjectInteractionTask( + task_args=task_args, + target_class="bottle", + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_complete_object_interaction_wrong_param_value( + self, task_args: TaskArgs + ) -> None: + """Test complete object interaction with wrong parameter value.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the box_threshold to wrong value + tool_calls[0]["args"]["service_args"]["box_threshold"] = 0.8 + + task = CompleteObjectInteractionTask( + task_args=task_args, + target_class="bottle", + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_complete_object_interaction_missing_tool_call( + self, task_args: TaskArgs + ) -> None: + """Test complete object interaction missing a required tool call.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Remove the Grounded SAM service call (index 1) + tool_calls.pop(1) + + task = CompleteObjectInteractionTask( + task_args=task_args, + target_class="bottle", + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestCallROS2ManipulatorMoveToServiceTask: + """Test CallROS2ManipulatorMoveToServiceTask validation.""" + + def test_call_manipulator_service_valid(self, task_args: TaskArgs) -> None: + """Test valid manipulator service call.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": MANIPULATOR_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": MANIPULATOR_SERVICE, + "service_type": MANIPULATOR_SERVICE_TYPE, + "service_args": { + "target_pose": { + "pose": { + "position": { + "x": STANDARD_TARGET_POSITION[0], + "y": STANDARD_TARGET_POSITION[1], + "z": STANDARD_TARGET_POSITION[2], + } + } + }, + "initial_gripper_state": True, + "final_gripper_state": False, + }, + }, + }, + ] + + task = CallROS2ManipulatorMoveToServiceTask( + task_args=task_args, + target_x=STANDARD_TARGET_POSITION[0], + target_y=STANDARD_TARGET_POSITION[1], + target_z=STANDARD_TARGET_POSITION[2], + initial_gripper_state=True, + final_gripper_state=False, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_call_manipulator_service_wrong_position(self, task_args: TaskArgs) -> None: + """Test manipulator service with wrong position values.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": MANIPULATOR_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": MANIPULATOR_SERVICE, + "service_type": MANIPULATOR_SERVICE_TYPE, + "service_args": { + "target_pose": { + "pose": { + "position": { + "x": 2.0, + "y": 3.0, + "z": 4.0, + } # Wrong position + } + }, + "initial_gripper_state": True, + "final_gripper_state": False, + }, + }, + }, + ] + + task = CallROS2ManipulatorMoveToServiceTask( + task_args=task_args, + target_x=STANDARD_TARGET_POSITION[0], + target_y=STANDARD_TARGET_POSITION[1], + target_z=STANDARD_TARGET_POSITION[2], + initial_gripper_state=True, + final_gripper_state=False, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_manipulator_service_wrong_message_type( + self, task_args: TaskArgs + ) -> None: + """Test manipulator service with wrong message type.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": "wrong_message_type"}, # Wrong message type + }, + { + "name": "call_ros2_service", + "args": { + "service_name": MANIPULATOR_SERVICE, + "service_type": MANIPULATOR_SERVICE_TYPE, + "service_args": { + "target_pose": { + "pose": { + "position": { + "x": STANDARD_TARGET_POSITION[0], + "y": STANDARD_TARGET_POSITION[1], + "z": STANDARD_TARGET_POSITION[2], + } + } + }, + "initial_gripper_state": True, + "final_gripper_state": False, + }, + }, + }, + ] + + task = CallROS2ManipulatorMoveToServiceTask( + task_args=task_args, + target_x=STANDARD_TARGET_POSITION[0], + target_y=STANDARD_TARGET_POSITION[1], + target_z=STANDARD_TARGET_POSITION[2], + initial_gripper_state=True, + final_gripper_state=False, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_manipulator_service_wrong_tool_order( + self, task_args: TaskArgs + ) -> None: + """Test manipulator service with wrong tool call order.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": MANIPULATOR_SERVICE, + "service_type": MANIPULATOR_SERVICE_TYPE, + "service_args": { + "target_pose": { + "pose": { + "position": { + "x": STANDARD_TARGET_POSITION[0], + "y": STANDARD_TARGET_POSITION[1], + "z": STANDARD_TARGET_POSITION[2], + } + } + }, + "initial_gripper_state": True, + "final_gripper_state": False, + }, + }, + }, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": MANIPULATOR_SERVICE_TYPE}, + }, + ] + + task = CallROS2ManipulatorMoveToServiceTask( + task_args=task_args, + target_x=STANDARD_TARGET_POSITION[0], + target_y=STANDARD_TARGET_POSITION[1], + target_z=STANDARD_TARGET_POSITION[2], + initial_gripper_state=True, + final_gripper_state=False, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestCallGetLogDigestTask: + """Test CallGetLogDigestTask validation.""" + + def test_call_log_digest_valid(self, task_args: TaskArgs) -> None: + """Test valid log digest service call.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": STRING_LIST_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": LOG_DIGEST_SERVICE, + "service_type": STRING_LIST_SERVICE_TYPE, + "service_args": {}, + }, + }, + ] + + task = CallGetLogDigestTask( + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_call_log_digest_wrong_service_name(self, task_args: TaskArgs) -> None: + """Test log digest with wrong service name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": STRING_LIST_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": "/wrong_log_service", # Wrong service name + "service_type": STRING_LIST_SERVICE_TYPE, + "service_args": {}, + }, + }, + ] + + task = CallGetLogDigestTask( + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_log_digest_wrong_message_type(self, task_args: TaskArgs) -> None: + """Test log digest with wrong message type.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": "wrong_message_type"}, # Wrong message type + }, + { + "name": "call_ros2_service", + "args": { + "service_name": LOG_DIGEST_SERVICE, + "service_type": STRING_LIST_SERVICE_TYPE, + "service_args": {}, + }, + }, + ] + + task = CallGetLogDigestTask( + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_log_digest_wrong_tool_order(self, task_args: TaskArgs) -> None: + """Test log digest with wrong tool call order.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": LOG_DIGEST_SERVICE, + "service_type": STRING_LIST_SERVICE_TYPE, + "service_args": {}, + }, + }, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": STRING_LIST_SERVICE_TYPE}, + }, + ] + + task = CallGetLogDigestTask( + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestCallVectorStoreRetrievalTask: + """Test CallVectorStoreRetrievalTask validation.""" + + def test_call_vector_store_valid(self, task_args: TaskArgs) -> None: + """Test valid vector store service call.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": VECTOR_STORE_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": VECTOR_STORE_SERVICE, + "service_type": VECTOR_STORE_SERVICE_TYPE, + "service_args": { + "query": "What is the purpose of this robot?", + }, + }, + }, + ] + + task = CallVectorStoreRetrievalTask( + task_args=task_args, + query="What is the purpose of this robot?", + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_call_vector_store_wrong_query(self, task_args: TaskArgs) -> None: + """Test vector store with wrong query.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": VECTOR_STORE_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": VECTOR_STORE_SERVICE, + "service_type": VECTOR_STORE_SERVICE_TYPE, + "service_args": { + "query": "Wrong query text", # Wrong query + }, + }, + }, + ] + + task = CallVectorStoreRetrievalTask( + task_args=task_args, + query=ROBOT_PURPOSE_QUERY, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_vector_store_wrong_service_name(self, task_args: TaskArgs) -> None: + """Test vector store with wrong service name.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "get_ros2_message_interface", + "args": {"msg_type": VECTOR_STORE_SERVICE_TYPE}, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": "/wrong_vector_store_service", # Wrong service name + "service_type": VECTOR_STORE_SERVICE_TYPE, + "service_args": { + "query": ROBOT_PURPOSE_QUERY, + }, + }, + }, + ] + + task = CallVectorStoreRetrievalTask( + task_args=task_args, + query=ROBOT_PURPOSE_QUERY, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_call_vector_store_wrong_tool_order(self, task_args: TaskArgs) -> None: + """Test vector store with wrong tool call order.""" + tool_calls: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": VECTOR_STORE_SERVICE, + "service_type": VECTOR_STORE_SERVICE_TYPE, + "service_args": { + "query": ROBOT_PURPOSE_QUERY, + }, + }, + }, + { + "name": "get_ros2_message_interface", + "args": {"msg_type": VECTOR_STORE_SERVICE_TYPE}, + }, + ] + + task = CallVectorStoreRetrievalTask( + task_args=task_args, + query=ROBOT_PURPOSE_QUERY, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestMultiModalSceneDocumentationTask: + """Test MultiModalSceneDocumentationTask validation.""" + + VALID_TOOL_CALLS_TEMPLATE: List[Dict[str, Any]] = [ + { + "name": "publish_ros2_message", + "args": { + "topic": DETECTIONS_TOPIC, + "message": { + "detections": [ + { + "results": [{"hypothesis": {"class_id": PERSON_CLASS}}], + "bbox": { + "center": { + "x": DETECTION_DEFAULTS["person"]["bbox_center"][0], + "y": DETECTION_DEFAULTS["person"]["bbox_center"][1], + }, + "size_x": DETECTION_DEFAULTS["person"]["bbox_size"][0], + "size_y": DETECTION_DEFAULTS["person"]["bbox_size"][1], + }, + }, + { + "results": [{"hypothesis": {"class_id": BOTTLE_CLASS}}], + "bbox": { + "center": { + "x": DETECTION_DEFAULTS["bottle"]["bbox_center"][0], + "y": DETECTION_DEFAULTS["bottle"]["bbox_center"][1], + }, + "size_x": DETECTION_DEFAULTS["bottle"]["bbox_size"][0], + "size_y": DETECTION_DEFAULTS["bottle"]["bbox_size"][1], + }, + }, + ] + }, + "message_type": DETECTION_ARRAY_MESSAGE_TYPE, + }, + }, + { + "name": "call_ros2_service", + "args": { + "service_name": VECTOR_STORE_SERVICE, + "service_type": VECTOR_STORE_SERVICE_TYPE, + "service_args": { + "query": "What safety protocols apply when humans and robots share workspace?", + }, + }, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": HRI_TOPIC, + "message": { + "text": "Scene Documentation Complete: Recorded objects with safety analysis" + }, + "message_type": HRI_MESSAGE_TYPE, + }, + }, + ] + + def test_multimodal_scene_documentation_valid(self, task_args: TaskArgs) -> None: + """Test valid multimodal scene documentation.""" + task = MultiModalSceneDocumentationTask( + task_args=task_args, + objects=DEFAULT_SCENE_OBJECTS, + ) + score = task.validate(self.VALID_TOOL_CALLS_TEMPLATE) + assert score == 1.0 + + def test_multimodal_scene_documentation_wrong_class( + self, task_args: TaskArgs + ) -> None: + """Test multimodal scene documentation with wrong object class.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the first detection class to wrong value + tool_calls[0]["args"]["message"]["detections"][0]["results"][0]["hypothesis"][ + "class_id" + ] = "unknown_object" + + task = MultiModalSceneDocumentationTask( + task_args=task_args, + objects=DEFAULT_SCENE_OBJECTS, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_multimodal_scene_documentation_wrong_param_value( + self, task_args: TaskArgs + ) -> None: + """Test multimodal scene documentation with wrong parameter value.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the safety query to wrong value + tool_calls[1]["args"]["service_args"]["query"] = "What is the weather like?" + + task = MultiModalSceneDocumentationTask( + task_args=task_args, + objects=DEFAULT_SCENE_OBJECTS, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_multimodal_scene_documentation_missing_tool_call( + self, task_args: TaskArgs + ) -> None: + """Test multimodal scene documentation missing a required tool call.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Remove the vector store service call (index 2) + tool_calls.pop(2) + + task = MultiModalSceneDocumentationTask( + task_args=task_args, + objects=DEFAULT_SCENE_OBJECTS, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestEmergencyResponseProtocolTask: + """Test EmergencyResponseProtocolTask validation.""" + + VALID_TOOL_CALLS_TEMPLATE: List[Dict[str, Any]] = [ + { + "name": "call_ros2_service", + "args": { + "service_name": GROUNDING_DINO_SERVICE, + "service_type": GROUNDING_DINO_SERVICE_TYPE, + "service_args": { + "classes": "person", + "box_threshold": DEFAULT_BOX_THRESHOLD, + "text_threshold": DEFAULT_TEXT_THRESHOLD, + }, + }, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": AUDIO_TOPIC, + "message": { + "samples": [880, 880, 880, 1760], + "sample_rate": 8000, + "channels": 1, + }, + "message_type": AUDIO_MESSAGE_TYPE, + }, + }, + { + "name": "publish_ros2_message", + "args": { + "topic": HRI_TOPIC, + "message": {"text": "Person detected, alarm started!"}, + "message_type": HRI_MESSAGE_TYPE, + }, + }, + ] + + def test_emergency_response_protocol_valid(self, task_args: TaskArgs) -> None: + """Test valid emergency response protocol.""" + task = EmergencyResponseProtocolTask( + task_args=task_args, + target_class=PERSON_CLASS, + ) + score = task.validate(self.VALID_TOOL_CALLS_TEMPLATE) + assert score == 1.0 + + def test_emergency_response_wrong_class(self, task_args: TaskArgs) -> None: + """Test emergency response with wrong target class.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the classes to wrong value + tool_calls[0]["args"]["service_args"]["classes"] = BOTTLE_CLASS + + task = EmergencyResponseProtocolTask( + task_args=task_args, + target_class=PERSON_CLASS, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_emergency_response_wrong_param_value(self, task_args: TaskArgs) -> None: + """Test emergency response with wrong parameter value.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + # Modify the box_threshold to wrong value + tool_calls[0]["args"]["service_args"]["box_threshold"] = 0.8 + + task = EmergencyResponseProtocolTask( + task_args=task_args, + target_class=PERSON_CLASS, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_emergency_response_missing_tool_call(self, task_args: TaskArgs) -> None: + """Test emergency response missing a required tool call.""" + tool_calls = copy.deepcopy(self.VALID_TOOL_CALLS_TEMPLATE) + + tool_calls.pop(1) + + task = EmergencyResponseProtocolTask( + task_args=task_args, + target_class=PERSON_CLASS, + ) + score = task.validate(tool_calls) + assert score == 0.0 From 6fd5671c8fac5eeb782c30b7acce33479ce855d4 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 25 Jul 2025 13:48:49 +0200 Subject: [PATCH 04/13] feat: tool calling spatial reasoning tasks extension (#637) --- .../predefined/spatial_reasoning_tasks.py | 104 +++++++----------- .../tool_calling_agent/tasks/spatial.py | 14 +-- 2 files changed, 47 insertions(+), 71 deletions(-) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py index f67d19792..d3ccbfa5e 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py @@ -32,66 +32,6 @@ ) IMG_PATH = "src/rai_bench/rai_bench/tool_calling_agent/predefined/images/" -true_response_inputs: List[BoolImageTaskInput] = [ - BoolImageTaskInput( - question="Is the door on the left from the desk?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is the light on in the room?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Are there any pictures on the wall?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="Are there 3 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant behind the rack?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Is there a pillow on the armchain?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), -] -false_response_inputs: List[BoolImageTaskInput] = [ - BoolImageTaskInput( - question="Is the door open?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is someone in the room?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="Are there 4 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a rack on the left from the sofa?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant on the right from the window?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is there a red pillow on the armchair?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), -] ########## SUBTASKS ################################################################# return_true_subtask = CheckArgsToolCallSubTask( expected_tool_name="return_bool_response", expected_args={"response": True} @@ -127,8 +67,8 @@ def get_spatial_tasks( easy_true_inputs = [ # Single object presence/detection BoolImageTaskInput( - question="Is the light on in the room?", - images_paths=[IMG_PATH + "image_2.jpg"], + question="Is the chair in the room?", + images_paths=[IMG_PATH + "image_1.jpg"], ), BoolImageTaskInput( question="Do you see the plant?", images_paths=[IMG_PATH + "image_2.jpg"] @@ -138,8 +78,8 @@ def get_spatial_tasks( images_paths=[IMG_PATH + "image_3.jpg"], ), BoolImageTaskInput( - question="Is there a pillow on the armchain?", - images_paths=[IMG_PATH + "image_7.jpg"], + question="is there a TV in the room?", + images_paths=[IMG_PATH + "image_4.jpg"], ), ] @@ -149,6 +89,14 @@ def get_spatial_tasks( question="Are there 3 pictures on the wall?", images_paths=[IMG_PATH + "image_4.jpg"], ), + BoolImageTaskInput( + question="Is the light on in the room?", + images_paths=[IMG_PATH + "image_2.jpg"], + ), + BoolImageTaskInput( + question="Is there something to sit on?", + images_paths=[IMG_PATH + "image_7.jpg"], + ), ] hard_true_inputs = [ @@ -161,6 +109,14 @@ def get_spatial_tasks( question="Is there a plant behind the rack?", images_paths=[IMG_PATH + "image_5.jpg"], ), + BoolImageTaskInput( + question="Is there a rug under the bed?", + images_paths=[IMG_PATH + "image_2.jpg"], + ), + BoolImageTaskInput( + question="Is there a pillow on the armchain?", + images_paths=[IMG_PATH + "image_7.jpg"], + ), ] easy_false_inputs = [ @@ -175,6 +131,14 @@ def get_spatial_tasks( question="Is there a red pillow on the armchair?", images_paths=[IMG_PATH + "image_7.jpg"], ), + BoolImageTaskInput( + question="Is there a red desk with chair in the room?", + images_paths=[IMG_PATH + "image_5.jpg"], + ), + BoolImageTaskInput( + question="Do you see the bed?", + images_paths=[IMG_PATH + "image_6.jpg"], + ), ] medium_false_inputs = [ @@ -186,6 +150,14 @@ def get_spatial_tasks( question="Are there 4 pictures on the wall?", images_paths=[IMG_PATH + "image_4.jpg"], ), + BoolImageTaskInput( + question="Is the TV switched on?", + images_paths=[IMG_PATH + "image_6.jpg"], + ), + BoolImageTaskInput( + question="Is the window opened?", + images_paths=[IMG_PATH + "image_6.jpg"], + ), ] hard_false_inputs = [ @@ -198,6 +170,10 @@ def get_spatial_tasks( question="Is there a plant on the right from the window?", images_paths=[IMG_PATH + "image_6.jpg"], ), + BoolImageTaskInput( + question="Is the chair next to a bed?", + images_paths=[IMG_PATH + "image_1.jpg"], + ), ] for extra_calls in extra_tool_calls: diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py index 94e81742e..2f9b58e0d 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py @@ -25,24 +25,24 @@ loggers_type = logging.Logger -SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT = """You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response with the use of the provided tools.""" +SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT = """You are a helpful and knowledgeable AI assistant that specializes +in interpreting and analyzing visual content. Your task is to answer questions based +on the images provided to you. Please response with the use of the provided tools.""" +# NOTE (jmatejcz) In this case we are using only one tool so there is no difference bettween 2 and 5 shot +# so I made 1 example in '2 shot' and 2 examples in '5 shot' prompt SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT = ( SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT + """ Example of tool calls: -- return_bool_response, args: {'response': True} -- return_bool_response, args: {'response': False}""" +- return_bool_response, args: {'response': True}""" ) -# NOTE (jmatejcz) In this case we are using only one tool so there is no difference bettween 2 and 5 shot SPATIAL_REASONING_SYSTEM_PROMPT_5_SHOT = ( SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT + """ -- return_bool_response, args: {'response': True} # When object is clearly visible -- return_bool_response, args: {'response': False} # When object is not present -- return_bool_response, args: {'response': True} # When spatial relationship is correct""" +- return_bool_response, args: {'response': False}""" ) From 0b7aa991b786a67b827e3dfa6a56b872d790b3ea Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 25 Jul 2025 14:53:43 +0200 Subject: [PATCH 05/13] refactor: remove navigation tasks (#638) --- docs/simulation_and_benchmarking/rai_bench.md | 1 - .../docs/tool_calling_agent_benchmark.md | 2 +- .../rai_bench/examples/benchmarking_models.py | 1 - src/rai_bench/rai_bench/test_models.py | 2 - .../mocked_ros2_interfaces.py | 737 +----------------- .../tool_calling_agent/predefined/__init__.py | 2 - .../predefined/navigation_tasks.py | 118 --- .../tool_calling_agent/predefined/tasks.py | 9 - .../tool_calling_agent/tasks/navigation.py | 231 ------ src/rai_bench/rai_bench/utils.py | 2 - 10 files changed, 2 insertions(+), 1103 deletions(-) delete mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py delete mode 100644 src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index be31f4a05..1969be88e 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -130,7 +130,6 @@ The ToolCallingAgentBenchmark class manages the execution of tasks and collects There are predefined Tasks available which are grouped by categories: - Basic - require retrieving info from certain topics -- Navigation - Spatial reasoning - questions about surroundings with images attached - Manipulation - Custom Interfaces - requires using messages with custom interfaces diff --git a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md index 0472170c5..0eb6664b5 100644 --- a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md +++ b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md @@ -14,4 +14,4 @@ Implementations can be found: - Validators [Validators](../tool_calling_agent/validators.py) - Subtasks [Validators](../tool_calling_agent/tasks/subtasks.py) -- Tasks, including navigation, spatial, custom interfaces and other [Tasks](../tool_calling_agent/tasks/) +- Tasks, including basic, spatial, custom interfaces and manipulation [Tasks](../tool_calling_agent/tasks/) diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index c4008eef4..c97d11cbd 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -36,7 +36,6 @@ task_types=[ # what types of tasks to include "basic", "spatial_reasoning", - # "navigation", "custom_interfaces", "manipulation", ], diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index d4949f7cd..0501d811c 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -96,14 +96,12 @@ class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): Literal[ "basic", "manipulation", - "navigation", "custom_interfaces", "spatial_reasoning", ] ] = [ "basic", "manipulation", - "navigation", "custom_interfaces", "spatial_reasoning", ] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py index 4971fc514..566791b85 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/mocked_ros2_interfaces.py @@ -26,20 +26,6 @@ RAIGroundingDinoRequest, ) -from rai_bench.tool_calling_agent.messages.actions import ( - AssistedTeleopGoal, - BackUpGoal, - ComputePathThroughPosesGoal, - ComputePathToPoseGoal, - DriveOnHeadingGoal, - FollowPathGoal, - FollowWaypointsGoal, - NavigateThroughPosesGoal, - NavigateToPoseGoal, - SmoothPathGoal, - SpinGoal, - WaitGoal, -) from rai_bench.tool_calling_agent.messages.base import Clock from rai_bench.tool_calling_agent.messages.services import ( StringListRequest, @@ -2459,410 +2445,6 @@ } -NAVIGATION_INTERFACES: Dict[str, str] = { - "nav2_msgs/action/NavigateToPose": """#goal definition -geometry_msgs/PoseStamped pose - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string behavior_tree ---- -#result definition -std_msgs/Empty result ---- -#feedback definition -geometry_msgs/PoseStamped current_pose - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration navigation_time - int32 sec - uint32 nanosec -builtin_interfaces/Duration estimated_time_remaining - int32 sec - uint32 nanosec -int16 number_of_recoveries -float32 distance_remaining -""", - "nav2_msgs/action/AssistedTeleop": """#goal definition -builtin_interfaces/Duration time_allowance - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback -builtin_interfaces/Duration current_teleop_duration - int32 sec - uint32 nanosec""", - "nav2_msgs/action/BackUp": """#goal definition -geometry_msgs/Point target - float64 x - float64 y - float64 z -float32 speed -builtin_interfaces/Duration time_allowance - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback definition -float32 distance_traveled""", - "nav2_msgs/action/ComputePathThroughPoses": """#goal definition -geometry_msgs/PoseStamped[] goals - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -geometry_msgs/PoseStamped start - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string planner_id -bool use_start # If false, use current robot pose as path start, if true, use start above instead ---- -#result definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration planning_time - int32 sec - uint32 nanosec ---- -#feedback definition""", - "nav2_msgs/action/ComputePathToPose": """#goal definition -geometry_msgs/PoseStamped goal - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -geometry_msgs/PoseStamped start - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string planner_id -bool use_start # If false, use current robot pose as path start, if true, use start above instead ---- -#result definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration planning_time - int32 sec - uint32 nanosec ---- -#feedback definition""", - "nav2_msgs/action/DriveOnHeading": """#goal definition -geometry_msgs/Point target - float64 x - float64 y - float64 z -float32 speed -builtin_interfaces/Duration time_allowance - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback definition -float32 distance_traveled""", - "nav2_msgs/action/FollowPath": """#goal definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string controller_id -string goal_checker_id ---- -#result definition -std_msgs/Empty result ---- -#feedback definition -float32 distance_to_goal -float32 speed""", - "nav2_msgs/action/FollowWaypoints": """#goal definition -geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 ---- -#result definition -int32[] missed_waypoints ---- -#feedback definition -uint32 current_waypoint""", - "nav2_msgs/action/NavigateThroughPoses": """#goal definition -geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string behavior_tree ---- -#result definition -std_msgs/Empty result ---- -#feedback definition -geometry_msgs/PoseStamped current_pose - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration navigation_time - int32 sec - uint32 nanosec -builtin_interfaces/Duration estimated_time_remaining - int32 sec - uint32 nanosec -int16 number_of_recoveries -float32 distance_remaining -int16 number_of_poses_remaining -""", - "nav2_msgs/action/SmoothPath": """#goal definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -string smoother_id -builtin_interfaces/Duration max_smoothing_duration - int32 sec - uint32 nanosec -bool check_for_collisions ---- -#result definition -nav_msgs/Path path - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - geometry_msgs/PoseStamped[] poses - std_msgs/Header header - builtin_interfaces/Time stamp - int32 sec - uint32 nanosec - string frame_id - Pose pose - Point position - float64 x - float64 y - float64 z - Quaternion orientation - float64 x 0 - float64 y 0 - float64 z 0 - float64 w 1 -builtin_interfaces/Duration smoothing_duration - int32 sec - uint32 nanosec -bool was_completed ---- -#feedback definition -""", - "nav2_msgs/action/Wait": """#goal definition -builtin_interfaces/Duration time - int32 sec - uint32 nanosec ---- -#result definition -builtin_interfaces/Duration total_elapsed_time - int32 sec - uint32 nanosec ---- -#feedback definition -builtin_interfaces/Duration time_left - int32 sec - uint32 nanosec""", -} - CUSTOM_INTERFACES: Dict[str, str] = { "rai_interfaces/msg/HRIMessage": """ # @@ -3406,84 +2988,6 @@ } -NAVIGATION_TOPICS_AND_TYPES: Dict[str, str] = { - # Main navigation actions - "/navigate_to_pose/_action/feedback": "nav2_msgs/action/NavigateToPose_FeedbackMessage", - "/navigate_to_pose/_action/status": "action_msgs/msg/GoalStatusArray", - "/navigate_through_poses/_action/feedback": "nav2_msgs/action/NavigateThroughPoses_FeedbackMessage", - "/navigate_through_poses/_action/status": "action_msgs/msg/GoalStatusArray", - "/follow_path/_action/feedback": "nav2_msgs/action/FollowPath_FeedbackMessage", - "/follow_path/_action/status": "action_msgs/msg/GoalStatusArray", - "/follow_waypoints/_action/feedback": "nav2_msgs/action/FollowWaypoints_FeedbackMessage", - "/follow_waypoints/_action/status": "action_msgs/msg/GoalStatusArray", - # Path planning actions - "/compute_path_to_pose/_action/feedback": "nav2_msgs/action/ComputePathToPose_FeedbackMessage", - "/compute_path_to_pose/_action/status": "action_msgs/msg/GoalStatusArray", - "/compute_path_through_poses/_action/feedback": "nav2_msgs/action/ComputePathThroughPoses_FeedbackMessage", - "/compute_path_through_poses/_action/status": "action_msgs/msg/GoalStatusArray", - "/smooth_path/_action/feedback": "nav2_msgs/action/SmoothPath_FeedbackMessage", - "/smooth_path/_action/status": "action_msgs/msg/GoalStatusArray", - # Behavior actions - "/assisted_teleop/_action/feedback": "nav2_msgs/action/AssistedTeleop_FeedbackMessage", - "/assisted_teleop/_action/status": "action_msgs/msg/GoalStatusArray", - "/backup/_action/feedback": "nav2_msgs/action/BackUp_FeedbackMessage", - "/backup/_action/status": "action_msgs/msg/GoalStatusArray", - "/drive_on_heading/_action/feedback": "nav2_msgs/action/DriveOnHeading_FeedbackMessage", - "/drive_on_heading/_action/status": "action_msgs/msg/GoalStatusArray", - "/spin/_action/feedback": "nav2_msgs/action/Spin_FeedbackMessage", - "/spin/_action/status": "action_msgs/msg/GoalStatusArray", - "/wait/_action/feedback": "nav2_msgs/action/Wait_FeedbackMessage", - "/wait/_action/status": "action_msgs/msg/GoalStatusArray", - # Costmaps and mapping - "/global_costmap/costmap": "nav_msgs/msg/OccupancyGrid", - "/global_costmap/costmap_raw": "nav2_msgs/msg/Costmap", - "/global_costmap/costmap_updates": "map_msgs/msg/OccupancyGridUpdate", - "/global_costmap/footprint": "geometry_msgs/msg/Polygon", - "/global_costmap/published_footprint": "geometry_msgs/msg/PolygonStamped", - "/global_costmap/scan": "sensor_msgs/msg/LaserScan", - "/local_costmap/costmap": "nav_msgs/msg/OccupancyGrid", - "/local_costmap/costmap_raw": "nav2_msgs/msg/Costmap", - "/local_costmap/costmap_updates": "map_msgs/msg/OccupancyGridUpdate", - "/local_costmap/footprint": "geometry_msgs/msg/Polygon", - "/local_costmap/published_footprint": "geometry_msgs/msg/PolygonStamped", - "/local_costmap/scan": "sensor_msgs/msg/LaserScan", - "/map": "nav_msgs/msg/OccupancyGrid", - "/map_metadata": "nav_msgs/msg/MapMetaData", - # SLAM - "/slam_toolbox/feedback": "visualization_msgs/msg/InteractiveMarkerFeedback", - "/slam_toolbox/graph_visualization": "visualization_msgs/msg/MarkerArray", - "/slam_toolbox/scan_visualization": "sensor_msgs/msg/LaserScan", - "/slam_toolbox/update": "visualization_msgs/msg/InteractiveMarkerUpdate", - # Path planning and visualization - "/plan": "nav_msgs/msg/Path", - "/plan_smoothed": "nav_msgs/msg/Path", - "/unsmoothed_plan": "nav_msgs/msg/Path", - "/transformed_global_plan": "nav_msgs/msg/Path", - "/trajectories": "visualization_msgs/msg/MarkerArray", - # Control and goals - "/cmd_vel_nav": "geometry_msgs/msg/Twist", - "/cmd_vel_teleop": "geometry_msgs/msg/Twist", - "/goal_pose": "geometry_msgs/msg/PoseStamped", - "/pose": "geometry_msgs/msg/PoseWithCovarianceStamped", - "/preempt_teleop": "std_msgs/msg/Empty", - "/speed_limit": "nav2_msgs/msg/SpeedLimit", - # Behavior tree - "/behavior_tree_log": "nav2_msgs/msg/BehaviorTreeLog", - # Other - "/led_strip": "sensor_msgs/msg/Image", - # Lifecycle management - "/behavior_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/bt_navigator/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/controller_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/global_costmap/global_costmap/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/local_costmap/local_costmap/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/map_saver/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/planner_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/smoother_server/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/velocity_smoother/transition_event": "lifecycle_msgs/msg/TransitionEvent", - "/waypoint_follower/transition_event": "lifecycle_msgs/msg/TransitionEvent", -} - CUSTOM_TOPICS_AND_TYPES: Dict[str, str] = { "/to_human": "rai_interfaces/msg/HRIMessage", "/audio_message": "rai_interfaces/msg/AudioMessage", @@ -3608,217 +3112,6 @@ "/state_controller/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", } -NAVIGATION_SERVICES_AND_TYPES: Dict[str, str] = { - # Action services for navigation behaviors - "/assisted_teleop/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/assisted_teleop/_action/get_result": "nav2_msgs/action/AssistedTeleop_GetResult", - "/assisted_teleop/_action/send_goal": "nav2_msgs/action/AssistedTeleop_SendGoal", - "/backup/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/backup/_action/get_result": "nav2_msgs/action/BackUp_GetResult", - "/backup/_action/send_goal": "nav2_msgs/action/BackUp_SendGoal", - "/drive_on_heading/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/drive_on_heading/_action/get_result": "nav2_msgs/action/DriveOnHeading_GetResult", - "/drive_on_heading/_action/send_goal": "nav2_msgs/action/DriveOnHeading_SendGoal", - "/follow_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/follow_path/_action/get_result": "nav2_msgs/action/FollowPath_GetResult", - "/follow_path/_action/send_goal": "nav2_msgs/action/FollowPath_SendGoal", - "/follow_waypoints/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/follow_waypoints/_action/get_result": "nav2_msgs/action/FollowWaypoints_GetResult", - "/follow_waypoints/_action/send_goal": "nav2_msgs/action/FollowWaypoints_SendGoal", - "/spin/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/spin/_action/get_result": "nav2_msgs/action/Spin_GetResult", - "/spin/_action/send_goal": "nav2_msgs/action/Spin_SendGoal", - "/wait/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/wait/_action/get_result": "nav2_msgs/action/Wait_GetResult", - "/wait/_action/send_goal": "nav2_msgs/action/Wait_SendGoal", - # Path planning action services - "/compute_path_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/compute_path_through_poses/_action/get_result": "nav2_msgs/action/ComputePathThroughPoses_GetResult", - "/compute_path_through_poses/_action/send_goal": "nav2_msgs/action/ComputePathThroughPoses_SendGoal", - "/compute_path_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/compute_path_to_pose/_action/get_result": "nav2_msgs/action/ComputePathToPose_GetResult", - "/compute_path_to_pose/_action/send_goal": "nav2_msgs/action/ComputePathToPose_SendGoal", - "/smooth_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/smooth_path/_action/get_result": "nav2_msgs/action/SmoothPath_GetResult", - "/smooth_path/_action/send_goal": "nav2_msgs/action/SmoothPath_SendGoal", - # Main navigation action services - "/navigate_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/navigate_through_poses/_action/get_result": "nav2_msgs/action/NavigateThroughPoses_GetResult", - "/navigate_through_poses/_action/send_goal": "nav2_msgs/action/NavigateThroughPoses_SendGoal", - "/navigate_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", - "/navigate_to_pose/_action/get_result": "nav2_msgs/action/NavigateToPose_GetResult", - "/navigate_to_pose/_action/send_goal": "nav2_msgs/action/NavigateToPose_SendGoal", - # Costmap management services - "/global_costmap/clear_around_global_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", - "/global_costmap/clear_entirely_global_costmap": "nav2_msgs/srv/ClearEntireCostmap", - "/global_costmap/clear_except_global_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", - "/global_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", - "/local_costmap/clear_around_local_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", - "/local_costmap/clear_entirely_local_costmap": "nav2_msgs/srv/ClearEntireCostmap", - "/local_costmap/clear_except_local_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", - "/local_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", - # Path validation - "/is_path_valid": "nav2_msgs/srv/IsPathValid", - # SLAM services - "/slam_toolbox/clear_changes": "slam_toolbox/srv/Clear", - "/slam_toolbox/clear_queue": "slam_toolbox/srv/ClearQueue", - "/slam_toolbox/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/slam_toolbox/deserialize_map": "slam_toolbox/srv/DeserializePoseGraph", - "/slam_toolbox/dynamic_map": "nav_msgs/srv/GetMap", - "/slam_toolbox/get_interactive_markers": "visualization_msgs/srv/GetInteractiveMarkers", - "/slam_toolbox/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/slam_toolbox/get_parameters": "rcl_interfaces/srv/GetParameters", - "/slam_toolbox/list_parameters": "rcl_interfaces/srv/ListParameters", - "/slam_toolbox/manual_loop_closure": "slam_toolbox/srv/LoopClosure", - "/slam_toolbox/pause_new_measurements": "slam_toolbox/srv/Pause", - "/slam_toolbox/save_map": "slam_toolbox/srv/SaveMap", - "/slam_toolbox/serialize_map": "slam_toolbox/srv/SerializePoseGraph", - "/slam_toolbox/set_parameters": "rcl_interfaces/srv/SetParameters", - "/slam_toolbox/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/slam_toolbox/toggle_interactive_mode": "slam_toolbox/srv/ToggleInteractive", - # Map saving - "/map_saver/change_state": "lifecycle_msgs/srv/ChangeState", - "/map_saver/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/map_saver/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/map_saver/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/map_saver/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/map_saver/get_parameters": "rcl_interfaces/srv/GetParameters", - "/map_saver/get_state": "lifecycle_msgs/srv/GetState", - "/map_saver/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/map_saver/list_parameters": "rcl_interfaces/srv/ListParameters", - "/map_saver/save_map": "nav2_msgs/srv/SaveMap", - "/map_saver/set_parameters": "rcl_interfaces/srv/SetParameters", - "/map_saver/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - # Navigation server lifecycle and parameter services - "/behavior_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/behavior_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/behavior_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/behavior_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/behavior_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/behavior_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/behavior_server/get_state": "lifecycle_msgs/srv/GetState", - "/behavior_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/behavior_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/behavior_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/behavior_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/bt_navigator/change_state": "lifecycle_msgs/srv/ChangeState", - "/bt_navigator/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/bt_navigator/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/bt_navigator/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/bt_navigator/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/bt_navigator/get_parameters": "rcl_interfaces/srv/GetParameters", - "/bt_navigator/get_state": "lifecycle_msgs/srv/GetState", - "/bt_navigator/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/bt_navigator/list_parameters": "rcl_interfaces/srv/ListParameters", - "/bt_navigator/set_parameters": "rcl_interfaces/srv/SetParameters", - "/bt_navigator/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/bt_navigator_navigate_through_poses_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", - "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/bt_navigator_navigate_to_pose_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", - "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/controller_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/controller_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/controller_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/controller_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/controller_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/controller_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/controller_server/get_state": "lifecycle_msgs/srv/GetState", - "/controller_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/controller_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/controller_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/controller_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/global_costmap/global_costmap/change_state": "lifecycle_msgs/srv/ChangeState", - "/global_costmap/global_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/global_costmap/global_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/global_costmap/global_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/global_costmap/global_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/global_costmap/global_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", - "/global_costmap/global_costmap/get_state": "lifecycle_msgs/srv/GetState", - "/global_costmap/global_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/global_costmap/global_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", - "/global_costmap/global_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", - "/global_costmap/global_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/local_costmap/local_costmap/change_state": "lifecycle_msgs/srv/ChangeState", - "/local_costmap/local_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/local_costmap/local_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/local_costmap/local_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/local_costmap/local_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/local_costmap/local_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", - "/local_costmap/local_costmap/get_state": "lifecycle_msgs/srv/GetState", - "/local_costmap/local_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/local_costmap/local_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", - "/local_costmap/local_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", - "/local_costmap/local_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/planner_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/planner_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/planner_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/planner_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/planner_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/planner_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/planner_server/get_state": "lifecycle_msgs/srv/GetState", - "/planner_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/planner_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/planner_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/planner_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/smoother_server/change_state": "lifecycle_msgs/srv/ChangeState", - "/smoother_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/smoother_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/smoother_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/smoother_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/smoother_server/get_parameters": "rcl_interfaces/srv/GetParameters", - "/smoother_server/get_state": "lifecycle_msgs/srv/GetState", - "/smoother_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/smoother_server/list_parameters": "rcl_interfaces/srv/ListParameters", - "/smoother_server/set_parameters": "rcl_interfaces/srv/SetParameters", - "/smoother_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/velocity_smoother/change_state": "lifecycle_msgs/srv/ChangeState", - "/velocity_smoother/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/velocity_smoother/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/velocity_smoother/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/velocity_smoother/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/velocity_smoother/get_parameters": "rcl_interfaces/srv/GetParameters", - "/velocity_smoother/get_state": "lifecycle_msgs/srv/GetState", - "/velocity_smoother/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/velocity_smoother/list_parameters": "rcl_interfaces/srv/ListParameters", - "/velocity_smoother/set_parameters": "rcl_interfaces/srv/SetParameters", - "/velocity_smoother/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/waypoint_follower/change_state": "lifecycle_msgs/srv/ChangeState", - "/waypoint_follower/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/waypoint_follower/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", - "/waypoint_follower/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", - "/waypoint_follower/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/waypoint_follower/get_parameters": "rcl_interfaces/srv/GetParameters", - "/waypoint_follower/get_state": "lifecycle_msgs/srv/GetState", - "/waypoint_follower/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", - "/waypoint_follower/list_parameters": "rcl_interfaces/srv/ListParameters", - "/waypoint_follower/set_parameters": "rcl_interfaces/srv/SetParameters", - "/waypoint_follower/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - # Lifecycle management services - "/lifecycle_manager_navigation/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/lifecycle_manager_navigation/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/lifecycle_manager_navigation/get_parameters": "rcl_interfaces/srv/GetParameters", - "/lifecycle_manager_navigation/is_active": "std_srvs/srv/Trigger", - "/lifecycle_manager_navigation/list_parameters": "rcl_interfaces/srv/ListParameters", - "/lifecycle_manager_navigation/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", - "/lifecycle_manager_navigation/set_parameters": "rcl_interfaces/srv/SetParameters", - "/lifecycle_manager_navigation/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", - "/lifecycle_manager_slam/describe_parameters": "rcl_interfaces/srv/DescribeParameters", - "/lifecycle_manager_slam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", - "/lifecycle_manager_slam/get_parameters": "rcl_interfaces/srv/GetParameters", - "/lifecycle_manager_slam/is_active": "std_srvs/srv/Trigger", - "/lifecycle_manager_slam/list_parameters": "rcl_interfaces/srv/ListParameters", - "/lifecycle_manager_slam/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", - "/lifecycle_manager_slam/set_parameters": "rcl_interfaces/srv/SetParameters", - "/lifecycle_manager_slam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", -} CUSTOM_SERVICES_AND_TYPES: Dict[str, str] = { "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam", "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino", @@ -3838,21 +3131,7 @@ "/pickup": "moveit_msgs/action/Pickup", "/place": "moveit_msgs/action/Place", } -NAVIGATION_ACTIONS_AND_TYPES: Dict[str, str] = { - "/navigate_to_pose": "nav2_msgs/action/NavigateToPose", - "/navigate_through_poses": "nav2_msgs/action/Nmoveit_msgs/action/MoveGroupmoveit_msgs/action/MoveGroupavigateThroughPoses", - "/follow_path": "nav2_msgs/action/FollowPath", - "/follow_waypoints": "nav2_msgs/action/FollowWaypoints", - "/compute_path_to_pose": "nav2_msgs/action/ComputePathToPose", - "/compute_path_through_poses": "nav2_msgs/action/ComputePathThroughPoses", - "/smooth_path": "nav2_msgs/action/SmoothPath", - "/spin": "nav2_msgs/action/Spin", - "/backup": "nav2_msgs/action/BackUp", - "/drive_on_heading": "nav2_msgs/action/DriveOnHeading", - "/wait": "nav2_msgs/action/Wait", - "/assisted_teleop": "nav2_msgs/action/AssistedTeleop", - "/clear_costmap": "nav2_msgs/action/ClearEntireCostmap", -} + COMMON_TOPIC_MODELS: Dict[str, Type[BaseModel]] = { "sensor_msgs/msg/CameraInfo": CameraInfo, "sensor_msgs/msg/Image": Image, @@ -3874,17 +3153,3 @@ "rai_interfaces/srv/WhatISee": WhatISeeRequest, } MANIPULATION_ACTION_MODELS: Dict[str, Type[BaseModel]] = {} -NAVIGATION_ACTION_MODELS: Dict[str, Type[BaseModel]] = { - "nav2_msgs/action/NavigateToPose": NavigateToPoseGoal, - "nav2_msgs/action/Spin": SpinGoal, - "nav2_msgs/action/AssistedTeleop": AssistedTeleopGoal, - "nav2_msgs/action/BackUp": BackUpGoal, - "nav2_msgs/action/ComputePathThroughPoses": ComputePathThroughPosesGoal, - "nav2_msgs/action/ComputePathToPose": ComputePathToPoseGoal, - "nav2_msgs/action/DriveOnHeading": DriveOnHeadingGoal, - "nav2_msgs/action/FollowPath": FollowPathGoal, - "nav2_msgs/action/FollowWaypoints": FollowWaypointsGoal, - "nav2_msgs/action/NavigateThroughPoses": NavigateThroughPosesGoal, - "nav2_msgs/action/SmoothPath": SmoothPathGoal, - "nav2_msgs/action/Wait": WaitGoal, -} diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py index 953dea515..53ff03a2a 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py @@ -15,13 +15,11 @@ from .basic_tasks import get_basic_tasks from .custom_interfaces_tasks import get_custom_interfaces_tasks from .manipulation_tasks import get_manipulation_tasks -from .navigation_tasks import get_navigation_tasks from .spatial_reasoning_tasks import get_spatial_tasks __all__ = [ "get_basic_tasks", "get_custom_interfaces_tasks", "get_manipulation_tasks", - "get_navigation_tasks", "get_spatial_tasks", ] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py deleted file mode 100644 index 4099e5a22..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/navigation_tasks.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Literal - -from rai_bench.tool_calling_agent.interfaces import ( - Task, - TaskArgs, -) -from rai_bench.tool_calling_agent.subtasks import ( - CheckActionFieldsToolCallSubTask, -) -from rai_bench.tool_calling_agent.tasks.navigation import ( - MoveToBedTask, - MoveToFrontTask, - NavigateToPointTask, - SpinAroundTask, -) -from rai_bench.tool_calling_agent.validators import ( - OrderedCallsValidator, -) - -########## SUBTASKS ################################################################# - -start_nav_action_subtask = CheckActionFieldsToolCallSubTask( - expected_tool_name="start_ros2_action", - expected_action="/navigate_to_pose", - expected_action_type="nav2_msgs/action/NavigateToPose", - expected_fields={ - "pose": { - "header": {"frame_id": "map"}, - "pose": { - "position": {"x": 2.0, "y": 2.0, "z": 0.0}, - }, - }, - }, -) -start_spin_action_subtask = CheckActionFieldsToolCallSubTask( - expected_tool_name="start_ros2_action", - expected_action="/spin", - expected_action_type="nav2_msgs/action/Spin", - expected_fields={"target_yaw": 3}, -) -start_move_front_action_subtask = CheckActionFieldsToolCallSubTask( - expected_tool_name="start_ros2_action", - expected_action="/drive_on_heading", - expected_action_type="nav2_msgs/action/DriveOnHeading", - expected_fields={ - "target": {"y": 0.0, "z": 0.0}, - }, -) -######### VALIDATORS ######################################################################################### -start_navigate_action_ord_val = OrderedCallsValidator( - subtasks=[start_nav_action_subtask] -) -start_spin_action_ord_val = OrderedCallsValidator(subtasks=[start_spin_action_subtask]) -move_ahead_ord_val = OrderedCallsValidator(subtasks=[start_move_front_action_subtask]) - - -def get_navigation_tasks( - extra_tool_calls: List[int] = [0], - prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], - n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], -) -> List[Task]: - """Get predefined navigation tasks. - - Parameters - ---------- - Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. - See the class documentation for parameter descriptions. - - Returns - ------- - Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. - """ - tasks: List[Task] = [] - - for extra_calls in extra_tool_calls: - for detail in prompt_detail: - for shots in n_shots: - task_args = TaskArgs( - extra_tool_calls=extra_calls, - prompt_detail=detail, - examples_in_system_prompt=shots, - ) - tasks.extend( - [ - NavigateToPointTask( - validators=[start_navigate_action_ord_val], - task_args=task_args, - ), - SpinAroundTask( - validators=[start_spin_action_ord_val], - task_args=task_args, - ), - MoveToBedTask( - validators=[move_ahead_ord_val], - task_args=task_args, - ), - MoveToFrontTask( - validators=[move_ahead_ord_val], - task_args=task_args, - ), - ] - ) - - return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py index f3a166c8a..5699841cd 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py @@ -22,7 +22,6 @@ get_basic_tasks, get_custom_interfaces_tasks, get_manipulation_tasks, - get_navigation_tasks, get_spatial_tasks, ) @@ -36,14 +35,12 @@ def get_tasks( Literal[ "basic", "manipulation", - "navigation", "custom_interfaces", "spatial_reasoning", ] ] = [ "basic", "manipulation", - "navigation", "custom_interfaces", "spatial_reasoning", ], @@ -81,12 +78,6 @@ def get_tasks( prompt_detail=prompt_detail, n_shots=n_shots, ) - if "navigation" in task_types: - all_tasks += get_navigation_tasks( - extra_tool_calls=extra_tool_calls, - prompt_detail=prompt_detail, - n_shots=n_shots, - ) if "spatial_reasoning" in task_types: all_tasks += get_spatial_tasks( extra_tool_calls=extra_tool_calls, diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py deleted file mode 100644 index 79514d256..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/navigation.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -from langchain_core.tools import BaseTool - -from rai_bench.tool_calling_agent.interfaces import Task -from rai_bench.tool_calling_agent.mocked_ros2_interfaces import ( - COMMON_SERVICES_AND_TYPES, - COMMON_TOPICS_AND_TYPES, - NAVIGATION_ACTION_MODELS, - NAVIGATION_ACTIONS_AND_TYPES, - NAVIGATION_INTERFACES, - NAVIGATION_SERVICES_AND_TYPES, - NAVIGATION_TOPICS_AND_TYPES, -) -from rai_bench.tool_calling_agent.mocked_tools import ( - MockActionsToolkit, - MockGetROS2MessageInterfaceTool, - MockGetROS2ServicesNamesAndTypesTool, - MockGetROS2TopicsNamesAndTypesTool, -) - -ROBOT_NAVIGATION_SYSTEM_PROMPT_0_SHOT = """You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. - Do not make assumptions about the environment you are currently in. - You can use ros2 topics, services and actions to operate. - - As a first step check transforms by getting 1 message from /tf topic - use /cmd_vel topic very carefully. Obstacle detection works only with nav2 stack, so be careful when it is not used. > - be patient with running ros2 actions. usually the take some time to run. - Always check your transform before and after you perform ros2 actions, so that you can verify if it worked. - - Navigation tips: - - it's good to start finding objects by rotating, then navigating to some diverse location with occasional rotations. Remember to frequency detect objects. - - for driving forward/backward or to some coordinates, ros2 actions are better. - - for driving for some specific time or in specific manner (like shaper or turns) it good to use /cmd_vel topic - - you are currently unable to read map or point-cloud, so please avoid subscribing to such topics. - - if you are asked to drive towards some object, it's good to: - 1. check the camera image and verify if objects can be seen - 2. if only driving forward is required, do it - 3. if obstacle avoidance might be required, use ros2 actions navigate_*, but first check your current position, then very accurately estimate the goal pose. - - it is good to verify using given information if the robot is not stuck - - navigation actions sometimes fail. Their output can be read from rosout. You can also tell if they partially worked by checking the robot position and rotation. - - before using any ros2 interfaces, always make sure to check you are using the right interface - - processing camera image takes 5-10s. Take it into account that if the robot is moving, the information can be outdated. Handle it by good planning of your movements. - - you are encouraged to use wait tool in between checking the status of actions - - to find some object navigate around and check the surrounding area - - when the goal is accomplished please make sure to cancel running actions - - when you reach the navigation goal - double check if you reached it by checking the current position - - if you detect collision, please stop operation - - you will be given your camera image description. Based on this information you can reason about positions of objects. - - be careful and aboid obstacles - - Here are the corners of your environment: - (-2.76,9.04, 0.0), - (4.62, 9.07, 0.0), - (-2.79, -3.83, 0.0), - (4.59, -3.81, 0.0) - - This is location of places: - Kitchen: - (2.06, -0.23, 0.0), - (2.07, -1.43, 0.0), - (-2.44, -0.38, 0.0), - (-2.56, -1.47, 0.0) - - # Living room: - (-2.49, 1.87, 0.0), - (-2.50, 5.49, 0.0), - (0.79, 5.73, 0.0), - (0.92, 1.01, 0.0) - - Before starting anything, make sure to load available topics, services and actions.""" - -ROBOT_NAVIGATION_SYSTEM_PROMPT_2_SHOT = ( - ROBOT_NAVIGATION_SYSTEM_PROMPT_0_SHOT - + """ - - Example tool calls: - - get_ros2_actions_names_and_types, args: {} - - start_ros2_action, args: {'action': '/navigate_to_pose', 'action_type': 'nav2_msgs/action/NavigateToPose', 'goal': {'pose': {'header': {'frame_id': 'map'}, 'pose': {'position': {'x': 2.0, 'y': 2.0, 'z': 0.0}}}}}""" -) - -ROBOT_NAVIGATION_SYSTEM_PROMPT_5_SHOT = ( - ROBOT_NAVIGATION_SYSTEM_PROMPT_2_SHOT - + """ - - get_ros2_message_interface, args: {'msg_type': 'nav2_msgs/action/Spin'} - - start_ros2_action, args: {'action': '/spin', 'action_type': 'nav2_msgs/action/Spin', 'goal': {'target_yaw': 3.14}} - - start_ros2_action, args: {'action': '/drive_on_heading', 'action_type': 'nav2_msgs/action/DriveOnHeading', 'goal': {'target': {'x': 1.0, 'y': 0.0, 'z': 0.0}, 'speed': 0.5}}""" -) -TOPICS_AND_TYPES = COMMON_TOPICS_AND_TYPES | NAVIGATION_TOPICS_AND_TYPES -SERVICES_AND_TYPES = COMMON_SERVICES_AND_TYPES | NAVIGATION_SERVICES_AND_TYPES - - -TOPIC_STRINGS = [ - f"topic: {topic}\ntype: {topic_type}\n" - for topic, topic_type in COMMON_TOPICS_AND_TYPES.items() -] - -ACTION_STRINGS = [ - f"action: {action}\ntype: {act_type}\n" - for action, act_type in NAVIGATION_ACTIONS_AND_TYPES.items() -] - -SERVICE_STRINGS = [ - f"service: {service}\ntype: {srv_type}\n" - for service, srv_type in SERVICES_AND_TYPES.items() -] - - -class NavigationTask(Task): - type = "navigation" - - @property - def available_tools(self) -> List[BaseTool]: - tools = MockActionsToolkit( - mock_actions_names_and_types=ACTION_STRINGS, - available_actions=list(NAVIGATION_ACTIONS_AND_TYPES.keys()), - available_action_types=list(NAVIGATION_ACTIONS_AND_TYPES.values()), - available_action_models=NAVIGATION_ACTION_MODELS, - ).get_tools() - tools.append( - MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=TOPIC_STRINGS - ) - ) - tools.append( - MockGetROS2ServicesNamesAndTypesTool( - mock_service_names_and_types=SERVICE_STRINGS - ) - ) - tools.append( - MockGetROS2MessageInterfaceTool(mock_interfaces=NAVIGATION_INTERFACES) - ) - return tools - - @property - def optional_tool_calls_number(self) -> int: - # list topics and get interface - return 2 - - def get_system_prompt(self) -> str: - if self.n_shots == 0: - return ROBOT_NAVIGATION_SYSTEM_PROMPT_0_SHOT - elif self.n_shots == 2: - return ROBOT_NAVIGATION_SYSTEM_PROMPT_2_SHOT - else: - return ROBOT_NAVIGATION_SYSTEM_PROMPT_5_SHOT - - -class NavigateToPointTask(NavigationTask): - complexity = "easy" - - def get_base_prompt(self) -> str: - return "Navigate to point (2.0, 2.0, 0.0)." - - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can use the navigation tools to move the robot to the specified coordinates. " - "First get the available actions, then set up the navigation goal to reach point (2.0, 2.0, 0.0)." - ) - - -class SpinAroundTask(NavigationTask): - recursion_limit = 50 - complexity = "medium" - - def get_base_prompt(self) -> str: - return "Spin around by 3 radians." - - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can locate the robot's current orientation and execute a spinning motion " - "to rotate the robot by 3 radians from its current heading." - ) - - -class MoveToFrontTask(NavigationTask): - recursion_limit = 50 - complexity = "medium" - - def get_base_prompt(self) -> str: - return "Move 2 meters to the front." - - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can determine the robot's current position and orientation, " - "then move it 2 meters forward in the direction it is currently facing." - ) - - -class MoveToBedTask(NavigationTask): - recursion_limit = 50 - complexity = "hard" - - def get_base_prompt(self) -> str: - return "Move closer to the bed leaving 1 meter space." - - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can locate the bed in the environment, calculate the appropriate position " - "that maintains 1 meter distance from the bed, and navigate to that position." - ) diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index 87629ac43..cdd1b2f67 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -68,14 +68,12 @@ def parse_tool_calling_benchmark_args(): choices=[ "basic", "manipulation", - "navigation", "custom_interfaces", "spatial_reasoning", ], default=[ "basic", "manipulation", - "navigation", "custom_interfaces", "spatial_reasoning", ], From 13244ef27e65ef6078c3cc52670931e24cda48e0 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:41:08 +0200 Subject: [PATCH 06/13] refactor: o3de config (#630) --- docs/tutorials/benchmarking.md | 24 +------------------ .../rai_bench/manipulation_o3de/benchmark.py | 3 +-- .../predefined/configs/o3de_config.yaml | 18 ++++++++++++++ src/rai_bench/rai_bench/utils.py | 2 +- tests/rai_sim/conftest.py | 1 - 5 files changed, 21 insertions(+), 27 deletions(-) create mode 100644 src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index deb9d20a9..fa65de41f 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -16,32 +16,10 @@ If your goal is creating custom tasks and scenarios, visit [Creating Custom Task ## Manipulation O3DE - Follow setup from [Manipulation demo Setup](../demos/manipulation.md#setup) -- Create the O3DE config: - ```yaml - binary_path: /path/to/binary/RAIManipulationDemo.GameLauncher - level: RoboticManipulationBenchmark - robotic_stack_command: ros2 launch examples/manipulation-demo-no-binary.launch.py - required_simulation_ros2_interfaces: - services: - - /spawn_entity - - /delete_entity - topics: - - /color_image5 - - /depth_image5 - - /color_camera_info5 - actions: [] - required_robotic_ros2_interfaces: - services: - - /grounding_dino_classify - - /grounded_sam_segment - - /manipulator_move_to - topics: [] - actions: [] - ``` - Run the benchmark with: ```bash - python src/rai_bench/rai_bench/examples/manipulation_o3de.py --model-name --vendor --o3de-config-path --levels + python src/rai_bench/rai_bench/examples/manipulation_o3de.py --model-name --vendor --levels ``` !!! warning diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index 4dce98527..ed53cad85 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -137,7 +137,7 @@ def __init__( self.simulation_bridge.init_simulation(simulation_config=simulation_config) self.simulation_bridge.launch_robotic_stack( required_robotic_ros2_interfaces=simulation_config.required_robotic_ros2_interfaces, - launch_description=self.launch_description, + launch_description=self.launch_description(), ) self.num_of_scenarios = len(scenarios) self.scenarios = enumerate(iter(scenarios)) @@ -146,7 +146,6 @@ def __init__( self.score_tracing_handler = ScoreTracingHandler() self.csv_initialize(self.results_filename, ScenarioResult) - @property def launch_description(self): launch_moveit = IncludeLaunchDescription( PythonLaunchDescriptionSource( diff --git a/src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml b/src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml new file mode 100644 index 000000000..c826a71f4 --- /dev/null +++ b/src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml @@ -0,0 +1,18 @@ +binary_path: demo_assets/manipulation/RAIManipulationDemo/RAIManipulationDemo.GameLauncher +level: RoboticManipulationBenchmark +required_simulation_ros2_interfaces: + services: + - /spawn_entity + - /delete_entity + topics: + - /color_image5 + - /depth_image5 + - /color_camera_info5 + actions: [] +required_robotic_ros2_interfaces: + services: + - /grounding_dino_classify + - /grounded_sam_segment + - /manipulator_move_to + topics: [] + actions: [] diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index cdd1b2f67..60fbac038 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -101,7 +101,7 @@ def parse_manipulation_o3de_benchmark_args(): parser.add_argument( "--o3de-config-path", type=str, - required=True, + default="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", help="Path to the O3DE configuration file", ) parser.add_argument( diff --git a/tests/rai_sim/conftest.py b/tests/rai_sim/conftest.py index e440c938a..bc6f2d953 100644 --- a/tests/rai_sim/conftest.py +++ b/tests/rai_sim/conftest.py @@ -52,7 +52,6 @@ def sample_base_yaml_config(tmp_path: Path) -> Path: def sample_o3dexros2_config(tmp_path: Path) -> Path: yaml_content = """ binary_path: /path/to/binary - robotic_stack_command: "ros2 launch robotic_stack.launch.py" required_simulation_ros2_interfaces: services: - /spawn_entity From c26026ab221426edac8a7ff288e9731c4d59144c Mon Sep 17 00:00:00 2001 From: Bartek Boczek Date: Tue, 12 Aug 2025 16:49:26 +0200 Subject: [PATCH 07/13] refactor(`nav2_toolkit`): remove unused `action_client` (#670) --- src/rai_core/rai/tools/ros2/navigation/nav2.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/rai_core/rai/tools/ros2/navigation/nav2.py b/src/rai_core/rai/tools/ros2/navigation/nav2.py index b3d364192..0ca4417fe 100644 --- a/src/rai_core/rai/tools/ros2/navigation/nav2.py +++ b/src/rai_core/rai/tools/ros2/navigation/nav2.py @@ -23,7 +23,6 @@ from nav2_msgs.action import NavigateToPose from nav_msgs.msg import OccupancyGrid from pydantic import BaseModel, Field -from rclpy.action import ActionClient from tf_transformations import euler_from_quaternion, quaternion_from_euler from rai.communication.ros2 import ROS2Message @@ -31,7 +30,6 @@ from rai.messages import MultimodalArtifact from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit -action_client: Optional[ActionClient] = None current_action_id: Optional[str] = None current_feedback: Optional[NavigateToPose.Feedback] = None current_result: Optional[NavigateToPose.Result] = None @@ -88,12 +86,6 @@ def on_done(self, result: NavigateToPose.Result) -> None: current_result = result def _run(self, x: float, y: float, z: float, yaw: float) -> str: - global action_client - if action_client is None: - action_client = ActionClient( - self.connector.node, NavigateToPose, self.action_name - ) - pose = PoseStamped() pose.header.frame_id = self.frame_id pose.header.stamp = self.connector.node.get_clock().now().to_msg() From 4c31d5d3daf35e4b2fa29a9c8d070488b431ed85 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Mon, 1 Sep 2025 11:40:34 +0200 Subject: [PATCH 08/13] fix: manipulaiton bench fixes (#653) --- src/rai_bench/rai_bench/manipulation_o3de/benchmark.py | 3 +-- src/rai_bench/rai_bench/test_models.py | 9 +++++++-- src/rai_bench/rai_bench/tool_calling_agent/benchmark.py | 3 +-- src/rai_core/rai/tools/ros2/manipulation/custom.py | 6 +++--- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index ed53cad85..be0e86408 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -308,8 +308,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: self.logger.error(msg=f"Task timeout: {e}") except GraphRecursionError as e: self.logger.error(msg=f"Reached recursion limit {e}") - except Exception as e: - self.logger.error(msg=f"Unexpected errot occured: {e}") + te = time.perf_counter() try: score = scenario.task.calculate_score(self.simulation_bridge) diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index 0501d811c..2528596e9 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -250,7 +250,12 @@ def test_models( bench_logger=bench_logger, ) except Exception as e: - bench_logger.critical(f"BENCHMARK RUN FAILED: {e}") + import traceback + bench_logger.critical( - f"{bench_conf.name} benchmark for {model_name}, vendor: {vendors[i]}, execution number: {u + 1}" + f"{bench_conf.name} benchmark for {model_name}, vendor: {vendors[i]}, repeat number: {u + 1}" ) + bench_logger.critical(f"BENCHMARK RUN FAILED: {e}") + error_msg = traceback.format_exc() + bench_logger.critical(error_msg) + print(error_msg) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index aeb059d6b..44923c860 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -148,8 +148,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: self.logger.error(msg=f"Task timeout: {e}") except GraphRecursionError as e: self.logger.error(msg=f"Reached recursion limit {e}") - except Exception as e: - self.logger.error(msg=f"Unexpected error occured: {e}") + tool_calls = task.get_tool_calls_from_messages(messages=messages) score = task.validate(tool_calls=tool_calls) te = time.perf_counter() diff --git a/src/rai_core/rai/tools/ros2/manipulation/custom.py b/src/rai_core/rai/tools/ros2/manipulation/custom.py index 1e7a38606..6e0d9655c 100644 --- a/src/rai_core/rai/tools/ros2/manipulation/custom.py +++ b/src/rai_core/rai/tools/ros2/manipulation/custom.py @@ -245,12 +245,12 @@ def _run( response = get_future_result(future, timeout_sec=20.0) if response is None: - return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})." + return f"Service call failed for point ({x1:.2f}, {y1:.2f}, {z1:.2f})." if response.success: - return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." + return f"End effector successfully positioned at coordinates ({x1:.2f}, {y1:.2f}, {z1:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." else: - return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})." + return f"Failed to position end effector at coordinates ({x1:.2f}, {y1:.2f}, {z1:.2f})." class GetObjectPositionsToolInput(BaseModel): From 45baf9a1b4a931923967db2c682436dcf4ca8160 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 5 Sep 2025 11:04:19 +0200 Subject: [PATCH 09/13] docs: rai simbench docs update (#665) --- docs/simulation_and_benchmarking/rai_bench.md | 5 +- docs/tutorials/benchmarking.md | 279 ++++++++++++------ src/rai_bench/README.md | 211 +------------ .../rai_bench/examples/benchmarking_models.py | 19 +- .../rai_bench/examples/custom_scenario.py | 128 ++++++++ .../rai_bench/examples/custom_task.py | 128 ++++++++ .../rai_bench/manipulation_o3de/benchmark.py | 2 +- .../rai_bench/manipulation_o3de/interfaces.py | 57 ++-- .../tasks/build_tower_task.py | 4 +- .../tasks/group_objects_task.py | 4 +- .../tasks/move_object_to_left_task.py | 4 +- .../tasks/place_at_coord_task.py | 4 +- .../tasks/place_cubes_task.py | 4 +- .../tasks/rotate_object_task.py | 4 +- src/rai_bench/rai_bench/test_models.py | 3 +- src/rai_bench/rai_bench/utils.py | 1 + src/rai_sim/rai_sim/o3de/o3de_bridge.py | 2 +- src/rai_sim/rai_sim/simulation_bridge.py | 15 +- tests/rai_sim/test_simulation_bridge.py | 2 +- 19 files changed, 509 insertions(+), 367 deletions(-) create mode 100644 src/rai_bench/rai_bench/examples/custom_scenario.py create mode 100644 src/rai_bench/rai_bench/examples/custom_task.py diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index 1969be88e..d3072a783 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -73,7 +73,7 @@ score = (correctly_placed_now - correctly_placed_initially) / initially_incorrec You can find predefined scene configs in `rai_bench/manipulation_o3de/predefined/configs/`. -Predefined scenarios can be imported like: +Predefined scenarios can be imported, for example, choosing tasks by difficulty: ```python from rai_bench.manipulation_o3de import get_scenarios @@ -81,8 +81,6 @@ from rai_bench.manipulation_o3de import get_scenarios get_scenarios(levels=["easy", "medium"]) ``` -Choose which task you want by selecting the difficulty, from trivial to very hard scenarios. - ## Tool Calling Agent Benchmark Evaluates agent performance independently from any simulation, based only on tool calls that the agent makes. To make it independent from simulations, this benchmark introduces tool mocks which can be adjusted for different tasks. This makes the benchmark more universal and a lot faster. @@ -106,6 +104,7 @@ The `Validator` class can combine single or multiple subtasks to create a single - OrderedCallsValidator - requires a strict order of subtasks. The next subtask will be validated only when the previous one was completed. Validator passes when all subtasks pass. - NotOrderedCallsValidator - doesn't enforce order of subtasks. Every subtask will be validated against every tool call. Validator passes when all subtasks pass. +- OneFromManyValidator - passes when any one of the given subtasks passes. ### Task diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index fa65de41f..85dfca688 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -15,60 +15,61 @@ If your goal is creating custom tasks and scenarios, visit [Creating Custom Task ## Manipulation O3DE -- Follow setup from [Manipulation demo Setup](../demos/manipulation.md#setup) -- Run the benchmark with: +- Follow the main setup [Basic Setup](../setup/install.md) and setup from [Manipulation demo Setup](../demos/manipulation.md#setup) +- To see available options run: + ```bash + python src/rai_bench/rai_bench/examples/manipulation_o3de.py --help + ``` +- Example usage: ```bash - python src/rai_bench/rai_bench/examples/manipulation_o3de.py --model-name --vendor --levels + python src/rai_bench/rai_bench/examples/manipulation_o3de.py --model-name qwen2.5:7b --vendor ollama --levels trivial ``` + !!! note + + When using Ollama, be sure to pull the model first. + !!! warning - Running all scenarios will take a while. If you want to just try it out, we recommend choosing just one level of difficulty. + Running all scenarios will take a while. If you want to just try it out, we recommend choosing just one level of difficulty. ## Tool Calling Agent -This benchmark does not require any additional setup besides the main one [Basic Setup](../setup/install.md), just run: +- This benchmark does not require any additional setup besides the main one [Basic Setup](../setup/install.md) +- To see available options run: + ```bash + python src/rai_bench/rai_bench/examples/tool_calling_agent.py --help + ``` +- Example usage: ```bash -python src/rai_bench/rai_bench/examples/tool_calling_agent.py --model-name --vendor --extra-tool-calls <0 5> --task-types basic --n-shots <0 2> --prompt-detail --complexities --out-dir +python src/rai_bench/rai_bench/examples/tool_calling_agent.py --model-name qwen2.5:7b --vendor ollama --extra-tool-calls 5 --task-types basic --n-shots 5 --prompt-detail descriptive --complexities easy ``` -!!! note - - This Benchmark is significantly faster, but still, if just trying out, we recommend choosing just one parameter per flag as every combination on params will create more tasks. - ## Testing Models -The best way of benchmarking your models is using the `rai_bench.test_models` function with benchmark configs. - -??? info "test_models function definition" - - ::: rai_bench.test_models.test_models +The best way of benchmarking your models is using the `src/rai_bench/rai_bench/examples/benchmarking_models.py` -Example usage: +Feel free to modify the benchmark configs to suit your needs, you can choose every possible set of params +and the benchmark will be run tasks with every combination: ```python -from rai_bench import ( - ManipulationO3DEBenchmarkConfig, - ToolCallingAgentBenchmarkConfig, - test_models, -) - if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen2.5:7b", "llama3.2:3b"] + model_names = ["qwen3:4b", "llama3.2:3b"] vendors = ["ollama", "ollama"] # Define benchmarks that will be used - man_conf = ManipulationO3DEBenchmarkConfig( - o3de_config_path="path/to/your/o3de_config.yaml", # path to your O3DE config + mani_conf = ManipulationO3DEBenchmarkConfig( + o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", levels=[ # define what difficulty of tasks to include in benchmark "trivial", + "easy", ], repeats=1, # how many times to repeat ) - tool_conf = ToolCallingAgentBenchmarkConfig( + tool_conf = ToolCallingAgentBenchmarkConfig( extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", @@ -76,10 +77,7 @@ if __name__ == "__main__": "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt - prompt_detail=[ # how descriptive should task prompt be - "brief", - "descriptive" - ], + prompt_detail=["brief", "descriptive"], # how descriptive should task prompt be repeats=1, ) @@ -87,11 +85,22 @@ if __name__ == "__main__": test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[man_conf, tool_conf], + benchmark_configs=[mani_conf, tool_conf], out_dir=out_dir, + # if you want to pass any additinal args to model + additional_model_args=[ + {"reasoning": False}, + {}, + ], ) ``` +Based on the example above the `Tool Calling` benchmark will run basic, spatial_reasoning and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. + +!!! note + + When using ollama vendor make sure to pull used models first + ## Viewing Results From every benchmark run, there will be results saved in the provided output directory: @@ -100,7 +109,7 @@ From every benchmark run, there will be results saved in the provided output dir - results_summary.csv - for overall metrics - results.csv - for detailed results of every task/scenario -When using `test_models`, the output directories will be saved as `////...` and this format can be visualized with our Streamlit script: +When using `test_models`, the output directories will be saved as `////...` and this format can be visualized with our Streamlit script: ```bash streamlit run src/rai_bench/rai_bench/examples/visualise_streamlit.py @@ -110,16 +119,29 @@ streamlit run src/rai_bench/rai_bench/examples/visualise_streamlit.py ### Manipulation O3DE Scenarios -To create your own Scenarios, you will need a Scene Config and Task. You can combine already existing Scene and existing Task to create a new Scenario like: +To create your own Scenarios, you will need a Scene Config and Task - check out example `src/rai_bench/rai_bench/examples/custom_scenario.py`. +You can combine already existing Scene and existing Task to create a new Scenario like: ```python +import logging from pathlib import Path -from rai_bench.manipulation_o3de.tasks import PlaceObjectAtCoordTask -from rai_sim.simulation_bridge import SceneConfig +from typing import List, Sequence, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + from rai_bench.manipulation_o3de.benchmark import Scenario +from rai_bench.manipulation_o3de.interfaces import ( + ManipulationTask, +) +from rai_bench.manipulation_o3de.tasks import PlaceObjectAtCoordTask +from rai_sim.simulation_bridge import Entity, SceneConfig +loggers_type = Union[RcutilsLogger, logging.Logger] -path_to_your_config = "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/1a.yaml" +### Define your scene setup ####################3 +path_to_your_config = ( + "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/1a.yaml" +) scene_config = SceneConfig.load_base_config(Path(path_to_your_config)) # configure existing Task with different params @@ -156,17 +178,6 @@ entities: Creating your own Task will require slightly more effort. Let's start with something simple - a Task that will require throwing given objects off the table: ```python -import logging -from typing import List, Tuple, Union -from rclpy.impl.rcutils_logger import RcutilsLogger -from rai_bench.manipulation_o3de.interfaces import ( - ManipulationTask, -) -from rai_sim.simulation_bridge import Entity, SimulationConfig - -loggers_type = Union[RcutilsLogger, logging.Logger] - - class ThrowObjectsOffTableTask(ManipulationTask): def __init__(self, obj_types: List[str], logger: loggers_type | None = None): super().__init__(logger=logger) @@ -180,11 +191,9 @@ class ThrowObjectsOffTableTask(ManipulationTask): # define prompt obj_names = ", ".join(obj + "s" for obj in self.obj_types).replace("_", " ") # 0.0 z is the level of table, so any coord below that means it is off the table - return f"Manipulate objects, so that all of the {obj_names} are thrown off the table (negative z)" + return f"Manipulate objects, so that all of the {obj_names} are dropped outside of the table (for example y<-0.75)." - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: # Validate if any required objects are present in sim config # if there is not a single object of provided type, there is no point in running # this task of given scene config @@ -193,7 +202,7 @@ class ThrowObjectsOffTableTask(ManipulationTask): ) return count > 1 - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: selected_type_objects = self.filter_entities_by_object_type( entities=entities, object_types=self.obj_types ) @@ -206,88 +215,178 @@ class ThrowObjectsOffTableTask(ManipulationTask): incorrect: int = len(selected_type_objects) - correct return correct, incorrect + # configure existing Task with different params target_coords = (0.1, 0.1) disp = 0.1 -task = PlaceObjectAtCoordTask( - obj_type="apple", - target_position=target_coords, - allowable_displacement=disp, +task = ThrowObjectsOffTableTask( + obj_types=["apple"], ) -Scenario( - task=task, - scene_config=scene_config, - scene_config_path=path_to_your_config +super_scenario = Scenario( + task=task, scene_config=scene_config, scene_config_path=path_to_your_config ) ``` As `obj_types` is parameterizable, it enables various variants of this Task. In combination with a lot of simulation configs available, it means that a single Task can provide dozens of scenarios. -Congratulations, you just created your first Scenario from scratch! +Then yo test it simply run: + +```python +##### Now you can run it in benchmark ################## +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.manipulation_o3de import run_benchmark + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path(out_dir="src/rai_bench/experiments/custom_task/") + + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + # use your scenario + scenarios=[super_scenario], + bench_logger=bench_logger, + ) + +``` + +Congratulations, you just created and launched your first Scenario from scratch! ### Tool Calling Tasks To create a Tool Calling Task, you will need to define Subtasks, Validators, and Task itself. +Check the example `src/rai_bench/rai_bench/examples/custom_task.py`. Let's create a basic task that requires using a tool to receive a message from a specific topic. ```python +from typing import List + +from langchain_core.tools import BaseTool + +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs +from rai_bench.tool_calling_agent.mocked_tools import ( + MockGetROS2TopicsNamesAndTypesTool, + MockReceiveROS2MessageTool, +) from rai_bench.tool_calling_agent.subtasks import ( CheckArgsToolCallSubTask, ) from rai_bench.tool_calling_agent.validators import ( OrderedCallsValidator, ) -from rai_bench.tool_calling_agent.mocked_tools import ( - MockGetROS2TopicsNamesAndTypesTool, -) -from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs -from langchain_core.tools import BaseTool -from typing import List - - - -# define subtask that requires -receive_robot_pos_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_position"}, - expected_optional_args={ - "timeout_sec": int - }, # if there is not exact value expected, you can pass type -) -# use OrderedCallValidator as there is only 1 subtask -topics_ord_val = OrderedCallsValidator(subtasks=[receive_robot_pos_subtask]) +# This Task will check if robot can receive msessage from specified topic class GetROS2RobotPositionTask(Task): complexity = "easy" + type = "custom" @property def available_tools(self) -> List[BaseTool]: + # define topics that will be seen by agent + TOPICS = [ + "/robot_position", + "/attached_collision_object", + "/clock", + "/collision_object", + ] + + TOPICS_STRING = [ + "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", + "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", + "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", + "topic: /robot_position\n type: sensor_msgs/msg/RobotPosition", + ] + # define which tools will be available for agent return [ - # define which topics will be seen by agent MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /robot_position\n type: sensor_msgs/msg/RobotPosition", - ] + mock_topics_names_and_types=TOPICS_STRING ), + MockReceiveROS2MessageTool(available_topics=TOPICS), ] def get_system_prompt(self) -> str: return "You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system." - def get_prompt(self) -> str: + def get_base_prompt(self) -> str: return "Get the position of the robot." + def get_prompt(self) -> str: + # Create versions for different levels + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover what topics are currently active." + ) + @property def optional_tool_calls_number(self) -> int: - # Listing topics before getting any message + # Listing topics before getting any message is fine return 1 + +# define subtask +receive_robot_pos_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_position"}, + expected_optional_args={ + "timeout_sec": int # if there is not exact value expected, you can pass type + }, +) +# use OrderedCallValidator as there is only 1 subtask to check +topics_ord_val = OrderedCallsValidator(subtasks=[receive_robot_pos_subtask]) + + # optionally pass number of extra tool calls args = TaskArgs(extra_tool_calls=0) -task = GetROS2RobotPositionTask(validators=[topics_ord_val], task_args=args) +super_task = GetROS2RobotPositionTask(validators=[topics_ord_val], task_args=args) +``` + +Then run it with: + +```python +##### Now you can run it in benchmark ################## +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.tool_calling_agent import ( + run_benchmark, + ) + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path("src/rai_bench/rai_bench/experiments/custom_task") + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + super_task.set_logger(bench_logger) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + tasks=[super_task], + bench_logger=bench_logger, + ) ``` diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md index d475ad42c..bdaebe9c4 100644 --- a/src/rai_bench/README.md +++ b/src/rai_bench/README.md @@ -1,209 +1,4 @@ -# RAI Benchmarks +# RAI Bench -The RAI Bench is a package including benchmarks and providing frame for creating new benchmarks - -## Manipulation O3DE Benchmark - -The Manipulation O3DE Benchmark [manipulation_o3de_benchmark_module](./rai_bench//manipulation_o3de/) provides tasks and scene configurations for robotic arm manipulation simulation in O3DE. The tasks use a common `ManipulationTask` logic and can be parameterized, which allows for many task variants. The current tasks include: - -- **MoveObjectToLeftTask** -- **GroupObjectsTask** -- **BuildCubeTowerTask** -- **PlaceObjectAtCoordTask** -- **RotateObjectTask** (currently not applicable due to limitations in the ManipulatorMoveTo tool) - -The result of a task is a value between 0 and 1, calculated like initially_misplaced_now_correct / initially_misplaced. This score is calculated at the end of each scenario. - -### Frame Components - -- `Task` -- `Scenario` -- `Benchmark` - -For more information about these classes go to -> [benchmark](./rai_bench//manipulation_o3de/benchmark.py) and [Task](./rai_bench//manipulation_o3de//interfaces.py) and - -### Example usage - -Example of how to load scenes, define scenarios and run benchmark can be found in [manipulation_o3de_benchmark_example](rai_bench/examples/manipulation_o3de/main.py) - -Scenarios can be loaded manually like: - -```python -one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( - base_config_path=Path("path_to_scene.yaml"), - connector_config_path=Path("path_to_o3de_config.yaml"), - ) - -Scenario(task=GrabCarrotTask(logger=some_logger), simulation_config=one_carrot_simulation_config) -``` - -or automatically like: - -```python -scenarios = Benchmark.create_scenarios( - tasks=tasks, simulation_configs=simulations_configs - ) -``` - -which will result in list of scenarios with combination of every possible task and scene(task decides if scene config is suitable for it). - -or can be imported from exisitng packets [scenarios_packets](rai_bench/examples/manipulation_o3de/scenarios.py): - -```python -t_scenarios = trivial_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger - ) -e_scenarios = easy_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -m_scenarios = medium_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -h_scenarios = hard_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -vh_scenarios = very_hard_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -``` - -which are grouped by their subjective difficulty. For now there are 10 trivial, 42 easy, 23 medium, 38 hard and 47 very hard scenarios. -Check docstrings and code in [scenarios_packets](rai_bench/examples/manipulation_o3de/scenarios.py) if you want to know how scenarios are assigned to difficulty level. - -### Running - -1. Download O3DE simulation binary and unzip it. - - - [ros2-humble](https://robotec-ml-rai-public.s3.eu-north-1.amazonaws.com/RAIManipulationDemo_jammyhumble.zip) - - [ros2-jazzy](https://robotec-ml-rai-public.s3.eu-north-1.amazonaws.com/RAIManipulationDemo_noblejazzy.zip) - -2. Follow step 2 from [Manipulation demo Setup section](../../docs/demos/manipulation.md#setup) - -3. Adjust the path to the binary in: [o3de_config.yaml](./rai_bench/examples/manipulation_o3de/configs/o3de_config.yaml) -4. Choose the model you want to run and a vendor. - > [!NOTE] - > The configs of vendors are defined in [config.toml](../../config.toml) Change ithem if needed. -5. Run benchmark with: - -```bash -cd rai -source setup_shell.sh -python src/rai_bench/rai_bench/examples/manipulation_o3de/main.py --model-name llama3.2 --vendor ollama -``` - -> [!NOTE] -> For now benchmark runs all available scenarios (~160). See [Examples](#example-usege) -> section for details. - -### Development - -When creating new task or changing existing ones, make sure to add unit tests for score calculation in [rai_bench_tests](../../tests/rai_bench/manipulation_o3de/tasks/). -This applies also when you are adding or changing the helper methods in `Task` or `ManipulationTask`. - -The number of scenarios can be easily extened without writing new tasks, by increasing number of variants of the same task and adding more simulation configs but it won't improve variety of scenarios as much as creating new tasks. - -## Tool Calling Agent Benchmark - -The Tool Calling Agent Benchmark is the benchmark for LangChain tool calling agents. It includes a set of tasks and a benchmark that evaluates the performance of the agent on those tasks by verifying the correctness of the tool calls requested by the agent. The benchmark is integrated with LangSmith and Langfuse tracing backends to easily track the performance of the agents. - -### Frame Components - -- [Tool Calling Agent Benchmark](rai_bench//tool_calling_agent/benchmark.py) - Benchmark for LangChain tool calling agents -- [Scores tracing](rai_bench/tool_calling_agent_bench/scores_tracing.py) - Component handling sending scores to tracing backends -- [Interfaces](rai_bench//tool_calling_agent/interfaces.py) - Interfaces for validation classes - Task, Validator, SubTask - For detailed description of validation visit -> [Validation](.//rai_bench/docs/tool_calling_agent_benchmark.md) - -[tool_calling_agent_test_bench.py](rai_bench/examples/tool_calling_agent/main.py) - Script providing benchmark on tasks based on the ROS2 tools usage. - -### Example Usage - -Validators can be constructed from any SubTasks, Tasks can be validated by any numer of Validators, which makes whole validation process incredibly versital. - -```python -# subtasks -get_topics_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_topics_names_and_types" -) -color_image_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_image", expected_args={"topic": "/camera_image_color"} -) -# validators - consist of subtasks -topics_ord_val = OrderedCallsValidator(subtasks=[get_topics_subtask]) -color_image_ord_val = OrderedCallsValidator(subtasks=[color_image_subtask]) -topics_and_color_image_ord_val = OrderedCallsValidator( - subtasks=[ - get_topics_subtask, - color_image_subtask, - ] -) -# tasks - validated by list of validators -GetROS2TopicsTask(validators=[topics_ord_val]) -GetROS2RGBCameraTask(validators=[topics_and_color_image_ord_val]), -GetROS2RGBCameraTask(validators=[topics_ord_val, color_image_ord_val]), -``` - -### Running - -To set up tracing backends, please follow the instructions in the [tracing.md](../../docs/tracing.md) document. - -To run the benchmark: - -```bash -cd rai -source setup_shell.sh -python src/rai_bench/rai_bench/examples/tool_calling_agent/main.py -``` - -There is also flags to declare model type and vendor: - -```bash -python src/rai_bench/rai_bench/examples/tool_calling_agent/main.py --model-name llama3.2 --vendor ollama -``` - -> [!NOTE] -> The configs of vendors are defined in [config.toml](../../config.toml) Change ithem if needed. - -## Testing Models - -To test multiple models, different benchamrks or couple repeats in one go - use script [test_models](./rai_bench/examples/test_models.py) - -Modify these params: - -```python -models_name = ["llama3.2", "qwen2.5:7b"] -vendors = ["ollama", "ollama"] -benchmarks = ["tool_calling_agent"] -repeats = 1 -``` - -to your liking and run the script! - -```bash -python src/rai_bench/rai_bench/examples/test_models.py -``` - -### Results and Visualization - -All results from running benchmarks will be saved to folder [experiments](./rai_bench/experiments/) - -If you run single benchmark test like: - -```bash -python src/rai_bench/rai_bench/examples//main.py -``` - -Results will be saved to dedicated directory named `` - -When you run a test via: - -```bash -python src/rai_bench/rai_bench/examples/test_models.py -``` - -results will be saved to separate folder in [results](./rai_bench/experiments/), with prefix `run_` - -To visualise the results run: - -```bash -streamlit run src/rai_bench/rai_bench/results_processing/visualise.py -``` +For tutorial see [RAI Bench Tutorial](../../docs/tutorials/benchmarking.md) +For understanding the structure of the package visit [RAI Bench Overview](../../docs/simulation_and_benchmarking) diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index c97d11cbd..2a0cc188c 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -20,26 +20,26 @@ if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen2.5:7b"] - vendors = ["ollama"] + model_names = ["qwen3:4b", "llama3.2:3b"] + vendors = ["ollama", "ollama"] # Define benchmarks that will be used mani_conf = ManipulationO3DEBenchmarkConfig( - o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config + o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", levels=[ # define what difficulty of tasks to include in benchmark "trivial", + "easy", ], repeats=1, # how many times to repeat ) tool_conf = ToolCallingAgentBenchmarkConfig( - extra_tool_calls=[0], # how many extra tool calls allowed to still pass + extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", "spatial_reasoning", "custom_interfaces", - "manipulation", ], - N_shots=[2], # examples in system prompt + N_shots=[0, 2], # examples in system prompt prompt_detail=["brief", "descriptive"], # how descriptive should task prompt be repeats=1, ) @@ -48,6 +48,11 @@ test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[tool_conf], + benchmark_configs=[mani_conf, tool_conf], out_dir=out_dir, + # if you want to pass any additinal args to model + additional_model_args=[ + {"reasoning": False}, + {}, + ], ) diff --git a/src/rai_bench/rai_bench/examples/custom_scenario.py b/src/rai_bench/rai_bench/examples/custom_scenario.py new file mode 100644 index 000000000..60b0ea3c3 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/custom_scenario.py @@ -0,0 +1,128 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import List, Sequence, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.manipulation_o3de.benchmark import Scenario +from rai_bench.manipulation_o3de.interfaces import ( + ManipulationTask, +) +from rai_bench.manipulation_o3de.tasks import PlaceObjectAtCoordTask +from rai_sim.simulation_bridge import Entity, SceneConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] + +### Define your scene setup ####################3 +path_to_your_config = ( + "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/1a.yaml" +) +scene_config = SceneConfig.load_base_config(Path(path_to_your_config)) + +# Configure predefined task to place an apple on the table +target_coords = (0.1, 0.1) +disp = 0.1 +task = PlaceObjectAtCoordTask( + obj_type="apple", + target_position=target_coords, + allowable_displacement=disp, +) + +# Create scene with apple on the table +Scenario(task=task, scene_config=scene_config, scene_config_path=path_to_your_config) + + +######### Define your task ################### +class ThrowObjectsOffTableTask(ManipulationTask): + def __init__(self, obj_types: List[str], logger: loggers_type | None = None): + super().__init__(logger=logger) + # obj_types is a list of objects that are subject of the task + # In this case, it will mean which objects should be thrown off the table + # can be any objects + self.obj_types = obj_types + + @property + def task_prompt(self) -> str: + # define prompt + obj_names = ", ".join(obj + "s" for obj in self.obj_types).replace("_", " ") + # 0.0 z is the level of table, so any coord below that means it is off the table + return f"Manipulate objects, so that all of the {obj_names} are dropped outside of the table (for example y<-0.75)." + + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: + # Validate if any required objects are present in sim config + # if there is not a single object of provided type, there is no point in running + # this task of given scene config + count = sum( + 1 for ent in simulation_config.entities if ent.prefab_name in self.obj_types + ) + return count > 1 + + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: + selected_type_objects = self.filter_entities_by_object_type( + entities=entities, object_types=self.obj_types + ) + + # check how many objects are below table, that will be our metric + correct = sum( + 1 for ent in selected_type_objects if ent.pose.pose.position.z < 0.0 + ) + + incorrect: int = len(selected_type_objects) - correct + return correct, incorrect + + +# Task, throw apple off the table +remove_obj_from_table_task = ThrowObjectsOffTableTask( + obj_types=["apple"], +) +super_scenario = Scenario( + task=task, scene_config=scene_config, scene_config_path=path_to_your_config +) + + +super_scenario = Scenario( + task=remove_obj_from_table_task, + scene_config=scene_config, + scene_config_path=path_to_your_config, +) + +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.manipulation_o3de import run_benchmark + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path("src/rai_bench/rai_bench/experiments/custom_task/") + + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + # use your scenario + scenarios=[super_scenario], + bench_logger=bench_logger, + ) diff --git a/src/rai_bench/rai_bench/examples/custom_task.py b/src/rai_bench/rai_bench/examples/custom_task.py new file mode 100644 index 000000000..fdfb9b763 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/custom_task.py @@ -0,0 +1,128 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +from langchain_core.tools import BaseTool + +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs +from rai_bench.tool_calling_agent.mocked_tools import ( + MockGetROS2TopicsNamesAndTypesTool, + MockReceiveROS2MessageTool, +) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, +) +from rai_bench.tool_calling_agent.validators import ( + OrderedCallsValidator, +) + + +# This Task will check if robot can receive msessage from specified topic +class GetROS2RobotPositionTask(Task): + complexity = "easy" + type = "custom" + + @property + def available_tools(self) -> List[BaseTool]: + # define topics that will be seen by agent + TOPICS = [ + "/robot_position", + "/attached_collision_object", + "/clock", + "/collision_object", + ] + + TOPICS_STRING = [ + "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", + "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", + "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", + "topic: /robot_position\n type: sensor_msgs/msg/RobotPosition", + ] + # define which tools will be available for agent + return [ + MockGetROS2TopicsNamesAndTypesTool( + mock_topics_names_and_types=TOPICS_STRING + ), + MockReceiveROS2MessageTool(available_topics=TOPICS), + ] + + def get_system_prompt(self) -> str: + return "You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system." + + def get_base_prompt(self) -> str: + return "Get the position of the robot." + + def get_prompt(self) -> str: + # Create versions for different levels + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover what topics are currently active." + ) + + @property + def optional_tool_calls_number(self) -> int: + # Listing topics before getting any message is fine + return 1 + + +# define subtask +receive_robot_pos_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_position"}, + expected_optional_args={ + "timeout_sec": int # if there is not exact value expected, you can pass type + }, +) +# use OrderedCallValidator as there is only 1 subtask to check +topics_ord_val = OrderedCallsValidator(subtasks=[receive_robot_pos_subtask]) + + +# optionally pass number of extra tool calls +args = TaskArgs(extra_tool_calls=0) +super_task = GetROS2RobotPositionTask(validators=[topics_ord_val], task_args=args) + +##### Now you can run it in benchmark ################## +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.tool_calling_agent import ( + run_benchmark, + ) + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path("src/rai_bench/rai_bench/experiments/custom_task") + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + super_task.set_logger(bench_logger) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + tasks=[super_task], + bench_logger=bench_logger, + ) diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index be0e86408..bb4730ad8 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -422,9 +422,9 @@ def _setup_benchmark_environment( def run_benchmark( llm: BaseChatModel, out_dir: Path, - o3de_config_path: str, scenarios: List[Scenario], bench_logger: logging.Logger, + o3de_config_path: str = "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", experiment_id: uuid.UUID = uuid.uuid4(), ): connector, o3de, benchmark, tools = _setup_benchmark_environment( diff --git a/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py b/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py index 91e595320..7e07c1591 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py @@ -15,7 +15,7 @@ import math from abc import ABC, abstractmethod from collections import defaultdict -from typing import Dict, List, Set, Tuple, TypeVar, Union +from typing import Dict, List, Sequence, Set, Tuple, Union from rai.types import Pose from rclpy.impl.rcutils_logger import RcutilsLogger @@ -24,12 +24,9 @@ Entity, SceneConfig, SimulationBridge, - SimulationConfigT, - SpawnedEntity, ) loggers_type = Union[RcutilsLogger, logging.Logger] -EntityT = TypeVar("EntityT", bound=Entity) class EntitiesMismatchException(Exception): @@ -79,9 +76,7 @@ def validate_config(self, simulation_config: SceneConfig) -> bool: pass @abstractmethod - def calculate_score( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> float: + def calculate_score(self, simulation_bridge: SimulationBridge) -> float: """ Calculate the task score based on the simulation information. @@ -98,8 +93,8 @@ def calculate_score( pass def filter_entities_by_object_type( - self, entities: List[EntityT], object_types: List[str] - ) -> List[EntityT]: + self, entities: Sequence[Entity], object_types: List[str] + ) -> List[Entity]: """ Filter and return only the entities that match the provided prefab types. @@ -198,14 +193,14 @@ def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> in return adjacent_count def build_neighbourhood_list( - self, entities: List[EntityT], threshold_distance: float = 0.15 - ) -> Dict[EntityT, List[EntityT]]: + self, entities: Sequence[Entity], threshold_distance: float = 0.15 + ) -> Dict[Entity, List[Entity]]: """ Build a neighbourhood list assigning a list of neighbours to every entity based on a threshold distance. Parameters ---------- - entities : List[EntityT] + entities : Sequence[EntityT] # Changed from List[EntityT] The list of entities. threshold_distance : float, optional The maximum distance between entities to consider them neighbours. Default is 0.15. @@ -215,7 +210,7 @@ def build_neighbourhood_list( Dict[EntityT, List[EntityT]] A dictionary mapping each entity to a list of neighbouring entities. """ - neighbourhood_graph: Dict[EntityT, List[EntityT]] = { + neighbourhood_graph: Dict[Entity, List[Entity]] = { entity: [] for entity in entities } for entity in entities: @@ -230,8 +225,8 @@ def build_neighbourhood_list( return neighbourhood_graph def group_entities_by_type( - self, entities: List[EntityT] - ) -> Dict[str, List[EntityT]]: + self, entities: Sequence[Entity] + ) -> Dict[str, List[Entity]]: """ Group entities by their prefab type. @@ -245,14 +240,14 @@ def group_entities_by_type( Dict[str, List[EntityT]] A dictionary with keys as prefab names and values as lists of entities of that type. """ - entities_by_type: Dict[str, List[EntityT]] = defaultdict(list) + entities_by_type: Dict[str, List[Entity]] = defaultdict(list) for entity in entities: entities_by_type[entity.prefab_name].append(entity) return entities_by_type def check_neighbourhood_types( self, - neighbourhood: List[EntityT], + neighbourhood: Sequence[Entity], allowed_types: List[str], ) -> bool: """ @@ -275,8 +270,8 @@ def check_neighbourhood_types( ) def find_clusters( - self, neighbourhood_list: Dict[EntityT, List[EntityT]] - ) -> List[List[EntityT]]: + self, neighbourhood_list: Dict[Entity, List[Entity]] + ) -> List[List[Entity]]: """ Identify clusters of entities using a DFS algorithm. @@ -293,10 +288,10 @@ def find_clusters( List[List[EntityT]] A list of clusters, where each cluster is a list of connected entities. """ - visited: Set[EntityT] = set() - clusters: List[List[EntityT]] = [] + visited: Set[Entity] = set() + clusters: List[List[Entity]] = [] - def dfs(node: EntityT, cluster: List[EntityT]): + def dfs(node: Entity, cluster: List[Entity]): visited.add(node) cluster.append(node) for neighbor in neighbourhood_list.get(node, []): @@ -305,7 +300,7 @@ def dfs(node: EntityT, cluster: List[EntityT]): for node in neighbourhood_list.keys(): if node not in visited: - component: List[EntityT] = [] + component: List[Entity] = [] dfs(node, component) clusters.append(component) @@ -314,9 +309,9 @@ def dfs(node: EntityT, cluster: List[EntityT]): def group_entities_along_z_axis( # NOTE (jmatejcz) figure out how to group by other coords and orientation, without reapeting code self, - entities: List[EntityT], + entities: List[Entity], margin: float, - ) -> List[List[EntityT]]: + ) -> List[List[Entity]]: """ Group entities that are aligned along the z axis based on their x and y coordinates. @@ -347,7 +342,7 @@ def group_entities_along_z_axis( key=lambda ent: (ent.pose.pose.position.x, ent.pose.pose.position.y), ) - groups: List[List[EntityT]] = [] + groups: List[List[Entity]] = [] for entity in entities: placed = False for group in groups: @@ -440,9 +435,7 @@ def validate_config(self, simulation_config: SceneConfig) -> bool: return False @abstractmethod - def calculate_correct( - self, entities: List[Entity] | List[SpawnedEntity] - ) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """Method to calculate how many objects are placed correctly Parameters @@ -458,7 +451,7 @@ def calculate_correct( pass def calculate_current_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] + self, simulation_bridge: SimulationBridge ) -> tuple[int, int]: """ Get the current placements of objects in the simulation @@ -485,9 +478,7 @@ def calculate_current_placements( ) return current_correct, current_incorrect - def calculate_score( - self, simulation_bridge: SimulationBridge[SceneConfig] - ) -> float: + def calculate_score(self, simulation_bridge: SimulationBridge) -> float: """ Calculate the task score based on the difference between initial and current placements. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py index 771b99387..58f4cf4e9 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -94,7 +94,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) return cube_count > 1 - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of correctly and incorrectly placed cubes. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py index 41ce103ec..18b2dcea1 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -76,7 +76,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b return set(self.obj_types) <= object_types_present - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Count correctly and incorrectly clustered objects based on clustering rules. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py index b6032ec1b..239306acd 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -51,7 +51,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) return set(self.obj_types) <= object_types_present.keys() - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of objects correctly moved to the left side of the table. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py index cd0f9fc28..69811fd3a 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py @@ -14,7 +14,7 @@ import logging import math -from typing import List, Tuple, Union +from typing import Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -67,7 +67,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) return count >= 1 - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of correctly and incorrectly placed objects. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py index 5ec7c5d1b..4337f7b19 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -65,7 +65,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b return False - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of correctly and incorrectly placed cubes based on adjacency. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py index e91af00c8..893619f1e 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py @@ -14,7 +14,7 @@ import logging import math -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rai.types import Quaternion from rclpy.impl.rcutils_logger import RcutilsLogger @@ -72,7 +72,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) def calculate_correct( - self, entities: List[Entity], allowable_rotation_error: float = 5.0 + self, entities: Sequence[Entity], allowable_rotation_error: float = 5.0 ) -> Tuple[int, int]: """ Calculate the number of correctly rotated objects and incorrectly rotated objects, diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index 2528596e9..dbccc393b 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -15,9 +15,8 @@ from abc import abstractmethod from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional -from git import Optional from langchain.chat_models.base import BaseChatModel from pydantic import BaseModel diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index 60fbac038..e1150082c 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -34,6 +34,7 @@ def parse_tool_calling_benchmark_args(): parser.add_argument( "--extra-tool-calls", type=int, + nargs="+", help="Number of extra tools calls agent can make and still pass the task", default=0, ) diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index dac0e1027..4a3e4a2e7 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -61,7 +61,7 @@ def load_config(cls, config_path: Path) -> "O3DExROS2SimulationConfig": return cls(**connector_content) -class O3DExROS2Bridge(SimulationBridge[O3DExROS2SimulationConfig]): +class O3DExROS2Bridge(SimulationBridge): def __init__( self, connector: ROS2Connector, logger: Optional[logging.Logger] = None ): diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index bb514e270..a39ff2906 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -15,7 +15,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, List, Optional, TypeVar +from typing import List, Optional import yaml from pydantic import BaseModel, Field, field_validator @@ -167,25 +167,22 @@ class SceneState(BaseModel): class SimulationConfig(BaseModel): ... -SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) - - -class SimulationBridge(ABC, Generic[SimulationConfigT]): +class SimulationBridge(ABC): """ Responsible for communication with simulation. """ def __init__(self, logger: Optional[logging.Logger] = None): - self.spawned_entities: List[ - SpawnedEntity - ] = [] # list of spawned entities with their initial poses + self.spawned_entities: List[SpawnedEntity] = ( + [] + ) # list of spawned entities with their initial poses if logger is None: self.logger = logging.getLogger(__name__) else: self.logger = logger @abstractmethod - def init_simulation(self, simulation_config: SimulationConfigT): + def init_simulation(self, simulation_config: SimulationConfig): """ Initialize simulation binary and all other required processes, for example ros2 nodes diff --git a/tests/rai_sim/test_simulation_bridge.py b/tests/rai_sim/test_simulation_bridge.py index 5e383fb85..f68264d01 100644 --- a/tests/rai_sim/test_simulation_bridge.py +++ b/tests/rai_sim/test_simulation_bridge.py @@ -148,7 +148,7 @@ def test_load_base_config(sample_base_yaml_config: Path): assert len(config.entities) == 2 -class MockSimulationBridge(SimulationBridge[SimulationConfig]): +class MockSimulationBridge(SimulationBridge): """Mock implementation of SimulationBridge for testing.""" def init_simulation(self, simulation_config: SimulationConfig): From 7d4d1363fd35e7e0b8b4a863de10e0918b5a7b98 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:00:31 +0200 Subject: [PATCH 10/13] feat: planning task and megamind agent (#679) --- .../rai_bench/examples/manipulation_o3de.py | 1 - .../examples/tool_calling_custom_agent.py | 100 ++++ .../rai_bench/tool_calling_agent/benchmark.py | 73 ++- .../tool_calling_agent/interfaces.py | 8 + .../tool_calling_agent/tasks/warehouse.py | 533 ++++++++++++++++++ src/rai_bench/rai_bench/utils.py | 6 +- src/rai_core/rai/agents/__init__.py | 5 +- .../rai/agents/langchain/core/__init__.py | 7 +- .../rai/agents/langchain/core/megamind.py | 305 ++++++++++ .../rai/agents/langchain/core/plan_agent.py | 275 +++++++++ .../rai/agents/langchain/core/react_agent.py | 2 +- .../rai/agents/langchain/core/tool_runner.py | 32 +- src/rai_sim/rai_sim/simulation_bridge.py | 6 +- 13 files changed, 1298 insertions(+), 55 deletions(-) create mode 100644 src/rai_bench/rai_bench/examples/tool_calling_custom_agent.py create mode 100644 src/rai_bench/rai_bench/tool_calling_agent/tasks/warehouse.py create mode 100644 src/rai_core/rai/agents/langchain/core/megamind.py create mode 100644 src/rai_core/rai/agents/langchain/core/plan_agent.py diff --git a/src/rai_bench/rai_bench/examples/manipulation_o3de.py b/src/rai_bench/rai_bench/examples/manipulation_o3de.py index bd0257c5c..675ea8c01 100644 --- a/src/rai_bench/rai_bench/examples/manipulation_o3de.py +++ b/src/rai_bench/rai_bench/examples/manipulation_o3de.py @@ -31,7 +31,6 @@ model_name=args.model_name, vendor=args.vendor, ) - run_benchmark( llm=llm, out_dir=experiment_dir, diff --git a/src/rai_bench/rai_bench/examples/tool_calling_custom_agent.py b/src/rai_bench/rai_bench/examples/tool_calling_custom_agent.py new file mode 100644 index 000000000..7fc878781 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/tool_calling_custom_agent.py @@ -0,0 +1,100 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import uuid +from datetime import datetime +from pathlib import Path + +from rai.agents.langchain.core import ( + Executor, + create_megamind, + get_initial_megamind_state, +) + +from rai_bench import ( + define_benchmark_logger, +) +from rai_bench.tool_calling_agent.benchmark import ToolCallingAgentBenchmark +from rai_bench.tool_calling_agent.interfaces import TaskArgs +from rai_bench.tool_calling_agent.tasks.warehouse import SortingTask +from rai_bench.utils import get_llm_for_benchmark + +if __name__ == "__main__": + now = datetime.now() + out_dir = f"src/rai_bench/rai_bench/experiments/tool_calling/{now.strftime('%Y-%m-%d_%H-%M-%S')}" + experiment_dir = Path(out_dir) + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir, level=logging.DEBUG) + + task = SortingTask(task_args=TaskArgs(extra_tool_calls=50)) + task.set_logger(bench_logger) + + supervisor_name = "gpt-4o" + + executor_name = "gpt-4o-mini" + model_name = f"supervisor-{supervisor_name}_executor-{executor_name}" + supervisor_llm = get_llm_for_benchmark(model_name=supervisor_name, vendor="openai") + executor_llm = get_llm_for_benchmark( + model_name=executor_name, + vendor="openai", + ) + + benchmark = ToolCallingAgentBenchmark( + tasks=[task], + logger=bench_logger, + model_name=model_name, + results_dir=experiment_dir, + ) + manipulation_system_prompt = """You are a manipulation specialist robot agent. +Your role is to handle object manipulation tasks including picking up and droping objects using provided tools. + +Ask the VLM for objects detection and positions before perfomring any manipulation action. +If VLM doesn't see objects that are objectives of the task, return this information, without proceeding""" + + navigation_system_prompt = """You are a navigation specialist robot agent. +Your role is to handle navigation tasks in space using provided tools. + +After performing navigation action, always check your current position to ensure success""" + + executors = [ + Executor( + name="manipulation", + llm=executor_llm, + tools=task.manipulation_tools(), + system_prompt=manipulation_system_prompt, + ), + Executor( + name="navigation", + llm=executor_llm, + tools=task.navigation_tools(), + system_prompt=navigation_system_prompt, + ), + ] + agent = create_megamind( + megamind_llm=supervisor_llm, + megamind_system_prompt=task.get_system_prompt(), + executors=executors, + task_planning_prompt=task.get_planning_prompt(), + ) + + experiment_id = uuid.uuid4() + benchmark.run_next( + agent=agent, + initial_state=get_initial_megamind_state(task=task.get_prompt()), + experiment_id=experiment_id, + ) + + bench_logger.info("===============================================================") + bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") + bench_logger.info("===============================================================") diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index 44923c860..d1843b843 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -19,14 +19,14 @@ from typing import Iterator, List, Optional, Sequence, Tuple from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage +from langchain_core.messages import AIMessage, BaseMessage from langchain_core.runnables.config import RunnableConfig from langgraph.errors import GraphRecursionError from langgraph.graph.state import CompiledStateGraph from rai.agents.langchain.core import ( create_conversational_agent, ) -from rai.messages import HumanMultimodalMessage +from rai.agents.langchain.core.react_agent import ReActAgentState from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, TimeoutException @@ -38,9 +38,6 @@ TaskResult, ToolCallingAgentRunSummary, ) -from rai_bench.tool_calling_agent.tasks.spatial import ( - SpatialReasoningAgentTask, -) from rai_bench.utils import get_llm_model_name @@ -67,7 +64,12 @@ def __init__( self.tasks_results: List[TaskResult] = [] self.csv_initialize(self.results_filename, TaskResult) - def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: + def run_next( + self, + agent: CompiledStateGraph, + initial_state: ReActAgentState, + experiment_id: uuid.UUID, + ) -> None: """Runs the next task of the benchmark. Parameters @@ -87,14 +89,16 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: ) callbacks = self.score_tracing_handler.get_callbacks() run_id = uuid.uuid4() - # NOTE (jmatejcz) recursion limit calculated as all_nodes_num -> one pass though whole node - # plus (task.max_tool_calls_number-1 because the first pass is already added in) - # times number of nodes - 2 because we dont cout start and end node - # this can be to much for larger graphs that dont use all nodes on extra calls - # in such ase adjust this value - recurssion_limit = len(agent.get_graph().nodes) + ( - task.max_tool_calls_number - 1 - ) * (len(agent.get_graph().nodes) - 2) + # NOTE (jmatejcz) recursion limit calculated as (all_nodes_num - 2) * required tool calls + # -2 because we don't want to include START and END node + # then we add numer of additional calls that can be made + # and +2 as we have to pass once though START and END + + recurssion_limit = ( + (len(agent.get_graph().nodes) - 2) * task.required_calls + + task.additional_calls + + 2 + ) config: RunnableConfig = { "run_id": run_id, "callbacks": callbacks, @@ -113,40 +117,27 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: messages: List[BaseMessage] = [] prev_count: int = 0 try: - with self.time_limit(20 * task.max_tool_calls_number): - if isinstance(task, SpatialReasoningAgentTask): - for state in agent.stream( - { - "messages": [ - HumanMultimodalMessage( - content=task.get_prompt(), images=task.get_images() - ) - ] - }, - config=config, - ): - node = next(iter(state)) - all_messages = state[node]["messages"] - for new_msg in all_messages[prev_count:]: - messages.append(new_msg) - prev_count = len(messages) - else: - for state in agent.stream( - { - "messages": [ - HumanMultimodalMessage(content=task.get_prompt()) - ] - }, - config=config, - ): - node = next(iter(state)) + with self.time_limit(200 * task.max_tool_calls_number): + for state in agent.stream( + initial_state, + config=config, + ): + node = next(iter(state)) + if "messages" in state[node]: all_messages = state[node]["messages"] for new_msg in all_messages[prev_count:]: messages.append(new_msg) + if isinstance(new_msg, AIMessage): + self.logger.debug( + f"Message from node '{node}': {new_msg.content}, tool_calls: {new_msg.tool_calls}" + ) prev_count = len(messages) except TimeoutException as e: self.logger.error(msg=f"Task timeout: {e}") except GraphRecursionError as e: + tool_calls = task.get_tool_calls_from_messages(messages=messages) + score = task.validate(tool_calls=tool_calls) + score = 0.0 self.logger.error(msg=f"Reached recursion limit {e}") tool_calls = task.get_tool_calls_from_messages(messages=messages) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py index 7d8cb3187..a92fbd57b 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/interfaces.py @@ -585,6 +585,14 @@ def max_tool_calls_number(self) -> int: + self.extra_tool_calls ) + @property + def additional_calls(self) -> int: + """number of additional calls that can be done to still pass task. + Includes extra tool calls params. + and optional tool calls number which depends on task. + """ + return self.optional_tool_calls_number + self.extra_tool_calls + @property def required_calls(self) -> int: """Minimal number of calls required to complete task""" diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/warehouse.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/warehouse.py new file mode 100644 index 000000000..423430fc7 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/warehouse.py @@ -0,0 +1,533 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +from langchain_core.tools import BaseTool, tool + +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator +from rai_bench.tool_calling_agent.subtasks import CheckArgsToolCallSubTask +from rai_bench.tool_calling_agent.validators import OrderedCallsValidator + +WAREHOUSE_ENVIRONMENT_DESCRIPTION = """ +WAREHOUSE LAYOUT: + +TABLE WITH SLOTS: +- Table location: x=10-11, y=1-7 +- Slot 1: (10.0, 1.5) +- Slot 2: (10.0, 3.0) +- Slot 3: (10.0, 4.5) +- Slot 4: (10.0, 6.0) +When navigating to the table remember that you can't navigate into it, +always approach from the side that is closer to rack (use x=10). + +Each slot can contain at most 1 item that can be picked up. +New Items won't appear during the task, so if you picked objects from a ceratin slot, +it will be empty for the rest of the task. + +STORAGE RACKS: +Storage Rack 1 location: x=2-6 y=5-6 +- Boxes: (3.0, 5.0), (5.0, 5.0) +When navigating to the tack remember that you can't navigate into it, +always approach from the side that is closer to starting position (use y=5). + +ROBOT STARTING POSITION: +- Robot starting location: (4.0, 2.0) +""" +SYSTEM_PROMPT = """You are a mobile robot operating in a warehouse environment for pick-and-place operations.""" + + +class EnvStateManager: + """Enhanced env state manager that tracks objects, boxes, and robot state""" + + def __init__(self): + self._state = { + "robot_position": (4.0, 2.0), + "gripper_state": "open", + } + + self._objects = { + "obj_1": { + "world_position": (10.5, 1.5), # Slot 1 position + "color": "blue", + # when picked up by the robot the obj will "disappear" from the vlm view + # when dropped the object will appear with different values + "picked_up": False, + "relative": (0.02, 0.1, 0.05), # relative to robot when at slot + }, + "obj_2": { + "world_position": (10.5, 3.0), # Slot 2 + "color": "red", + "picked_up": False, + "relative": (-0.2, 0.05, 0.05), + }, + "obj_3": { + "world_position": (10.5, 4.5), # Slot 3 + "color": "green", + "picked_up": False, + "relative": (0.1, 0.4, 0.05), + }, + "obj_4": { + "world_position": (10.5, 6.0), # Slot 4 + "color": "green", + "picked_up": False, + "relative": (0.15, -0.25, 0.05), + }, + } + + self._boxes = { + "box_1": { + "world_position": (3.0, 5.0), + "objects": [], # List of objects in this box + "relative": (0.2, 0, 0.05), # relative when robot is at box + }, + "box_2": { + "world_position": (5.0, 5.0), + "objects": [], + "relative": (0.1, -0.05, 0.05), + }, + } + + def get_position(self) -> Tuple[float, float]: + return self._state["robot_position"] + + def set_position(self, x: float, y: float): + self._state["robot_position"] = (x, y) + + def get_held_object(self) -> Optional[str]: + return self._state.get("held_object") + + def pick_up_object_at_position( + self, relative_pos: Tuple[float, float, float] + ) -> Optional[str]: + """Pick up object at relative position from current robot location""" + robot_x, robot_y = self.get_position() + + # Find object at the relative position + for obj, obj_data in self._objects.items(): + if not obj_data["picked_up"]: + # Check if this object is at the current location with matching relative position + if relative_pos == obj_data["relative"]: + # Check if robot is at the right slot for this object + if ( + abs(robot_x - obj_data["world_position"][0]) <= 0.5 + and abs(robot_y - obj_data["world_position"][1]) <= 0.5 + ): + obj_data["picked_up"] = True + self._state["held_object"] = obj + return obj + return None + + def drop_object_at_position(self, relative_pos: Tuple[float, float, float]) -> None: + """Drop held object at relative position from current robot location""" + # Check if placed in box, if yes, change env state + robot_x, robot_y = self.get_position() + # Find which box we're dropping into + for box_id, box_data in self._boxes.items(): + if relative_pos == box_data["relative"]: + # Check if robot is at the right position for this box + if ( + robot_x == box_data["world_position"][0] + and robot_y == box_data["world_position"][1] + ): + # Drop object into box + obj_id = self._state["held_object"] + box_data["objects"].append(obj_id) + + # Update object position to be in the box + self._objects[obj_id]["world_position"] = ( + box_data["world_position"][0] + relative_pos[0], + box_data["world_position"][1] + relative_pos[1], + ) + + self._state["held_object"] = None + + def get_visible_objects_at_position(self) -> List[Dict]: + """Get objects visible at current robot position""" + robot_x, robot_y = self.get_position() + visible_objects = [] + + # Check for objects at table slots + if abs(robot_x - 10.0) <= 0.5: # At sorting table + for obj_id, obj_data in self._objects.items(): + if not obj_data["picked_up"]: + obj_world_pos = obj_data["world_position"] + # Check if object is at current slot + expected_robot_y = obj_world_pos[1] - obj_data["relative"][1] + if abs(robot_y - expected_robot_y) <= 0.5: + visible_objects.append( + { + "id": obj_id, + "color": obj_data["color"], + "relative_position": obj_data["relative"], + } + ) + + return visible_objects + + def get_visible_boxes_at_position(self) -> List[Dict]: + """Get boxes visible at current robot position""" + robot_x, robot_y = self.get_position() + visible_boxes = [] + + # Check for boxes at storage rack + if 2 <= robot_x <= 6 and abs(robot_y - 5.5) <= 0.5: + for box_id, box_data in self._boxes.items(): + box_world_pos = box_data["world_position"] + if abs(robot_x - box_world_pos[0]) <= 0.5: + visible_boxes.append( + { + "id": box_id, + "relative_position": box_data["relative"], + "contents": [ + self._objects[obj_id]["color"] + for obj_id in box_data["objects"] + ], + } + ) + + return visible_boxes + + def get_state_summary(self) -> Dict: + """Get complete state for debugging""" + return { + "robot_position": self._state["robot_position"], + "gripper_state": self._state["gripper_state"], + "held_object": self._state.get("held_object"), + "objects": self._objects, + "boxes": self._boxes, + } + + +class SortingTask(Task): + complexity = "hard" + type = "warehouse" + + def __init__( + self, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + **kwargs: Any, + ) -> None: + if not validators: + # after every navigate call + # the where am i should probably be called? should it be mandatory? + # it is for now + # Should ask vlm be called after manipulaiton action? + # So robot can confirm if it pick or droppped object + where_am_i_subtask = CheckArgsToolCallSubTask( + expected_tool_name="where_am_i", + expected_args={}, # No parameters expected + ) + ask_vlm_subtask = CheckArgsToolCallSubTask( + expected_tool_name="ask_vlm", + expected_args={}, + ) + + #### navigate to table, detect and pick up object + navigate_to_slot1_subtask = CheckArgsToolCallSubTask( + expected_tool_name="nav_tool", + expected_args={ + "x": 10.0, + "y": 1.5, + }, + ) + pick_up_1_subtask = CheckArgsToolCallSubTask( + expected_tool_name="pick_up_object", + expected_args={"x": 0.02, "y": 0.1, "z": 0.05}, + ) + #### navigate to the box and drop object + navigate_to_box1_subtask = CheckArgsToolCallSubTask( + expected_tool_name="nav_tool", + expected_args={ + "x": 3.0, + "y": 5.0, + }, + ) + drop_subtask_1 = CheckArgsToolCallSubTask( + expected_tool_name="drop_object", + expected_args={"x": 0.2, "y": 0, "z": 0.05}, + ) + + #### navigate to the table and pick up second object + navigate_to_slot2_subtask = CheckArgsToolCallSubTask( + expected_tool_name="nav_tool", + expected_args={ + "x": 10.0, + "y": 3.0, + }, + ) + # there was no green or blue object so navigate to the next slot + navigate_to_slot3_subtask = CheckArgsToolCallSubTask( + expected_tool_name="nav_tool", + expected_args={ + "x": 10.0, + "y": 4.5, + }, + ) + pick_up_3_subtask = CheckArgsToolCallSubTask( + expected_tool_name="pick_up_object", + expected_args={"x": 0.1, "y": 0.4, "z": 0.05}, + ) + + #### navigate to the 2nd box and drop + navigate_to_box2_subtask = CheckArgsToolCallSubTask( + expected_tool_name="nav_tool", + expected_args={ + "x": 5.0, + "y": 5.0, + }, + ) + drop_subtask_2 = CheckArgsToolCallSubTask( + expected_tool_name="drop_object", + expected_args={"x": 0.1, "y": -0.05, "z": 0.05}, + ) + #### navigate to 4th slot and check for object, its empty so end the task + navigate_to_slot4_subtask = CheckArgsToolCallSubTask( + expected_tool_name="nav_tool", + expected_args={ + "x": 10.0, + "y": 3.0, + }, + ) + validators = [ + #### navigate to slot1, detect and pick up 1st object + OrderedCallsValidator( + subtasks=[ + navigate_to_slot1_subtask, + where_am_i_subtask, + ask_vlm_subtask, + pick_up_1_subtask, + ] + ), + #### navigate to the box1 and drop object + OrderedCallsValidator( + subtasks=[ + navigate_to_box1_subtask, + where_am_i_subtask, + ask_vlm_subtask, + drop_subtask_1, + ] + ), + #### navigate to slot2, detect - there is no blue or green obj + # so navigate to slot3, detect and pick up + OrderedCallsValidator( + subtasks=[ + navigate_to_slot2_subtask, + where_am_i_subtask, + ask_vlm_subtask, + navigate_to_slot3_subtask, + where_am_i_subtask, + ask_vlm_subtask, + pick_up_3_subtask, + ] + ), + #### navigate to the 2nd box and drop + OrderedCallsValidator( + subtasks=[ + navigate_to_box2_subtask, + where_am_i_subtask, + ask_vlm_subtask, + drop_subtask_2, + ] + ), + #### navigate to 4th slot and check for object, its empty so end the task + OrderedCallsValidator( + subtasks=[ + navigate_to_slot4_subtask, + where_am_i_subtask, + ask_vlm_subtask, + ] + ), + ] + super().__init__(validators=validators, task_args=task_args, **kwargs) + self.env_state = EnvStateManager() + + # define tools + @tool + def nav_tool(x: float, y: float): + """Navigate to certain coordinates in the warehouse.""" + self.env_state.set_position(x, y) + return ( + f"Navigating to x: {x}, y: {y} ...\n" + "Check you current position to ensure if movement was done properly" + ) + + @tool + def where_am_i() -> Dict[str, float]: + """Returns your current position""" + x, y = self.env_state.get_position() + return {"x": x, "y": y} + + @tool + def pick_up_object(x: float, y: float, z: float) -> str: + """Move gripper and close it to pick up object from a certain coordinates relative to you""" + held_obj = self.env_state.get_held_object() + if not held_obj: + obj_id = self.env_state.pick_up_object_at_position((x, y, z)) + if obj_id: + obj_color = self.env_state._objects[obj_id]["color"] + return f"Successfully picked up {obj_color} object ({obj_id}) at relative position x: {x}, y: {y}, z: {z}" + else: + return f"No object grabbed successfully at relative position x: {x}, y: {y}, z: {z}" + else: + return f"Can't perform pick up action as you are already holding an {held_obj} object." + + @tool + def drop_object(x: float, y: float, z: float) -> str: + """Move gripper and open it to drop object at a certain coordinates relative to you""" + held_obj = self.env_state.get_held_object() + if not held_obj: + return "Failed to drop - you are not holding any object." + else: + self.env_state.drop_object_at_position((x, y, z)) + return f"Successfully dropped object ({held_obj}) at relative position x: {x}, y: {y}, z: {z}" + + @tool + def ask_vlm() -> str: + """Ask VLM to detect objects at your current location and return their coordinates relative to you""" + visible_objects = self.env_state.get_visible_objects_at_position() + visible_boxes = self.env_state.get_visible_boxes_at_position() + + current_pos = self.env_state.get_position() + x, y = current_pos + + # Generate response based on what's actually visible + responses = [] + + if visible_objects: + for obj in visible_objects: + rel_pos = obj["relative_position"] + responses.append( + f"I see a {obj['color']} object at x: {rel_pos[0]}, y: {rel_pos[1]}, z: {rel_pos[2]} relative to you" + ) + + if visible_boxes: + for box in visible_boxes: + rel_pos = box["relative_position"] + box_num = "1" if "box_1" in box["id"] else "2" + contents_str = ( + f" (contains: {', '.join(box['contents'])} objects)" + if box["contents"] + else " (empty)" + ) + responses.append( + f"I see Box {box_num} at x: {rel_pos[0]}, y: {rel_pos[1]}, z: {rel_pos[2]} relative to you{contents_str}" + ) + + if not responses: + # Check what area we're in for context + if abs(x - 10.5) < 0.5: # At sorting table + slot_num = None + if 1 <= y <= 2: + slot_num = 1 + elif 2.5 <= y <= 3.5: + slot_num = 2 + elif 4 <= y <= 5: + slot_num = 3 + elif 5.5 <= y <= 6.5: + slot_num = 4 + + if slot_num: + return f"I see Slot {slot_num}, but it appears to be empty." + + elif 2 <= x <= 6 and abs(y - 5.5) < 0.5: # At storage rack + return "I see the storage rack area, but no objects or boxes are immediately visible from this position." + + return "I don't see any relevant objects here." + + return " ".join(responses) + + @tool + def get_slot_position(slot_id: str) -> Dict[str, float]: + """Returns the world position of a slot, pass in slot id as integer, for example 1""" + + x, y = self.env_state._slots[f"{slot_id}"]["world_position"] + return {"x": x, "y": y} + + @tool + def get_default_box_position() -> Dict[str, float]: + """Returns the world position of the default box""" + x, y = self.env_state._boxes["default"]["world_position"] + return {"x": x, "y": y} + + self.nav_tool = nav_tool + self.where_am_i = where_am_i + self.pick_up_object = pick_up_object + self.drop_object = drop_object + self.ask_vlm = ask_vlm + + self.nav_tool = nav_tool + self.where_am_i = where_am_i + self.pick_up_object = pick_up_object + self.drop_object = drop_object + self.ask_vlm = ask_vlm + self.get_slot_position = get_slot_position + self.get_default_box_position = get_default_box_position + + @property + def optional_tool_calls_number(self) -> int: + return 10 + + def get_base_prompt(self) -> str: + return ( + "Sort blue and green objects from slots to separate boxes on the rack. " + "Blue objects should go to the 1st box (x: 3.0, y: 5.0), green objects should go to the second box (x: 5.0, y: 5.0). " + "Check the slots in order. If you checked all of them and sorted all blue and green objects the task is done." + ) + + def get_prompt(self) -> str: + return self.get_base_prompt() + + def get_planning_prompt(self) -> str: + """ + Planning prompt help generate summary info for high level + task planning to undrestand the overall progress. + """ + return """ +Determine success and provide brief explanation of what happened by slot, +for example Slot 1: Object color: BLUE, actions: [NAVIGATED to SLOT->CHECKED OBJECTS->PICKED up OBJECT->NAVIGATED to BOX ->DROPPED OBJECT->COMPLETED]. +Mark a slot as COMPLETED only if object from this slot was dropped. +If the slot doesn't contain the right object, for example Slot 2: Object color: RED, actions: [NAVIGATED to SLOT->CHECKED OBJECTS->NOT THE RIGHT COLOR->NOTHING TO DO->COMPLETED]. +""" + + def manipulation_tools(self) -> List[BaseTool]: + return [ + self.pick_up_object, + self.drop_object, + self.ask_vlm, + ] + + def navigation_tools(self) -> List[BaseTool]: + return [ + self.nav_tool, + self.where_am_i, + ] + + @property + def available_tools(self) -> List[BaseTool]: + return [ + self.nav_tool, + self.where_am_i, + self.pick_up_object, + self.drop_object, + self.ask_vlm, + ] + + def get_system_prompt(self) -> str: + return SYSTEM_PROMPT + "\n" + WAREHOUSE_ENVIRONMENT_DESCRIPTION + + def report_sorting_status(self): + print("** Reporting sorting status") + self.env_state.report_sorting_status() diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index e1150082c..79e0e0f51 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -123,12 +123,12 @@ def parse_manipulation_o3de_benchmark_args(): return parser.parse_args() -def define_benchmark_logger(out_dir: Path) -> logging.Logger: +def define_benchmark_logger(out_dir: Path, level: int = logging.INFO) -> logging.Logger: log_file = out_dir / "benchmark.log" out_dir.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(log_file) - file_handler.setLevel(logging.INFO) + file_handler.setLevel(level) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) @@ -137,7 +137,7 @@ def define_benchmark_logger(out_dir: Path) -> logging.Logger: bench_logger = logging.getLogger("Benchmark logger") for handler in bench_logger.handlers: bench_logger.removeHandler(handler) - bench_logger.setLevel(logging.INFO) + bench_logger.setLevel(level) bench_logger.addHandler(file_handler) return bench_logger diff --git a/src/rai_core/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py index f439c7be8..ef4bc749e 100644 --- a/src/rai_core/rai/agents/__init__.py +++ b/src/rai_core/rai/agents/__init__.py @@ -13,7 +13,10 @@ # limitations under the License. from rai.agents.base import BaseAgent -from rai.agents.langchain import BaseStateBasedAgent, ReActAgent +from rai.agents.langchain import ( + BaseStateBasedAgent, + ReActAgent, +) from rai.agents.runner import AgentRunner, wait_for_shutdown __all__ = [ diff --git a/src/rai_core/rai/agents/langchain/core/__init__.py b/src/rai_core/rai/agents/langchain/core/__init__.py index ea6cb46bf..87c7b39bc 100644 --- a/src/rai_core/rai/agents/langchain/core/__init__.py +++ b/src/rai_core/rai/agents/langchain/core/__init__.py @@ -14,18 +14,23 @@ from .conversational_agent import State as ConversationalAgentState from .conversational_agent import create_conversational_agent +from .megamind import Executor, create_megamind, get_initial_megamind_state from .react_agent import ( ReActAgentState, create_react_runnable, ) from .state_based_agent import create_state_based_runnable -from .tool_runner import ToolRunner +from .tool_runner import SubAgentToolRunner, ToolRunner __all__ = [ "ConversationalAgentState", + "Executor", "ReActAgentState", + "SubAgentToolRunner", "ToolRunner", "create_conversational_agent", + "create_megamind", "create_react_runnable", "create_state_based_runnable", + "get_initial_megamind_state", ] diff --git a/src/rai_core/rai/agents/langchain/core/megamind.py b/src/rai_core/rai/agents/langchain/core/megamind.py new file mode 100644 index 000000000..68a5b40bb --- /dev/null +++ b/src/rai_core/rai/agents/langchain/core/megamind.py @@ -0,0 +1,305 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +### NOTE (jmatejcz) this agent is still in process of testing and refining +from dataclasses import dataclass +from functools import partial +from typing import ( + Annotated, + List, + Optional, +) + +from langchain.chat_models.base import BaseChatModel +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.tools import BaseTool, InjectedToolCallId, tool +from langgraph.graph import END, START, MessagesState, StateGraph +from langgraph.graph.state import CompiledStateGraph +from langgraph.prebuilt import create_react_agent +from langgraph.types import Command +from pydantic import BaseModel, Field + +from rai.agents.langchain.core.tool_runner import SubAgentToolRunner +from rai.messages import ( + HumanMultimodalMessage, +) + + +class StepSuccess(BaseModel): + """Output of success attacher""" + + success: bool = Field(description="Whether the task was completed successfully") + explanation: str = Field(description="Explanation of what happened") + + +class MegamindState(MessagesState): + original_task: str + steps_done: List[str] + step: Optional[str] + step_success: StepSuccess + step_messages: List[BaseMessage] + + +def llm_node( + llm: BaseChatModel, + system_prompt: Optional[str], + state: MegamindState, +) -> MegamindState: + """Process messages using the LLM - returns the agent's response.""" + messages = state["step_messages"].copy() + if not state["step"]: + raise ValueError("Step should be defined at this point") + if system_prompt: + messages.insert(0, HumanMessage(state["step"])) + messages.insert(0, SystemMessage(content=system_prompt)) + + ai_msg = llm.invoke(messages) + # append to both + state["step_messages"].append(ai_msg) + state["messages"].append(ai_msg) + return state + + +def analyzer_node( + llm: BaseChatModel, + planning_prompt: Optional[str], + state: MegamindState, +) -> MegamindState: + """Analyze the conversation and return structured output.""" + if not planning_prompt: + planning_prompt = "" + analyzer = llm.with_structured_output(StepSuccess) + analysis = analyzer.invoke( + [ + SystemMessage( + content=f""" +Analyze if this task was completed successfully: + +Task: {state["step"]} + +{planning_prompt} +Below you have messages of agent doing the task:""" + ), + *state["step_messages"], + ] + ) + state["step_success"] = StepSuccess( + success=analysis.success, explanation=analysis.explanation + ) + state["steps_done"].append(f"{state['step_success'].explanation}") + return state + + +def should_continue_or_structure(state: MegamindState) -> str: + """Decide whether to continue with tools or return structured output.""" + last_message = state["step_messages"][-1] + + # If AI message has tool calls, continue to tools + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + return "tools" + + # Otherwise, return structured output + return "structured_output" + + +def create_react_structured_agent( + llm: BaseChatModel, + tools: Optional[List[BaseTool]] = None, + system_prompt: Optional[str] = None, + planning_prompt: Optional[str] = None, +) -> CompiledStateGraph: + """Create a react agent that returns structured output.""" + + graph = StateGraph(MegamindState) + graph.add_edge(START, "llm") + + if tools: + tool_runner = SubAgentToolRunner(tools) + graph.add_node("tools", tool_runner) + + bound_llm = llm.bind_tools(tools) + graph.add_node("llm", partial(llm_node, bound_llm, system_prompt)) + + graph.add_node( + "structured_output", partial(analyzer_node, llm, planning_prompt) + ) + + graph.add_conditional_edges( + "llm", + should_continue_or_structure, + {"tools": "tools", "structured_output": "structured_output"}, + ) + graph.add_edge("tools", "llm") + graph.add_edge("structured_output", END) + else: + graph.add_node("llm", partial(llm_node, llm, system_prompt)) + graph.add_node( + "structured_output", partial(analyzer_node, llm, planning_prompt) + ) + graph.add_edge("llm", "structured_output") + graph.add_edge("structured_output", END) + + return graph.compile() + + +def create_handoff_tool(agent_name: str, description: str = None): + """Create a handoff tool for transferring tasks to specialist agents.""" + name = f"transfer_to_{agent_name}" + description = description or f" {agent_name} for help." + + @tool(name, description=description) + def handoff_tool( + task_instruction: str, # The specific task for the agent + tool_call_id: Annotated[str, InjectedToolCallId], + ) -> Command: + return Command( + goto=agent_name, + # Send only the task message to the specialist agent, not the full history + update={"step": task_instruction, "step_messages": []}, + graph=Command.PARENT, + ) + + return handoff_tool + + +@dataclass +class Executor: + name: str + llm: BaseChatModel + tools: List[BaseTool] + system_prompt: str + + +def get_initial_megamind_state(task: str): + return MegamindState( + { + "original_task": task, + "messages": [HumanMultimodalMessage(content=task)], + "step": "", + "steps_done": [], + "step_success": StepSuccess(success=False, explanation=""), + "step_messages": [], + } + ) + + +def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindState: + """Initial planning step.""" + if "original_task" not in state: + state["original_task"] = state["messages"][0].content[0]["text"] + if "steps_done" not in state: + state["steps_done"] = [] + if "step" not in state: + state["step"] = None + + megamind_prompt = f"You are given objective to complete: {state['original_task']}" + if state["steps_done"]: + megamind_prompt += "\n\n" + megamind_prompt += "Steps that were already done successfully:\n" + steps_done = "\n".join( + [f"{i + 1}. {step}" for i, step in enumerate(state["steps_done"])] + ) + megamind_prompt += steps_done + megamind_prompt += "\n" + + if state["step"]: + if not state["step_success"]: + raise ValueError("Step success should be specified at this point") + + megamind_prompt += "\nBased on that outcome and past steps come up with the next step and delegate it to selected agent." + + else: + megamind_prompt += "\n" + megamind_prompt += ( + "Come up with the fist step and delegate it to selected agent." + ) + + megamind_prompt += "\n\n" + megamind_prompt += ( + "When you decide that the objective is completed return response to user." + ) + messages = [ + HumanMultimodalMessage(content=megamind_prompt), + ] + # NOTE (jmatejcz) the response of megamind isnt appended to messages + # as Command from handoff instantly transitions to next node + megamind_agent.invoke({"messages": messages}) + return state + + +def create_megamind( + megamind_llm: BaseChatModel, + megamind_system_prompt: str, + executors: List[Executor], + task_planning_prompt: Optional[str] = None, +) -> CompiledStateGraph: + """Create a megamind langchain agent + + Args: + executors (List[Executor]): Subagents for megamind, each can be a specialist with + its own tools llm and system prompt + task_planning_prompt (Optional[str]): Prompt that helps summarize the step in a way + that helps planning task + """ + executor_agents = {} + handoff_tools = [] + for executor in executors: + executor_agents[executor.name] = create_react_structured_agent( + llm=executor.llm, + tools=executor.tools, + system_prompt=executor.system_prompt, + planning_prompt=task_planning_prompt, + ) + + handoff_tools.append( + create_handoff_tool( + agent_name=executor.name, + description=f"Assign task to {executor.name} agent.", + ) + ) + if not megamind_system_prompt: + # make a generic system prompt that list executors and their tools + specialists_info = [] + for executor in executors: + tool_names = [tool.name for tool in executor.tools] + tool_list = ", ".join(tool_names) + specialists_info.append(f"- {executor.name}: Available tools: {tool_list}") + + specialists_section = "\n".join(specialists_info) + megamind_system_prompt = f"""You manage specialists to whom you will delegate tasks to complete objective. +Available specialists and their capabilities: +{specialists_section} + +The single task should be delegated to only 1 agent and should be doable by only 1 agent.""" + + megamind_agent = create_react_agent( + megamind_llm, + tools=handoff_tools, + prompt=megamind_system_prompt, + name="megamind", + ) + + graph = StateGraph(MegamindState).add_node( + "megamind", partial(plan_step, megamind_agent) + ) + for agent_name, agent in executor_agents.items(): + graph.add_node(agent_name, agent) + graph.add_edge(agent_name, "megamind") + + graph.add_edge(START, "megamind") + return graph.compile() diff --git a/src/rai_core/rai/agents/langchain/core/plan_agent.py b/src/rai_core/rai/agents/langchain/core/plan_agent.py new file mode 100644 index 000000000..e01247034 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/core/plan_agent.py @@ -0,0 +1,275 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple, Union + +from langchain.chat_models.base import BaseChatModel +from langchain_core.messages import BaseMessage, SystemMessage +from langchain_core.tools import BaseTool +from langgraph.graph import END, START, StateGraph +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel, Field + +from rai.agents.langchain.core import ReActAgentState +from rai.agents.langchain.core.react_agent import create_react_runnable +from rai.initialization import get_llm_model +from rai.messages import HumanMultimodalMessage + + +class Plan(BaseModel): + """A plan to help solve a user request.""" + + steps: List[str] = Field( + description="different steps to follow, should be in sorted order" + ) + + +class Response(BaseModel): + """Response to user.""" + + response: str + + +class Act(BaseModel): + """Action to take.""" + + action: Union[Response, Plan] = Field( + description="Action to perform. If you want to respond to user, use Response. " + "If you need to further use tools to get the answer, use Plan." + ) + + +class PlanExecuteState(ReActAgentState): + """State for the plan and execute agent.""" + + # NOTE (jmatejcz) should original_task be replaced with + # passing first message? The message can contain images etc. + original_task: str + plan: List[str] + past_steps: List[Tuple[str, str]] + response: str + + +def should_end(state: PlanExecuteState) -> str: + """Check if we should end or continue planning.""" + if state["response"]: + return END + else: + return "agent" + + +def create_plan_execute_agent( + tools: List[BaseTool], + planner_llm: Optional[BaseChatModel] = None, + executor_llm: Optional[BaseChatModel] = None, + replanner_llm: Optional[BaseChatModel] = None, + system_prompt: Optional[str] = None, +) -> CompiledStateGraph: + """Create a plan and execute agent that can break down complex tasks into steps. + + Parameters + ---------- + tools : List[BaseTool] + List of tools the agent can use during execution + llm : Optional[BaseChatModel], default=None + Language model to use. If None, will use complex_model from config + system_prompt : Optional[str | SystemMultimodalMessage], default=None + System prompt to use (currently not used in this implementation) + + Returns + ------- + CompiledStateGraph + Compiled state graph for the plan and execute agent + + Raises + ------ + ValueError + If tools are not provided or invalid + """ + if planner_llm is None: + planner_llm = get_llm_model("complex_model", streaming=True) + if executor_llm is None: + executor_llm = get_llm_model("complex_model", streaming=True) + if replanner_llm is None: + replanner_llm = get_llm_model("complex_model", streaming=True) + + if not tools: + raise ValueError("Tools must be provided for plan and execute agent") + if system_prompt is None: + system_prompt = "" + + planner_prompt = """For the given objective, come up with a simple step by step plan. + +When creating your plan: +- Design each step to leverage the most appropriate tool from the list above +- Be specific about what information each step should gather or what action it should perform +- Frame steps as clear instructions that can be executed using the available tools +- Do NOT actually call or use any tools yourself - only create the plan +- Each step should be actionable and tool-appropriate + +This plan should involve individual tasks, that if executed correctly will yield the correct answer. +Do not add any superfluous steps. The result of the final step should be the final answer. +Make sure that each step has all the information needed - do not skip steps.""" + + agent_executor = create_react_runnable( + llm=executor_llm, system_prompt=system_prompt, tools=tools + ) + # the prompt will be filled with values when passed to invoke + planner_llm_with_tools = planner_llm.bind_tools(tools) + planner = planner_llm_with_tools.with_structured_output(Plan) # type: ignore + replanner = replanner_llm.with_structured_output(Act) # type: ignore + + def execute_step(state: PlanExecuteState): + """Execute the current step of the plan.""" + + plan = state["plan"] + if not plan: + return {} + task = plan[0] + task_formatted = f"""You are tasked with executing task: {task}.""" + + agent_response = agent_executor.invoke( + {"messages": [HumanMultimodalMessage(content=task_formatted)]}, + config={"recursion_limit": 50}, + ) + return { + "past_steps": [(task, agent_response["messages"][-1].content)], + } + + def plan_step(state: PlanExecuteState): + """Initial planning step.""" + messages = [ + SystemMessage(content=system_prompt + "\n" + planner_prompt), + HumanMultimodalMessage(content=state["original_task"]), + ] + plan = planner.invoke(messages) + return {"plan": plan.steps} + + def replan_step(state: PlanExecuteState): + """Replan based on execution results.""" + # Format past steps for the prompt + past_steps_str = "\n".join( + [ + f"{step}: {result}" + for i, (step, result) in enumerate(state["past_steps"]) + ] + ) + + # Format remaining plan + plan_str = "\n".join([step for i, step in enumerate(state["plan"])]) + + replanner_prompt = f"""For the given objective, come up with a simple step by step plan. +This plan should involve individual tasks, that if executed correctly will yield the correct answer. +Do not add any superfluous steps. The result of the final step should be the final answer. +Make sure that each step has all the information needed - do not skip steps. + +Your objective was this: +{state["original_task"]} + +Your current plan is: +{plan_str} + +You have currently done the following steps: +{past_steps_str} + +Update your plan accordingly if needed. If no more steps are needed and you can return to the user, then respond with that. Otherwise, fill out the plan. Only add steps to the plan that still NEED to be done. Do not return previously done steps as part of the plan.""" + + messages = [ + SystemMessage(content=system_prompt), + HumanMultimodalMessage(content=replanner_prompt), + ] + output = replanner.invoke(messages) + + if isinstance(output.action, Response): + return {"response": output.action.response} + else: + return {"plan": output.action.steps} + + workflow = StateGraph(PlanExecuteState) + + workflow.add_node("planner", plan_step) + workflow.add_node("agent", execute_step) + workflow.add_node("replan", replan_step) + + workflow.add_edge(START, "planner") + # From plan we go to agent + workflow.add_edge("planner", "agent") + # From agent, we replan + workflow.add_edge("agent", "replan") + + workflow.add_conditional_edges( + "replan", + should_end, + ["agent", END], + ) + + return workflow.compile() + + +def create_initial_plan_execute_state( + original_task: str, + messages: Optional[List[BaseMessage]] = None, +) -> PlanExecuteState: + """Create initial state for the plan and execute agent. + + Parameters + ---------- + input_text : str + The user's input/objective to accomplish + messages : Optional[List[BaseMessage]], default=None + Initial messages for the conversation + + Returns + ------- + PlanExecuteState + Initial state for the agent + """ + if messages is None: + messages = [] + + return PlanExecuteState( + messages=messages, + original_task=original_task, + plan=[], + past_steps=[], + response="", + ) + + +def run_plan_execute_agent( + agent: CompiledStateGraph, + original_task: str, + config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Run the plan and execute agent on a given input. + + Parameters + ---------- + agent : CompiledStateGraph + The compiled plan and execute agent + input_text : str + The user's input/objective + config : Optional[Dict[str, Any]], default=None + Configuration for the agent execution + + Returns + ------- + Dict[str, Any] + Final state after execution + """ + initial_state = create_initial_plan_execute_state(original_task) + + # Execute the agent + result = agent.invoke(initial_state, config=config) + + return result diff --git a/src/rai_core/rai/agents/langchain/core/react_agent.py b/src/rai_core/rai/agents/langchain/core/react_agent.py index d49fd4617..34424a84e 100644 --- a/src/rai_core/rai/agents/langchain/core/react_agent.py +++ b/src/rai_core/rai/agents/langchain/core/react_agent.py @@ -16,7 +16,6 @@ from typing import ( List, Optional, - TypedDict, cast, ) @@ -26,6 +25,7 @@ from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition +from typing_extensions import TypedDict from rai.agents.langchain.core.tool_runner import ToolRunner from rai.initialization import get_llm_model diff --git a/src/rai_core/rai/agents/langchain/core/tool_runner.py b/src/rai_core/rai/agents/langchain/core/tool_runner.py index 156f19dac..216748e2d 100644 --- a/src/rai_core/rai/agents/langchain/core/tool_runner.py +++ b/src/rai_core/rai/agents/langchain/core/tool_runner.py @@ -47,15 +47,25 @@ def __init__( tool_ = create_tool(tool_) self.tools_by_name[tool_.name] = tool_ + def get_messages(self, input: dict[str, Any]) -> List: + """Get fields from from input that will be processed.""" + return input.get("messages", []) + + def update_input_with_outputs( + self, input: dict[str, Any], outputs: List[Any] + ) -> None: + """Update input with tool outputs.""" + input["messages"].extend(outputs) + def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any: config["max_concurrency"] = ( 1 # TODO(maciejmajek): use better mechanism for task queueing ) - if messages := input.get("messages", []): - message = messages[-1] - else: + messages = self.get_messages(input) + if not messages: raise ValueError("No message found in input") + message = messages[-1] if not isinstance(message, AIMessage): raise ValueError("Last message is not an AIMessage") @@ -142,5 +152,19 @@ def run_one(call: ToolCall): # we sort the messages by type so that the tool messages are sent first # for more information see implementation of ToolMultimodalMessage.postprocess outputs.sort(key=lambda x: x.__class__.__name__, reverse=True) - input["messages"].extend(outputs) + + self.update_input_with_outputs(input, outputs) return input + + +class SubAgentToolRunner(ToolRunner): + """ToolRunner that works with 'step_messages' key used by subagents""" + + def get_messages(self, input: dict[str, Any]) -> List: + return input.get("step_messages", []) + + def update_input_with_outputs( + self, input: dict[str, Any], outputs: List[Any] + ) -> None: + input["messages"].extend(outputs) + input["step_messages"].extend(outputs) diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index a39ff2906..21c9ebc2b 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -173,9 +173,9 @@ class SimulationBridge(ABC): """ def __init__(self, logger: Optional[logging.Logger] = None): - self.spawned_entities: List[SpawnedEntity] = ( - [] - ) # list of spawned entities with their initial poses + self.spawned_entities: List[ + SpawnedEntity + ] = [] # list of spawned entities with their initial poses if logger is None: self.logger = logging.getLogger(__name__) else: From 269f4f6ff43d8d6024536c4342e8d7fd10df477a Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Mon, 15 Sep 2025 14:18:15 +0200 Subject: [PATCH 11/13] feat: megamind context providers (#687) --- .../rai/agents/langchain/core/__init__.py | 8 ++++- .../rai/agents/langchain/core/megamind.py | 33 +++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/rai_core/rai/agents/langchain/core/__init__.py b/src/rai_core/rai/agents/langchain/core/__init__.py index 87c7b39bc..9aade0321 100644 --- a/src/rai_core/rai/agents/langchain/core/__init__.py +++ b/src/rai_core/rai/agents/langchain/core/__init__.py @@ -14,7 +14,12 @@ from .conversational_agent import State as ConversationalAgentState from .conversational_agent import create_conversational_agent -from .megamind import Executor, create_megamind, get_initial_megamind_state +from .megamind import ( + ContextProvider, + Executor, + create_megamind, + get_initial_megamind_state, +) from .react_agent import ( ReActAgentState, create_react_runnable, @@ -23,6 +28,7 @@ from .tool_runner import SubAgentToolRunner, ToolRunner __all__ = [ + "ContextProvider", "ConversationalAgentState", "Executor", "ReActAgentState", diff --git a/src/rai_core/rai/agents/langchain/core/megamind.py b/src/rai_core/rai/agents/langchain/core/megamind.py index 68a5b40bb..f1b9d6af1 100644 --- a/src/rai_core/rai/agents/langchain/core/megamind.py +++ b/src/rai_core/rai/agents/langchain/core/megamind.py @@ -13,6 +13,7 @@ # limitations under the License. ### NOTE (jmatejcz) this agent is still in process of testing and refining +from abc import ABC, abstractmethod from dataclasses import dataclass from functools import partial from typing import ( @@ -185,6 +186,14 @@ class Executor: system_prompt: str +class ContextProvider(ABC): + """Context provider are meant to inject exteral info to megamind prompt""" + + @abstractmethod + def get_context(self) -> str: + pass + + def get_initial_megamind_state(task: str): return MegamindState( { @@ -198,7 +207,11 @@ def get_initial_megamind_state(task: str): ) -def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindState: +def plan_step( + megamind_agent: BaseChatModel, + state: MegamindState, + context_providers: Optional[List[ContextProvider]] = None, +) -> MegamindState: """Initial planning step.""" if "original_task" not in state: state["original_task"] = state["messages"][0].content[0]["text"] @@ -208,6 +221,9 @@ def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindSt state["step"] = None megamind_prompt = f"You are given objective to complete: {state['original_task']}" + for provider in context_providers: + megamind_prompt += provider.get_context() + megamind_prompt += "\n" if state["steps_done"]: megamind_prompt += "\n\n" megamind_prompt += "Steps that were already done successfully:\n" @@ -244,17 +260,27 @@ def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindSt def create_megamind( megamind_llm: BaseChatModel, - megamind_system_prompt: str, executors: List[Executor], + megamind_system_prompt: Optional[str] = None, task_planning_prompt: Optional[str] = None, + context_providers: List[ContextProvider] = [], ) -> CompiledStateGraph: """Create a megamind langchain agent Args: executors (List[Executor]): Subagents for megamind, each can be a specialist with its own tools llm and system prompt + + megamind_system_prompt (Optional[str]): Prompt for megamind node. If not provided + it will default to informing agent of the avaialble executors and listing their tools. + task_planning_prompt (Optional[str]): Prompt that helps summarize the step in a way that helps planning task + + context_providers (List[ContextProvider]): Each ContextProvider can inject external info + to prompt during planning phase + + """ executor_agents = {} handoff_tools = [] @@ -295,7 +321,8 @@ def create_megamind( ) graph = StateGraph(MegamindState).add_node( - "megamind", partial(plan_step, megamind_agent) + "megamind", + partial(plan_step, megamind_agent, context_providers=context_providers), ) for agent_name, agent in executor_agents.items(): graph.add_node(agent_name, agent) From ab73ba7cd94d25f808115509193d3ff156399d81 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Mon, 15 Sep 2025 15:13:14 +0200 Subject: [PATCH 12/13] feat: tool calling bench - manipulation tasks extenstion (#656) --- .../tool_calling_agent/mocked_tools.py | 3 +- .../predefined/manipulation_tasks.py | 86 ++- .../tool_calling_agent/tasks/manipulation.py | 391 ++++++++---- tests/rai_bench/conftest.py | 13 +- .../test_predefined_manipulation_tasks.py | 561 ++++++++++++++++++ 5 files changed, 926 insertions(+), 128 deletions(-) create mode 100644 tests/rai_bench/tool_calling_agent/test_predefined_manipulation_tasks.py diff --git a/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py b/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py index fc80d2a82..b33d3d66e 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py @@ -184,7 +184,8 @@ class MockGetObjectPositionsTool(GetObjectPositionsTool): mock_objects: dict[str, List[Point]] def _run(self, object_name: str) -> str: - """Method that returns a mock message with the object positions if the object_name is present in the mock_objects dictionary. + """Method that returns a mock message with the object positions + if the object_name is present in the mock_objects dictionary. Parameters ---------- diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py index 7e4068c8f..8ab9090fe 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Literal +from typing import Any, Dict, List, Literal from rai.tools.ros2 import MoveToPointToolInput from rai.types import Point @@ -21,33 +21,25 @@ Task, TaskArgs, ) -from rai_bench.tool_calling_agent.subtasks import ( - CheckArgsToolCallSubTask, -) from rai_bench.tool_calling_agent.tasks.manipulation import ( + AlignTwoObjectsTask, + GetObjectPositionsTask, + GrabExistingObjectTask, + MoveExistingObjectFrontTask, + MoveExistingObjectLeftTask, MoveToPointTask, ) -from rai_bench.tool_calling_agent.validators import ( - OrderedCallsValidator, -) -########## SUBTASKS ################################################################# -move_to_point_subtask_grab = CheckArgsToolCallSubTask( - expected_tool_name="move_to_point", - expected_args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}, -) -move_to_point_subtask_drop = CheckArgsToolCallSubTask( - expected_tool_name="move_to_point", - expected_args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}, -) +BANANA_POSITION = Point(x=0.1, y=0.2, z=0.3) +BANANA_POSITION_2 = Point(x=0.4, y=0.5, z=0.6) +CUBE_POSITION = Point(x=0.7, y=0.8, z=0.9) -######### VALIDATORS ######################################################################################### -move_to_point_ord_val_grab = OrderedCallsValidator( - subtasks=[move_to_point_subtask_grab] -) -move_to_point_ord_val_drop = OrderedCallsValidator( - subtasks=[move_to_point_subtask_drop] -) +BANANA_OBJECT = "banana" +CUBE_OBJECT = "cube" +APPLE_OBJECT = "apple" + +MOVE_TO_GRAB_COORDS: Dict[str, Any] = {"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"} +MOVE_TO_DROP_COORDS: Dict[str, Any] = {"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"} def get_manipulation_tasks( @@ -69,9 +61,15 @@ def get_manipulation_tasks( tasks: List[Task] = [] objects = { - "banana": [Point(x=0.1, y=0.2, z=0.3), Point(x=0.4, y=0.5, z=0.6)], - "cube": [Point(x=0.7, y=0.8, z=0.9)], + BANANA_OBJECT: [BANANA_POSITION], + CUBE_OBJECT: [CUBE_POSITION], + } + + objects_with_multiple_bananas = { + BANANA_OBJECT: [BANANA_POSITION, BANANA_POSITION_2], + CUBE_OBJECT: [CUBE_POSITION], } + for extra_calls in extra_tool_calls: for detail in prompt_detail: for shots in n_shots: @@ -80,6 +78,7 @@ def get_manipulation_tasks( prompt_detail=detail, examples_in_system_prompt=shots, ) + tasks.extend( [ MoveToPointTask( @@ -87,7 +86,6 @@ def get_manipulation_tasks( move_to_tool_input=MoveToPointToolInput( x=1.0, y=2.0, z=3.0, task="grab" ), - validators=[move_to_point_ord_val_grab], task_args=task_args, ), MoveToPointTask( @@ -95,9 +93,43 @@ def get_manipulation_tasks( move_to_tool_input=MoveToPointToolInput( x=1.2, y=2.3, z=3.4, task="drop" ), - validators=[move_to_point_ord_val_drop], task_args=task_args, ), + GetObjectPositionsTask( + objects=objects_with_multiple_bananas, + task_args=task_args, + ), + GrabExistingObjectTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ), + GrabExistingObjectTask( + objects=objects, + object_to_grab=BANANA_OBJECT, + task_args=task_args, + ), + MoveExistingObjectLeftTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ), + MoveExistingObjectLeftTask( + objects=objects, + object_to_grab=BANANA_OBJECT, + task_args=task_args, + ), + MoveExistingObjectFrontTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ), + MoveExistingObjectFrontTask( + objects=objects, + object_to_grab=BANANA_OBJECT, + task_args=task_args, + ), + AlignTwoObjectsTask(objects=objects, task_args=task_args), ] ) diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py index bf6bea985..6e1e7230f 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/tasks/manipulation.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any, Dict, List +import logging +from abc import ABC +from typing import Any, Dict, List, Optional import inflect from langchain_core.tools import BaseTool from rai.tools.ros2 import MoveToPointToolInput from rai.types import Point -from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator +from rai_bench.tool_calling_agent.interfaces import SubTask, Task, TaskArgs, Validator from rai_bench.tool_calling_agent.mocked_ros2_interfaces import ( COMMON_INTERFACES, COMMON_SERVICES_AND_TYPES, @@ -37,6 +37,14 @@ MockGetROS2TopicsNamesAndTypesTool, MockMoveToPointTool, ) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, +) +from rai_bench.tool_calling_agent.validators import ( + NotOrderedCallsValidator, + OneFromManyValidator, + OrderedCallsValidator, +) INTERFACES = COMMON_INTERFACES | MANIPULATION_INTERFACES TOPCIS_AND_TYPES = COMMON_TOPICS_AND_TYPES | MANIPULATION_TOPICS_AND_TYPES @@ -63,6 +71,8 @@ x - front to back (positive is forward) y - left to right (positive is right) z - up to down (positive is up). + + 1 unit in system is equal to 1 meter in real environment. """ PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT = ( @@ -76,12 +86,15 @@ PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_5_SHOT = ( PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT_2_SHOT + """ -- move_to_point, args: {'x': 1.7, 'y': 1.8, 'z': 1.9, 'task': 'drop'} -- move_to_point, args: {'x': 0.1, 'y': -0.2, 'z': 0.1, 'task': 'grab'} +- get_ros2_topics_names_and_types, args: {} +- get_ros2_message_interface, args: {'msg_type': 'moveit_msgs/srv/ExecuteKnownTrajectory'} - move_to_point, args: {'x': 0.7, 'y': 0.8, 'z': 0.9, 'task': 'drop'} """ ) +LEFT_DISTANCE = 0.2 # 20cm +FRONT_DISTANCE = 0.6 # 60cm + class TaskParametrizationError(Exception): """Exception raised when the task parameters are not valid.""" @@ -97,11 +110,13 @@ def __init__( objects: Dict[str, List[Point]], validators: List[Validator], task_args: TaskArgs, + logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: - super().__init__(validators=validators, task_args=task_args, **kwargs) + super().__init__( + validators=validators, task_args=task_args, logger=logger, **kwargs + ) self.objects = objects - self._verify_args() @property def optional_tool_calls_number(self) -> int: @@ -147,17 +162,29 @@ def __init__( object_to_grab: str, validators: List[Validator], task_args: TaskArgs, + logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: super().__init__( - validators=validators, objects=objects, task_args=task_args, **kwargs + validators=validators, + objects=objects, + task_args=task_args, + logger=logger, + **kwargs, ) self.object_to_grab = object_to_grab self._verify_args() - @abstractmethod - def _verify_args(self) -> None: - pass + def _verify_args(self): + if self.object_to_grab not in self.objects: + error_message = f"Requested object to grab {self.object_to_grab} is not present in defined objects: {self.objects}." + self.logger.error(msg=error_message) + raise TaskParametrizationError(error_message) + + if len(self.objects[self.object_to_grab]) > 1: + error_message = f"Requested object to grab {self.object_to_grab} has more than one position in defined objects: {self.objects[self.object_to_grab]}." + self.logger.error(msg=error_message) + raise TaskParametrizationError(error_message) class MoveToPointTask(ManipulationTask): @@ -167,14 +194,32 @@ def __init__( self, objects: Dict[str, List[Point]], move_to_tool_input: MoveToPointToolInput, - validators: List[Validator], task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: + self.move_to_tool_input = move_to_tool_input + + if validators is None: + move_to_point_subtask = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": move_to_tool_input.x, + "y": move_to_tool_input.y, + "z": move_to_tool_input.z, + "task": move_to_tool_input.task, + }, + ) + validators = [OrderedCallsValidator(subtasks=[move_to_point_subtask])] + super().__init__( - validators=validators, objects=objects, task_args=task_args, **kwargs + validators=validators, + objects=objects, + task_args=task_args, + logger=logger, + **kwargs, ) - self.move_to_tool_input = move_to_tool_input def get_base_prompt(self) -> str: return ( @@ -200,14 +245,29 @@ class GetObjectPositionsTask(ManipulationTask): def __init__( self, objects: Dict[str, List[Point]], - validators: List[Validator], task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: + if validators is None: + subtasks: List[SubTask] = [] + for obj_name in objects.keys(): + subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_object_positions", + expected_args={"object_name": obj_name}, + ) + subtasks.append(subtask) + + validators = [NotOrderedCallsValidator(subtasks=subtasks)] + super().__init__( - validators=validators, objects=objects, task_args=task_args, **kwargs + validators=validators, + objects=objects, + task_args=task_args, + logger=logger, + **kwargs, ) - self.objects = objects def get_base_prompt(self) -> str: inflector = inflect.engine() @@ -223,22 +283,61 @@ def get_base_prompt(self) -> str: else: objects_list = formatted_objects[0] - return f"Get the {objects_list} positions." + return ( + f"Get the {objects_list} positions. Object name should be in singular form." + ) def get_prompt(self) -> str: if self.prompt_detail == "brief": return self.get_base_prompt() else: return ( - f"{self.get_base_prompt()} " - "You can detect all objects and retrieve their 3D coordinates " - "for manipulation planning." + f"{self.get_base_prompt()} in the robotic workspace environment. " + "You can detect all objects and retrieve their 3D coordinates." ) class GrabExistingObjectTask(GrabTask): complexity = "medium" + def __init__( + self, + objects: Dict[str, List[Point]], + object_to_grab: str, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, + **kwargs: Any, + ) -> None: + if validators is None: + get_object_positions_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_object_positions", + expected_args={"object_name": object_to_grab}, + ) + grab_move_subtask = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": objects[object_to_grab][0].x, + "y": objects[object_to_grab][0].y, + "z": objects[object_to_grab][0].z, + "task": "grab", + }, + ) + validators = [ + OrderedCallsValidator( + subtasks=[get_object_positions_subtask, grab_move_subtask] + ) + ] + + super().__init__( + objects=objects, + object_to_grab=object_to_grab, + validators=validators, + task_args=task_args, + logger=logger, + **kwargs, + ) + def get_base_prompt(self) -> str: return f"Grab {self.object_to_grab}." @@ -252,43 +351,56 @@ def get_prompt(self) -> str: "to grab it at the correct coordinates." ) - def _verify_args(self): - if self.object_to_grab not in self.objects: - error_message = f"Requested object to grab {self.object_to_grab} is not present in defined objects: {self.objects}." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - if len(self.objects[self.object_to_grab]) > 1: - error_message = f"Requested object to grab {self.object_to_grab} has more than one position in defined objects: {self.objects[self.object_to_grab]}." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - - -class GrabNotExistingObjectTask(GrabTask): +class MoveExistingObjectLeftTask(GrabTask): complexity = "medium" - def get_base_prompt(self) -> str: - return f"Grab {self.object_to_grab}." - - def get_prompt(self) -> str: - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()} " - "You can check if the object exists in the environment and " - "attempt to grab it if found." + def __init__( + self, + objects: Dict[str, List[Point]], + object_to_grab: str, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, + **kwargs: Any, + ) -> None: + if validators is None: + get_object_positions_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_object_positions", + expected_args={"object_name": object_to_grab}, ) + grab_subtask = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": objects[object_to_grab][0].x, + "y": objects[object_to_grab][0].y, + "z": objects[object_to_grab][0].z, + "task": "grab", + }, + ) + drop_subtask = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": objects[object_to_grab][0].x, + "y": round(objects[object_to_grab][0].y - LEFT_DISTANCE, 2), + "z": objects[object_to_grab][0].z, + "task": "drop", + }, + ) + validators = [ + OrderedCallsValidator( + subtasks=[get_object_positions_subtask, grab_subtask, drop_subtask] + ) + ] - def _verify_args(self): - if self.object_to_grab in self.objects: - error_message = f"Requested object to grab {self.object_to_grab} is present in defined objects: {self.objects} but should not be." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - - -class MoveExistingObjectLeftTask(GrabTask): - complexity = "hard" + super().__init__( + objects=objects, + object_to_grab=object_to_grab, + validators=validators, + task_args=task_args, + logger=logger, + **kwargs, + ) def get_base_prompt(self) -> str: return f"Move {self.object_to_grab} 20 cm to the left." @@ -303,20 +415,56 @@ def get_prompt(self) -> str: "and move it to a position 20 cm to the left of its current location." ) - def _verify_args(self): - if self.object_to_grab not in self.objects: - error_message = f"Requested object to grab {self.object_to_grab} is not present in defined objects: {self.objects}." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - if len(self.objects[self.object_to_grab]) > 1: - error_message = f"Requested object to grab {self.object_to_grab} has more than one position in defined objects: {self.objects[self.object_to_grab]}." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) +class MoveExistingObjectFrontTask(GrabTask): + complexity = "medium" + def __init__( + self, + objects: Dict[str, List[Point]], + object_to_grab: str, + task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, + **kwargs: Any, + ) -> None: + if validators is None: + get_object_positions_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_object_positions", + expected_args={"object_name": object_to_grab}, + ) + grab_subtask = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": objects[object_to_grab][0].x, + "y": objects[object_to_grab][0].y, + "z": objects[object_to_grab][0].z, + "task": "grab", + }, + ) + drop_subtask = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": round(objects[object_to_grab][0].x + FRONT_DISTANCE, 2), + "y": objects[object_to_grab][0].y, + "z": objects[object_to_grab][0].z, + "task": "drop", + }, + ) + validators = [ + OrderedCallsValidator( + subtasks=[get_object_positions_subtask, grab_subtask, drop_subtask] + ) + ] -class MoveExistingObjectFrontTask(GrabTask): - complexity = "hard" + super().__init__( + objects=objects, + object_to_grab=object_to_grab, + validators=validators, + task_args=task_args, + logger=logger, + **kwargs, + ) def get_base_prompt(self) -> str: return f"Move {self.object_to_grab} 60 cm to the front." @@ -331,61 +479,106 @@ def get_prompt(self) -> str: "and move it to a position 60 cm forward from its current location." ) - def _verify_args(self): - if self.object_to_grab not in self.objects: - error_message = f"Requested object to grab {self.object_to_grab} is not present in defined objects: {self.objects}." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - - if len(self.objects[self.object_to_grab]) > 1: - error_message = f"Requested object to grab {self.object_to_grab} has more than one position in defined objects: {self.objects[self.object_to_grab]}." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - -class SwapObjectsTask(ManipulationTask): +class AlignTwoObjectsTask(ManipulationTask): complexity = "hard" def __init__( self, objects: Dict[str, List[Point]], - objects_to_swap: List[str], - validators: List[Validator], task_args: TaskArgs, + validators: Optional[List[Validator]] = None, + logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: + if validators is None: + # Get the two objects from the objects dict (first and second object) + object_names = list(objects.keys()) + obj1_name, obj2_name = object_names[0], object_names[1] + obj1_pos, obj2_pos = objects[obj1_name][0], objects[obj2_name][0] + + get_object_positions_subtask = CheckArgsToolCallSubTask( + expected_tool_name="get_object_positions", + expected_args={}, + ) + + # Two possible positions: 0.5 units to the right or left of obj2 + target_x_pos1 = round(obj2_pos.x + 0.5, 2) + target_x_pos2 = round(obj2_pos.x - 0.5, 2) + + grab_subtask = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": obj1_pos.x, + "y": obj1_pos.y, + "z": obj1_pos.z, + "task": "grab", + }, + ) + + # Create subtasks for dropping the first object at either valid position + drop_pos1 = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": target_x_pos1, + "y": obj1_pos.y, + "z": obj1_pos.z, + "task": "drop", + }, + ) + drop_pos2 = CheckArgsToolCallSubTask( + expected_tool_name="move_to_point", + expected_args={ + "x": target_x_pos2, + "y": obj1_pos.y, + "z": obj1_pos.z, + "task": "drop", + }, + ) + + val1 = OrderedCallsValidator( + subtasks=[ + get_object_positions_subtask, + grab_subtask, + ] + ) + val2 = OneFromManyValidator(subtasks=[drop_pos1, drop_pos2]) + validators = [val1, val2] super().__init__( - validators=validators, objects=objects, task_args=task_args, **kwargs + validators=validators, + objects=objects, + task_args=task_args, + logger=logger, + **kwargs, ) - self.objects = objects - self.objects_to_swap = objects_to_swap - self._verify_args() def get_base_prompt(self) -> str: - return f"Swap {self.objects_to_swap[0]} and {self.objects_to_swap[1]}." + object_names = list(self.objects.keys()) + return f"Move the first object ({object_names[0]}) so it is 50 cm apart from the second object ({object_names[1]}) along the x-axis." def get_prompt(self) -> str: if self.prompt_detail == "brief": return self.get_base_prompt() else: + object_names = list(self.objects.keys()) return ( f"{self.get_base_prompt()} " - "You can locate both objects in the workspace, then perform a sequence " - f"of grab and move operations to swap the positions of {self.objects_to_swap[0]} " - f"and {self.objects_to_swap[1]}." + f"You can locate both objects, grab the first object ({object_names[0]}) with the manipulator, " + f"and position it so that the distance between {object_names[0]} and {object_names[1]} along the x-axis is exactly 50 cm (0.5 units). " + f"You can move {object_names[0]} to either side of {object_names[1]} to achieve this distance." ) - def _verify_args(self): - for obj in self.objects_to_swap: - if obj not in self.objects: - error_message = f"Requested object to swap {obj} is not present in defined objects: {self.objects}." - self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - if len(self.objects[obj]) != 1: - error_message = f"Number of positions for object to swap ({obj}) should be equal to 1." + def _verify_args(self) -> None: + if len(self.objects) < 2: + error_message = f"AlignTwoObjectsTask requires at least 2 objects, but got {len(self.objects)}: {list(self.objects.keys())}" + if self.logger: self.logger.error(msg=error_message) - raise TaskParametrizationError(error_message) - if len(self.objects_to_swap) != 2: - error_message = f"Number of requested objects to swap {len(self.objects_to_swap)} should be equal to 2." - self.logger.error(msg=error_message) raise TaskParametrizationError(error_message) + + # Verify that objects are different so they can be distinguished + for obj_name, positions in self.objects.items(): + if len(positions) != 1: + error_message = f"Object {obj_name} must have exactly 1 position, but got {len(positions)}: {positions}" + if self.logger: + self.logger.error(msg=error_message) + raise TaskParametrizationError(error_message) diff --git a/tests/rai_bench/conftest.py b/tests/rai_bench/conftest.py index 3f369fd5c..da4765530 100644 --- a/tests/rai_bench/conftest.py +++ b/tests/rai_bench/conftest.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import pytest from rai.types import Header, Point, Pose, PoseStamped, Quaternion +from rai_bench.tool_calling_agent.interfaces import TaskArgs from rai_sim.simulation_bridge import Entity @@ -35,3 +36,13 @@ def create_entity( header=Header(frame_id="/test_frame"), ), ) + + +@pytest.fixture +def task_args() -> TaskArgs: + """Create basic task arguments for testing.""" + return TaskArgs( + extra_tool_calls=0, + prompt_detail="brief", + examples_in_system_prompt=0, + ) diff --git a/tests/rai_bench/tool_calling_agent/test_predefined_manipulation_tasks.py b/tests/rai_bench/tool_calling_agent/test_predefined_manipulation_tasks.py new file mode 100644 index 000000000..8ffe01271 --- /dev/null +++ b/tests/rai_bench/tool_calling_agent/test_predefined_manipulation_tasks.py @@ -0,0 +1,561 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List + +import pytest +from rai.tools.ros2 import MoveToPointToolInput + +from rai_bench.tool_calling_agent.interfaces import TaskArgs +from rai_bench.tool_calling_agent.predefined.manipulation_tasks import ( + BANANA_OBJECT, + BANANA_POSITION, + CUBE_OBJECT, + CUBE_POSITION, + MOVE_TO_DROP_COORDS, + MOVE_TO_GRAB_COORDS, +) +from rai_bench.tool_calling_agent.tasks.manipulation import ( + FRONT_DISTANCE, + LEFT_DISTANCE, + GetObjectPositionsTask, + GrabExistingObjectTask, + MoveExistingObjectFrontTask, + MoveExistingObjectLeftTask, + MoveToPointTask, +) + + +@pytest.fixture +def objects() -> Dict[str, Any]: + """Create test objects for manipulation tasks.""" + return { + BANANA_OBJECT: [BANANA_POSITION], + CUBE_OBJECT: [CUBE_POSITION], + } + + +class TestMoveToPointTask: + """Test MoveToPointTask validation.""" + + def test_move_to_point_grab_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "move_to_point", + "args": MOVE_TO_GRAB_COORDS, + } + ] + + task = MoveToPointTask( + objects=objects, + move_to_tool_input=MoveToPointToolInput(x=1.0, y=2.0, z=3.0, task="grab"), + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_move_to_point_drop_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "move_to_point", + "args": MOVE_TO_DROP_COORDS, + } + ] + + task = MoveToPointTask( + objects=objects, + move_to_tool_input=MoveToPointToolInput(x=1.2, y=2.3, z=3.4, task="drop"), + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_move_to_point_wrong_coordinates( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "move_to_point", + "args": { + "x": 0.0, + "y": 0.0, + "z": 0.0, + "task": "grab", + }, # Wrong coordinates + } + ] + + task = MoveToPointTask( + objects=objects, + move_to_tool_input=MoveToPointToolInput(x=1.0, y=2.0, z=3.0, task="grab"), + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_move_to_point_wrong_task( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "move_to_point", + "args": {"x": 1.0, "y": 2.0, "z": 3.0, "task": "drop"}, # Wrong task + } + ] + + task = MoveToPointTask( + objects=objects, + move_to_tool_input=MoveToPointToolInput(x=1.0, y=2.0, z=3.0, task="grab"), + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_move_to_point_wrong_tool( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "wrong_tool_name", + "args": MOVE_TO_GRAB_COORDS, + } + ] + + task = MoveToPointTask( + objects=objects, + move_to_tool_input=MoveToPointToolInput(x=1.0, y=2.0, z=3.0, task="grab"), + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGetObjectPositionsTask: + """Test GetObjectPositionsTask validation.""" + + def test_get_object_positions_valid_with_object_name( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": BANANA_OBJECT}}, + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + ] + + task = GetObjectPositionsTask( + objects=objects, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_get_object_positions_missing_object( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": BANANA_OBJECT}}, + {"name": "get_object_positions", "args": {"object_name": BANANA_OBJECT}}, + ] + + task = GetObjectPositionsTask( + objects=objects, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_object_positions_wrong_tool( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [{"name": "wrong_tool_name", "args": {}}] + + task = GetObjectPositionsTask( + objects=objects, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_get_object_positions_unexpected_args( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"unexpected": "arg"}} + ] + + task = GetObjectPositionsTask( + objects=objects, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestGrabExistingObjectTask: + """Test GrabExistingObjectTask validation.""" + + def test_grab_cube_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, + }, + ] + + task = GrabExistingObjectTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_grab_banana_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": BANANA_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": BANANA_POSITION.x, + "y": BANANA_POSITION.y, + "z": BANANA_POSITION.z, + "task": "grab", + }, + }, + ] + + task = GrabExistingObjectTask( + objects=objects, + object_to_grab=BANANA_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_grab_wrong_coordinates( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": 0.0, + "y": 0.0, + "z": 0.0, + "task": "grab", + }, # Wrong coordinates + }, + ] + + task = GrabExistingObjectTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_grab_missing_get_positions( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, + }, + ] + + task = GrabExistingObjectTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_grab_wrong_order( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, + }, + { + "name": "get_object_positions", + "args": {"object_name": CUBE_OBJECT}, + }, # Wrong order + ] + + task = GrabExistingObjectTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestMoveExistingObjectLeftTask: + """Test MoveExistingObjectLeftTask validation.""" + + def test_move_cube_left_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, + }, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": round(CUBE_POSITION.y - LEFT_DISTANCE, 2), + "z": CUBE_POSITION.z, + "task": "drop", + }, + }, + ] + + task = MoveExistingObjectLeftTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_move_banana_left_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": BANANA_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": BANANA_POSITION.x, + "y": BANANA_POSITION.y, + "z": BANANA_POSITION.z, + "task": "grab", + }, + }, + { + "name": "move_to_point", + "args": { + "x": BANANA_POSITION.x, + "y": round(BANANA_POSITION.y - LEFT_DISTANCE, 2), + "z": BANANA_POSITION.z, + "task": "drop", + }, + }, + ] + + task = MoveExistingObjectLeftTask( + objects=objects, + object_to_grab=BANANA_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_move_left_wrong_target_position( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, + }, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "drop", + }, # Same position, not left + }, + ] + + task = MoveExistingObjectLeftTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + def test_move_left_missing_drop( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, # missing drop + }, + ] + + task = MoveExistingObjectLeftTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 + + +class TestMoveExistingObjectFrontTask: + """Test MoveExistingObjectFrontTask validation.""" + + def test_move_cube_front_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, + }, + { + "name": "move_to_point", + "args": { + "x": round(CUBE_POSITION.x + FRONT_DISTANCE, 2), + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "drop", + }, + }, + ] + + task = MoveExistingObjectFrontTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_move_banana_front_valid( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": BANANA_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": BANANA_POSITION.x, + "y": BANANA_POSITION.y, + "z": BANANA_POSITION.z, + "task": "grab", + }, + }, + { + "name": "move_to_point", + "args": { + "x": round(BANANA_POSITION.x + FRONT_DISTANCE, 2), + "y": BANANA_POSITION.y, + "z": BANANA_POSITION.z, + "task": "drop", + }, + }, + ] + + task = MoveExistingObjectFrontTask( + objects=objects, + object_to_grab=BANANA_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 1.0 + + def test_move_front_wrong_direction( + self, task_args: TaskArgs, objects: Dict[str, Any] + ) -> None: + tool_calls: List[Dict[str, Any]] = [ + {"name": "get_object_positions", "args": {"object_name": CUBE_OBJECT}}, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "grab", + }, + }, + { + "name": "move_to_point", + "args": { + "x": CUBE_POSITION.x - FRONT_DISTANCE, + "y": CUBE_POSITION.y, + "z": CUBE_POSITION.z, + "task": "drop", + }, # Wrong direction (back instead of front) + }, + ] + + task = MoveExistingObjectFrontTask( + objects=objects, + object_to_grab=CUBE_OBJECT, + task_args=task_args, + ) + score = task.validate(tool_calls) + assert score == 0.0 From 219d6f54df549a72a7909d701fa1ade62fad1181 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:13:36 +0200 Subject: [PATCH 13/13] chore: resolving conflicts (#690) Co-authored-by: Julia Jia Co-authored-by: Magdalena Kotynia Co-authored-by: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> Co-authored-by: Pawel Kotowski Co-authored-by: Brian Tuan --- docs/simulation_and_benchmarking/rai_bench.md | 22 +- docs/tutorials/benchmarking.md | 3 +- src/rai_bench/rai_bench/__init__.py | 4 +- src/rai_bench/rai_bench/agents.py | 123 --------- .../docs/tool_calling_agent_benchmark.md | 2 +- .../rai_bench/examples/benchmarking_models.py | 11 +- .../rai_bench/examples/dual_agent.py | 53 ---- .../rai_bench/examples/vlm_benchmark.py | 48 ++++ .../rai_bench/manipulation_o3de/__init__.py | 4 +- .../rai_bench/manipulation_o3de/benchmark.py | 59 +--- .../langfuse_scores_tracing.py | 4 +- src/rai_bench/rai_bench/test_models.py | 78 +----- .../rai_bench/tool_calling_agent/__init__.py | 4 +- .../rai_bench/tool_calling_agent/benchmark.py | 55 +--- .../tool_calling_agent/predefined/__init__.py | 2 - .../predefined/spatial_reasoning_tasks.py | 254 ------------------ .../tool_calling_agent/predefined/tasks.py | 11 +- .../tool_calling_agent/tasks/spatial.py | 176 ------------ src/rai_bench/rai_bench/utils.py | 55 ++-- .../rai_bench/vlm_benchmark/__init__.py | 18 ++ .../rai_bench/vlm_benchmark/benchmark.py | 211 +++++++++++++++ .../rai_bench/vlm_benchmark/interfaces.py | 174 ++++++++++++ .../predefined/images/image_1.jpg | Bin .../predefined/images/image_2.jpg | Bin .../predefined/images/image_3.jpg | Bin .../predefined/images/image_4.jpg | Bin .../predefined/images/image_5.jpg | Bin .../predefined/images/image_6.jpg | Bin .../predefined/images/image_7.jpg | Bin .../vlm_benchmark/predefined/tasks.py | 112 ++++++++ .../vlm_benchmark/results_tracking.py | 35 +++ .../rai_bench/vlm_benchmark/tasks/tasks.py | 76 ++++++ src/rai_core/pyproject.toml | 2 +- src/rai_core/rai/agents/langchain/__init__.py | 2 + .../rai/agents/langchain/core/__init__.py | 2 + .../langchain/core/conversational_agent.py | 12 +- .../rai/agents/langchain/core/react_agent.py | 10 +- .../langchain/core/state_based_agent.py | 10 +- .../langchain/core/structured_output_agent.py | 60 +++++ .../agents/langchain/invocation_helpers.py | 76 ++++++ .../rai/communication/hri_connector.py | 3 +- .../rai/communication/ros2/api/service.py | 67 +++-- .../ros2/connectors/service_mixin.py | 5 + .../initialization/model_initialization.py | 13 +- src/rai_core/rai/messages/multimodal.py | 4 - .../rai/tools/ros2/navigation/__init__.py | 2 + .../tools/ros2/navigation/nav2_blocking.py | 69 +++++ .../agents/langchain/test_langchain_agent.py | 124 +++++++++ tests/communication/ros2/test_api.py | 36 ++- tests/conftest.py | 129 +++++++++ tests/initialization/test_tracing.py | 73 +++++ tests/messages/test_multimodal_message.py | 37 +++ 52 files changed, 1448 insertions(+), 882 deletions(-) delete mode 100644 src/rai_bench/rai_bench/agents.py delete mode 100644 src/rai_bench/rai_bench/examples/dual_agent.py create mode 100644 src/rai_bench/rai_bench/examples/vlm_benchmark.py delete mode 100644 src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py delete mode 100644 src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/__init__.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/benchmark.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/interfaces.py rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_1.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_2.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_3.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_4.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_5.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_6.jpg (100%) rename src/rai_bench/rai_bench/{tool_calling_agent => vlm_benchmark}/predefined/images/image_7.jpg (100%) create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py create mode 100644 src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py create mode 100644 src/rai_core/rai/agents/langchain/core/structured_output_agent.py create mode 100644 src/rai_core/rai/agents/langchain/invocation_helpers.py create mode 100644 src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py create mode 100644 tests/initialization/test_tracing.py create mode 100644 tests/messages/test_multimodal_message.py diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index d3072a783..24c033a1a 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -6,6 +6,7 @@ RAI Bench is a comprehensive package that both provides benchmarks with ready-to - [Manipulation O3DE Benchmark](#manipulation-o3de-benchmark) - [Tool Calling Agent Benchmark](#tool-calling-agent-benchmark) +- [VLM Benchmark](#vlm-benchmark) ## Manipulation O3DE Benchmark @@ -94,9 +95,9 @@ Evaluates agent performance independently from any simulation, based only on too The `SubTask` class is used to validate just one tool call. Following classes are available: - `CheckArgsToolCallSubTask` - verify if a certain tool was called with expected arguments -- `CheckTopicFieldsToolCallSubTask` - verify if a message published to ROS 2topic was of proper type and included expected fields -- `CheckServiceFieldsToolCallSubTask` - verify if a message published to ROS 2service was of proper type and included expected fields -- `CheckActionFieldsToolCallSubTask` - verify if a message published to ROS 2action was of proper type and included expected fields +- `CheckTopicFieldsToolCallSubTask` - verify if a message published to ROS2 topic was of proper type and included expected fields +- `CheckServiceFieldsToolCallSubTask` - verify if a message published to ROS2 service was of proper type and included expected fields +- `CheckActionFieldsToolCallSubTask` - verify if a message published to ROS2 action was of proper type and included expected fields ### Validator @@ -129,7 +130,6 @@ The ToolCallingAgentBenchmark class manages the execution of tasks and collects There are predefined Tasks available which are grouped by categories: - Basic - require retrieving info from certain topics -- Spatial reasoning - questions about surroundings with images attached - Manipulation - Custom Interfaces - requires using messages with custom interfaces @@ -164,3 +164,17 @@ class TaskArgs(BaseModel): - `GetROS2RGBCameraTask` has 1 required tool call and 1 optional. When `extra_tool_calls` set to 5, agent can correct himself couple times and still pass even with 7 tool calls. There can be 2 types of invalid tool calls, first when the tool is used incorrectly and agent receives an error - this allows him to correct himself easier. Second type is when tool is called properly but it is not the tool that should be called or it is called with wrong params. In this case agent won't get any error so it will be harder for him to correct, but BOTH of these cases are counted as `extra tool call`. If you want to know details about every task, visit `rai_bench/tool_calling_agent/tasks` + +## VLM Benchmark + +The VLM Benchmark is a benchmark for VLM models. It includes a set of tasks containing questions related to images and evaluates the performance of the agent that returns the answer in the structured format. + +### Running + +To run the benchmark: + +```bash +cd rai +source setup_shell.sh +python src/rai_bench/rai_bench/examples/vlm_benchmark.py --model-name gemma3:4b --vendor ollama +``` diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index 85dfca688..ccb933a1c 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -73,7 +73,6 @@ if __name__ == "__main__": extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", - "spatial_reasoning", "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt @@ -95,7 +94,7 @@ if __name__ == "__main__": ) ``` -Based on the example above the `Tool Calling` benchmark will run basic, spatial_reasoning and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. +Based on the example above the `Tool Calling` benchmark will run basic and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. !!! note diff --git a/src/rai_bench/rai_bench/__init__.py b/src/rai_bench/rai_bench/__init__.py index 395b5cf9a..5dd4df1e4 100644 --- a/src/rai_bench/rai_bench/__init__.py +++ b/src/rai_bench/rai_bench/__init__.py @@ -14,7 +14,6 @@ from .test_models import ( ManipulationO3DEBenchmarkConfig, ToolCallingAgentBenchmarkConfig, - test_dual_agents, test_models, ) from .utils import ( @@ -22,6 +21,7 @@ get_llm_for_benchmark, parse_manipulation_o3de_benchmark_args, parse_tool_calling_benchmark_args, + parse_vlm_benchmark_args, ) __all__ = [ @@ -31,6 +31,6 @@ "get_llm_for_benchmark", "parse_manipulation_o3de_benchmark_args", "parse_tool_calling_benchmark_args", - "test_dual_agents", + "parse_vlm_benchmark_args", "test_models", ] diff --git a/src/rai_bench/rai_bench/agents.py b/src/rai_bench/rai_bench/agents.py deleted file mode 100644 index a38e4e9d0..000000000 --- a/src/rai_bench/rai_bench/agents.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from functools import partial -from typing import List, Optional - -from langchain.chat_models.base import BaseChatModel -from langchain_core.messages import ( - AIMessage, - BaseMessage, - HumanMessage, -) -from langchain_core.tools import BaseTool -from langgraph.graph import START, StateGraph -from langgraph.graph.state import CompiledStateGraph -from langgraph.prebuilt.tool_node import tools_condition -from rai.agents.langchain.core.conversational_agent import State, agent -from rai.agents.langchain.core.tool_runner import ToolRunner - - -def multimodal_to_tool_bridge(state: State): - """Node of langchain workflow designed to bridge - nodes with llms. Removing images for context - """ - - cleaned_messages: List[BaseMessage] = [] - for msg in state["messages"]: - if isinstance(msg, HumanMessage): - # Remove images but keep the direct request - if isinstance(msg.content, list): - # Extract text only - text_parts = [ - part.get("text", "") - for part in msg.content - if isinstance(part, dict) and part.get("type") == "text" - ] - if text_parts: - cleaned_messages.append(HumanMessage(content=" ".join(text_parts))) - else: - cleaned_messages.append(msg) - elif isinstance(msg, AIMessage): - # Keep AI messages for context - cleaned_messages.append(msg) - - state["messages"] = cleaned_messages - return state - - -def create_multimodal_to_tool_agent( - multimodal_llm: BaseChatModel, - tool_llm: BaseChatModel, - tools: List[BaseTool], - multimodal_system_prompt: str, - tool_system_prompt: str, - logger: Optional[logging.Logger] = None, - debug: bool = False, -) -> CompiledStateGraph: - """ - Creates an agent flow where inputs first go to a multimodal LLM, - then its output is passed to a tool-calling LLM. - Can be usefull when multimodal llm does not provide tool calling. - - Args: - tools: List of tools available to the tool agent - - Returns: - Compiled state graph - """ - _logger = None - if logger: - _logger = logger - else: - _logger = logging.getLogger(__name__) - - _logger.info("Creating multimodal to tool agent flow") - - tool_llm_with_tools = tool_llm.bind_tools(tools) - tool_node = ToolRunner(tools=tools, logger=_logger) - - workflow = StateGraph(State) - workflow.add_node( - "thinker", - partial(agent, multimodal_llm, _logger, multimodal_system_prompt), - ) - # context bridge for altering the - workflow.add_node( - "context_bridge", - multimodal_to_tool_bridge, - ) - workflow.add_node( - "tool_agent", - partial(agent, tool_llm_with_tools, _logger, tool_system_prompt), - ) - workflow.add_node("tools", tool_node) - - workflow.add_edge(START, "thinker") - workflow.add_edge("thinker", "context_bridge") - workflow.add_edge("context_bridge", "tool_agent") - - workflow.add_conditional_edges( - "tool_agent", - tools_condition, - ) - - # Tool node goes back to tool agent - workflow.add_edge("tools", "tool_agent") - - app = workflow.compile(debug=debug) - _logger.info("Multimodal to tool agent flow created") - return app diff --git a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md index 0eb6664b5..0f287d483 100644 --- a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md +++ b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md @@ -14,4 +14,4 @@ Implementations can be found: - Validators [Validators](../tool_calling_agent/validators.py) - Subtasks [Validators](../tool_calling_agent/tasks/subtasks.py) -- Tasks, including basic, spatial, custom interfaces and manipulation [Tasks](../tool_calling_agent/tasks/) +- Tasks, including basic, custom interfaces and manipulation [Tasks](../tool_calling_agent/tasks/) diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index 2a0cc188c..43cdb3099 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -20,7 +20,7 @@ if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen3:4b", "llama3.2:3b"] + model_names = ["qwen2.5:3b", "llama3.2:3b"] vendors = ["ollama", "ollama"] # Define benchmarks that will be used @@ -36,7 +36,7 @@ extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", - "spatial_reasoning", + "manipulation", "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt @@ -48,11 +48,6 @@ test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[mani_conf, tool_conf], + benchmark_configs=[tool_conf, mani_conf], out_dir=out_dir, - # if you want to pass any additinal args to model - additional_model_args=[ - {"reasoning": False}, - {}, - ], ) diff --git a/src/rai_bench/rai_bench/examples/dual_agent.py b/src/rai_bench/rai_bench/examples/dual_agent.py deleted file mode 100644 index cf38cc8d9..000000000 --- a/src/rai_bench/rai_bench/examples/dual_agent.py +++ /dev/null @@ -1,53 +0,0 @@ -# # Copyright (C) 2025 Robotec.AI -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -from langchain_community.chat_models import ChatOllama -from langchain_openai import ChatOpenAI - -from rai_bench import ( - ManipulationO3DEBenchmarkConfig, - ToolCallingAgentBenchmarkConfig, - test_dual_agents, -) - -if __name__ == "__main__": - # Define models you want to benchmark - model_name = "gemma3:4b" - m_llm = ChatOllama( - model=model_name, base_url="http://localhost:11434", keep_alive=30 - ) - - tool_llm = ChatOpenAI(model="gpt-4o-mini", base_url="https://api.openai.com/v1/") - # Define benchmarks that will be used - tool_conf = ToolCallingAgentBenchmarkConfig( - extra_tool_calls=0, # how many extra tool calls allowed to still pass - task_types=["spatial_reasoning"], - repeats=15, - ) - - man_conf = ManipulationO3DEBenchmarkConfig( - o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config - levels=[ # define what difficulty of tasks to include in benchmark - "trivial", - ], - repeats=1, # how many times to repeat - ) - - out_dir = "src/rai_bench/rai_bench/experiments/dual_agents/" - - test_dual_agents( - multimodal_llms=[m_llm], - tool_calling_models=[tool_llm], - benchmark_configs=[man_conf, tool_conf], - out_dir=out_dir, - ) diff --git a/src/rai_bench/rai_bench/examples/vlm_benchmark.py b/src/rai_bench/rai_bench/examples/vlm_benchmark.py new file mode 100644 index 000000000..65f8526a6 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/vlm_benchmark.py @@ -0,0 +1,48 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from rai_bench import ( + define_benchmark_logger, + parse_vlm_benchmark_args, +) +from rai_bench.utils import get_llm_for_benchmark +from rai_bench.vlm_benchmark import get_spatial_tasks, run_benchmark + +if __name__ == "__main__": + args = parse_vlm_benchmark_args() + experiment_dir = Path(args.out_dir) + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + try: + tasks = get_spatial_tasks() + for task in tasks: + task.set_logger(bench_logger) + + llm = get_llm_for_benchmark( + model_name=args.model_name, + vendor=args.vendor, + ) + run_benchmark( + llm=llm, + out_dir=experiment_dir, + tasks=tasks, + bench_logger=bench_logger, + ) + except Exception as e: + bench_logger.critical( + msg=f"Benchmark failed with error: {e}", + exc_info=True, + ) diff --git a/src/rai_bench/rai_bench/manipulation_o3de/__init__.py b/src/rai_bench/rai_bench/manipulation_o3de/__init__.py index 206300fff..8ada5037a 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/__init__.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .benchmark import run_benchmark, run_benchmark_dual_agent +from .benchmark import run_benchmark from .predefined.scenarios import get_scenarios -__all__ = ["get_scenarios", "run_benchmark", "run_benchmark_dual_agent"] +__all__ = ["get_scenarios", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index bb4730ad8..8bf9421d5 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -16,7 +16,7 @@ import time import uuid from pathlib import Path -from typing import List, Optional, TypeVar +from typing import List, TypeVar import rclpy from langchain.tools import BaseTool @@ -45,7 +45,6 @@ ) from rai_open_set_vision.tools import GetGrabbingPointTool -from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, RunSummary, TimeoutException from rai_bench.manipulation_o3de.interfaces import Task from rai_bench.manipulation_o3de.results_tracking import ( @@ -285,7 +284,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: for msg in new_messages: if isinstance(msg, HumanMultimodalMessage): - last_msg = msg.text + last_msg = msg.text() elif isinstance(msg, BaseMessage): if isinstance(msg.content, list): if len(msg.content) == 1: @@ -452,57 +451,3 @@ def run_benchmark( connector.shutdown() o3de.shutdown() rclpy.shutdown() - - -def run_benchmark_dual_agent( - multimodal_llm: BaseChatModel, - tool_calling_llm: BaseChatModel, - out_dir: Path, - scenarios: List[Scenario], - o3de_config_path: str, - bench_logger: logging.Logger, - experiment_id: uuid.UUID = uuid.uuid4(), - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - connector, o3de, benchmark, tools = _setup_benchmark_environment( - o3de_config_path, - get_llm_model_name(multimodal_llm), - scenarios, - out_dir, - bench_logger, - ) - basic_tool_system_prompt = ( - "Based on the conversation call the tools with appropriate arguments" - ) - try: - for scenario in scenarios: - agent = create_multimodal_to_tool_agent( - multimodal_llm=multimodal_llm, - tool_llm=tool_calling_llm, - tools=tools, - multimodal_system_prompt=( - m_system_prompt if m_system_prompt else scenario.task.system_prompt - ), - tool_system_prompt=( - tool_system_prompt - if tool_system_prompt - else basic_tool_system_prompt - ), - logger=bench_logger, - ) - - benchmark.run_next(agent=agent, experiment_id=experiment_id) - - bench_logger.info( - "===============================================================" - ) - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info( - "===============================================================" - ) - - finally: - connector.shutdown() - o3de.shutdown() - rclpy.shutdown() diff --git a/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py b/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py index 772974b40..bee9c70d4 100644 --- a/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py +++ b/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py @@ -47,7 +47,7 @@ def send_score( if isinstance(callback, CallbackHandler): callback.langfuse.score( trace_id=str(run_id), - name="tool calls result", + name="result", value=score, comment=comment, ) @@ -55,7 +55,7 @@ def send_score( if isinstance(callback, LangChainTracer): callback.client.create_feedback( run_id=run_id, - key="tool calls result", + key="result", score=score, comment=comment, ) diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index dbccc393b..84622a5ad 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -17,7 +17,6 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional -from langchain.chat_models.base import BaseChatModel from pydantic import BaseModel import rai_bench.manipulation_o3de as manipulation_o3de @@ -25,7 +24,6 @@ from rai_bench.utils import ( define_benchmark_logger, get_llm_for_benchmark, - get_llm_model_name, ) @@ -79,9 +77,9 @@ class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): complexities : List[Literal["easy", "medium", "hard"]], optional complexity levels of tasks to include in the benchmark, by default all levels are included: ["easy", "medium", "hard"] - task_types : List[Literal["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"]], optional + task_types : List[Literal["basic", "manipulation", "navigation", "custom_interfaces"], optional types of tasks to include in the benchmark, by default all types are included: - ["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"] + ["basic", "manipulation", "navigation", "custom_interfaces"] For more detailed explanation of parameters, see the documentation: (https://robotecai.github.io/rai/simulation_and_benchmarking/rai_bench/) @@ -96,13 +94,11 @@ class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ] ] = [ "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ] @property @@ -110,74 +106,6 @@ def name(self) -> str: return "tool_calling_agent" -def test_dual_agents( - multimodal_llms: List[BaseChatModel], - tool_calling_models: List[BaseChatModel], - benchmark_configs: List[BenchmarkConfig], - out_dir: str, - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - if len(multimodal_llms) != len(tool_calling_models): - raise ValueError( - "Number of passed multimodal models must match number of passed tool calling models" - ) - experiment_id = uuid.uuid4() - for bench_conf in benchmark_configs: - # for each bench configuration seperate run folder - now = datetime.now() - run_name = f"run_{now.strftime('%Y-%m-%d_%H-%M-%S')}" - for i, m_llm in enumerate(multimodal_llms): - tool_llm = tool_calling_models[i] - for u in range(bench_conf.repeats): - curr_out_dir = ( - out_dir - + "/" - + run_name - + "/" - + bench_conf.name - + "/" - + get_llm_model_name(m_llm) - + "/" - + str(u) - ) - bench_logger = define_benchmark_logger(out_dir=Path(curr_out_dir)) - try: - if isinstance(bench_conf, ToolCallingAgentBenchmarkConfig): - tool_calling_tasks = tool_calling_agent.get_tasks( - extra_tool_calls=bench_conf.extra_tool_calls, - complexities=bench_conf.complexities, - task_types=bench_conf.task_types, - ) - tool_calling_agent.run_benchmark_dual_agent( - multimodal_llm=m_llm, - tool_calling_llm=tool_llm, - m_system_prompt=m_system_prompt, - tool_system_prompt=tool_system_prompt, - out_dir=Path(curr_out_dir), - tasks=tool_calling_tasks, - experiment_id=experiment_id, - bench_logger=bench_logger, - ) - elif isinstance(bench_conf, ManipulationO3DEBenchmarkConfig): - manipulation_o3de_scenarios = manipulation_o3de.get_scenarios( - levels=bench_conf.levels, - logger=bench_logger, - ) - manipulation_o3de.run_benchmark_dual_agent( - multimodal_llm=m_llm, - tool_calling_llm=tool_llm, - out_dir=Path(curr_out_dir), - o3de_config_path=bench_conf.o3de_config_path, - scenarios=manipulation_o3de_scenarios, - experiment_id=experiment_id, - bench_logger=bench_logger, - ) - except Exception as e: - bench_logger.critical(f"BENCHMARK RUN FAILED: {e}") - raise e - - def test_models( model_names: List[str], vendors: List[str], @@ -185,7 +113,6 @@ def test_models( out_dir: str, additional_model_args: Optional[List[Dict[str, Any]]] = None, ): - # TODO (jmatejcz) add docstring after passing agent logic will be added if additional_model_args is None: additional_model_args = [{} for _ in model_names] @@ -215,7 +142,6 @@ def test_models( vendor=vendors[i], **additional_model_args[i], ) - # TODO (jmatejcz) take param to set log level bench_logger = define_benchmark_logger(out_dir=Path(curr_out_dir)) try: if isinstance(bench_conf, ToolCallingAgentBenchmarkConfig): diff --git a/src/rai_bench/rai_bench/tool_calling_agent/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/__init__.py index 5f9771011..c4c668ac6 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/__init__.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .benchmark import run_benchmark, run_benchmark_dual_agent +from .benchmark import run_benchmark from .predefined.tasks import get_tasks -__all__ = ["get_tasks", "run_benchmark", "run_benchmark_dual_agent"] +__all__ = ["get_tasks", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index d1843b843..8d22bbb2c 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -16,7 +16,7 @@ import time import uuid from pathlib import Path -from typing import Iterator, List, Optional, Sequence, Tuple +from typing import Iterator, List, Sequence, Tuple from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage @@ -26,9 +26,8 @@ from rai.agents.langchain.core import ( create_conversational_agent, ) -from rai.agents.langchain.core.react_agent import ReActAgentState +from rai.messages import HumanMultimodalMessage -from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, TimeoutException from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler from rai_bench.tool_calling_agent.interfaces import ( @@ -67,7 +66,7 @@ def __init__( def run_next( self, agent: CompiledStateGraph, - initial_state: ReActAgentState, + initial_state: dict, experiment_id: uuid.UUID, ) -> None: """Runs the next task of the benchmark. @@ -227,50 +226,14 @@ def run_benchmark( system_prompt=task.get_system_prompt(), logger=bench_logger, ) - benchmark.run_next(agent=agent, experiment_id=experiment_id) - - bench_logger.info("===============================================================") - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info("===============================================================") - - -def run_benchmark_dual_agent( - multimodal_llm: BaseChatModel, - tool_calling_llm: BaseChatModel, - out_dir: Path, - tasks: List[Task], - bench_logger: logging.Logger, - experiment_id: uuid.UUID = uuid.uuid4(), - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - benchmark = ToolCallingAgentBenchmark( - tasks=tasks, - logger=bench_logger, - model_name=get_llm_model_name(multimodal_llm), - results_dir=out_dir, - ) - - basic_tool_system_prompt = ( - "Based on the conversation call the tools with appropriate arguments" - ) - for task in tasks: - agent = create_multimodal_to_tool_agent( - multimodal_llm=multimodal_llm, - tool_llm=tool_calling_llm, - tools=task.available_tools, - multimodal_system_prompt=( - m_system_prompt if m_system_prompt else task.get_system_prompt() - ), - tool_system_prompt=( - tool_system_prompt if tool_system_prompt else basic_tool_system_prompt - ), - logger=bench_logger, - debug=False, + benchmark.run_next( + agent=agent, + initial_state={ + "messages": [HumanMultimodalMessage(content=task.get_prompt())] + }, + experiment_id=experiment_id, ) - benchmark.run_next(agent=agent, experiment_id=experiment_id) - bench_logger.info("===============================================================") bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") bench_logger.info("===============================================================") diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py index 53ff03a2a..af2958f59 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py @@ -15,11 +15,9 @@ from .basic_tasks import get_basic_tasks from .custom_interfaces_tasks import get_custom_interfaces_tasks from .manipulation_tasks import get_manipulation_tasks -from .spatial_reasoning_tasks import get_spatial_tasks __all__ = [ "get_basic_tasks", "get_custom_interfaces_tasks", "get_manipulation_tasks", - "get_spatial_tasks", ] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py deleted file mode 100644 index d3ccbfa5e..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Literal, Sequence - -from rai_bench.tool_calling_agent.interfaces import ( - Task, - TaskArgs, -) -from rai_bench.tool_calling_agent.subtasks import ( - CheckArgsToolCallSubTask, -) -from rai_bench.tool_calling_agent.tasks.spatial import ( - BoolImageTaskEasy, - BoolImageTaskHard, - BoolImageTaskInput, - BoolImageTaskMedium, -) -from rai_bench.tool_calling_agent.validators import ( - OrderedCallsValidator, -) - -IMG_PATH = "src/rai_bench/rai_bench/tool_calling_agent/predefined/images/" -########## SUBTASKS ################################################################# -return_true_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": True} -) -return_false_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": False} -) - -######### VALIDATORS ######################################################################################### -ret_true_ord_val = OrderedCallsValidator(subtasks=[return_true_subtask]) -ret_false_ord_val = OrderedCallsValidator(subtasks=[return_false_subtask]) - - -def get_spatial_tasks( - extra_tool_calls: List[int] = [0], - prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], - n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], -) -> Sequence[Task]: - """Get predefined spatial reasoning tasks. - - Parameters - ---------- - Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. - See the class documentation for parameter descriptions. - - Returns - ------- - Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. - """ - tasks: List[Task] = [] - - # Categorize tasks by complexity based on question difficulty - easy_true_inputs = [ - # Single object presence/detection - BoolImageTaskInput( - question="Is the chair in the room?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", images_paths=[IMG_PATH + "image_2.jpg"] - ), - BoolImageTaskInput( - question="Are there any pictures on the wall?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="is there a TV in the room?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - ] - - medium_true_inputs = [ - # Object state or counting - BoolImageTaskInput( - question="Are there 3 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is the light on in the room?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Is there something to sit on?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - ] - - hard_true_inputs = [ - # Spatial relationships between objects - BoolImageTaskInput( - question="Is the door on the left from the desk?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant behind the rack?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Is there a rug under the bed?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Is there a pillow on the armchain?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - ] - - easy_false_inputs = [ - # Single object presence/detection - BoolImageTaskInput( - question="Is someone in the room?", images_paths=[IMG_PATH + "image_1.jpg"] - ), - BoolImageTaskInput( - question="Do you see the plant?", images_paths=[IMG_PATH + "image_3.jpg"] - ), - BoolImageTaskInput( - question="Is there a red pillow on the armchair?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - BoolImageTaskInput( - question="Is there a red desk with chair in the room?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Do you see the bed?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - ] - - medium_false_inputs = [ - # Object state or counting - BoolImageTaskInput( - question="Is the door open?", images_paths=[IMG_PATH + "image_1.jpg"] - ), - BoolImageTaskInput( - question="Are there 4 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is the TV switched on?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is the window opened?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - ] - - hard_false_inputs = [ - # Spatial relationships between objects - BoolImageTaskInput( - question="Is there a rack on the left from the sofa?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant on the right from the window?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is the chair next to a bed?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - ] - - for extra_calls in extra_tool_calls: - for detail in prompt_detail: - for shots in n_shots: - task_args = TaskArgs( - extra_tool_calls=extra_calls, - prompt_detail=detail, - examples_in_system_prompt=shots, - ) - - tasks.extend( - [ - BoolImageTaskEasy( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in easy_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskEasy( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in easy_false_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskMedium( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in medium_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskMedium( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in medium_false_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskHard( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in hard_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskHard( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in hard_false_inputs - ] - ) - - return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py index 5699841cd..d2302ce22 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py @@ -22,7 +22,6 @@ get_basic_tasks, get_custom_interfaces_tasks, get_manipulation_tasks, - get_spatial_tasks, ) @@ -36,13 +35,11 @@ def get_tasks( "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ] ] = [ "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ], ) -> List[Task]: """Get a list of tasks based on the provided configuration. @@ -55,7 +52,7 @@ def get_tasks( Returns ------- List[Task] - sequence of spatial reasoning tasks with varying difficulty levels. + sequence of tasks with varying difficulty levels. There will be every combination of extra_tool_calls x prompt_detail x n_shots tasks generated. """ @@ -78,12 +75,6 @@ def get_tasks( prompt_detail=prompt_detail, n_shots=n_shots, ) - if "spatial_reasoning" in task_types: - all_tasks += get_spatial_tasks( - extra_tool_calls=extra_tool_calls, - prompt_detail=prompt_detail, - n_shots=n_shots, - ) filtered_tasks: List[Task] = [] for task in all_tasks: diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py deleted file mode 100644 index 2f9b58e0d..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from abc import ABC, abstractmethod -from typing import List - -from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field -from rai.messages import preprocess_image - -from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator - -loggers_type = logging.Logger - -SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT = """You are a helpful and knowledgeable AI assistant that specializes -in interpreting and analyzing visual content. Your task is to answer questions based -on the images provided to you. Please response with the use of the provided tools.""" -# NOTE (jmatejcz) In this case we are using only one tool so there is no difference bettween 2 and 5 shot -# so I made 1 example in '2 shot' and 2 examples in '5 shot' prompt - -SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT = ( - SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT - + """ - -Example of tool calls: -- return_bool_response, args: {'response': True}""" -) - -SPATIAL_REASONING_SYSTEM_PROMPT_5_SHOT = ( - SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT - + """ -- return_bool_response, args: {'response': False}""" -) - - -class TaskParametrizationError(Exception): - """Exception raised when the task parameters are not valid.""" - - pass - - -class ReturnBoolResponseToolInput(BaseModel): - response: bool = Field(..., description="The response to the question.") - - -class ReturnBoolResponseTool(BaseTool): - """Tool that returns a boolean response.""" - - name: str = "return_bool_response" - description: str = "Return a bool response to the question." - args_schema = ReturnBoolResponseToolInput - - def _run(self, response: bool) -> bool: - if type(response) is bool: - return response - raise ValueError("Invalid response type. Response must be a boolean.") - - -class BoolImageTaskInput(BaseModel): - question: str = Field(..., description="The question to be answered.") - images_paths: List[str] = Field( - ..., - description="List of image file paths to be used for answering the question.", - ) - - -class SpatialReasoningAgentTask(Task): - """Abstract class for spatial reasoning tasks for tool calling agent.""" - - type = "spatial_reasoning" - - def __init__( - self, - validators: List[Validator], - task_args: TaskArgs, - logger: loggers_type | None = None, - ) -> None: - super().__init__( - validators=validators, - task_args=task_args, - logger=logger, - ) - self.expected_tools: List[BaseTool] - self.question: str - self.images_paths: List[str] - - @abstractmethod - def get_images(self) -> List[str]: - """Get the images related to the task. - Returns - ------- - List[str] - List of image paths - """ - pass - - def get_system_prompt(self) -> str: - if self.n_shots == 0: - return SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT - elif self.n_shots == 2: - return SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT - else: - return SPATIAL_REASONING_SYSTEM_PROMPT_5_SHOT - - -class BoolImageTask(SpatialReasoningAgentTask, ABC): - def __init__( - self, - task_input: BoolImageTaskInput, - validators: List[Validator], - task_args: TaskArgs, - logger: loggers_type | None = None, - ) -> None: - super().__init__( - validators=validators, - task_args=task_args, - logger=logger, - ) - self.question = task_input.question - self.images_paths = task_input.images_paths - - @property - def available_tools(self) -> List[BaseTool]: - return [ReturnBoolResponseTool()] - - @property - def optional_tool_calls_number(self) -> int: - return 0 - - def get_base_prompt(self) -> str: - return self.question - - def get_prompt(self): - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()}" - "You can examine the provided image(s) carefully to identify relevant features, " - "analyze the visual content, and provide a boolean response based on your observations." - ) - - def get_images(self): - images = [preprocess_image(image_path) for image_path in self.images_paths] - return images - - -# NOTE (jmatejcz) spatial reasoning task's difficulty is based solely on prompt and image -# so in this case when declaring task, please subjectivly decide how hard is the task -# examples: -# easy -> locating single object, tell if it is present -# medium -> tell in what state is the object (is door open?) or locating multiple objects -# hard -> locating multiple objects and resoning about their relative positions (is X on the right side of Y?) -class BoolImageTaskEasy(BoolImageTask): - complexity = "easy" - - -class BoolImageTaskMedium(BoolImageTask): - complexity = "medium" - - -class BoolImageTaskHard(BoolImageTask): - complexity = "hard" diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index 79e0e0f51..6e6943b1e 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -22,8 +22,9 @@ from rai.initialization import get_llm_model_direct -def parse_tool_calling_benchmark_args(): - parser = argparse.ArgumentParser(description="Run the Tool Calling Agent Benchmark") +def parse_base_benchmark_args(description: str, default_out_subdir: str): + """Parse common benchmark arguments shared across different benchmark types.""" + parser = argparse.ArgumentParser(description=description) parser.add_argument( "--model-name", type=str, @@ -31,6 +32,21 @@ def parse_tool_calling_benchmark_args(): required=True, ) parser.add_argument("--vendor", type=str, help="Vendor of the model", required=True) + + now = datetime.now() + parser.add_argument( + "--out-dir", + type=str, + default=f"src/rai_bench/rai_bench/experiments/{default_out_subdir}/{now.strftime('%Y-%m-%d_%H-%M-%S')}", + help="Output directory for results and logs", + ) + return parser + + +def parse_tool_calling_benchmark_args(): + parser = parse_base_benchmark_args( + "Run the Tool Calling Agent Benchmark", "tool_calling" + ) parser.add_argument( "--extra-tool-calls", type=int, @@ -70,35 +86,21 @@ def parse_tool_calling_benchmark_args(): "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ], default=[ "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ], help="Types of tasks to include in the benchmark", ) - now = datetime.now() - parser.add_argument( - "--out-dir", - type=str, - default=f"src/rai_bench/rai_bench/experiments/tool_calling/{now.strftime('%Y-%m-%d_%H-%M-%S')}", - help="Output directory for results and logs", - ) return parser.parse_args() def parse_manipulation_o3de_benchmark_args(): - parser = argparse.ArgumentParser(description="Run the Manipulation O3DE Benchmark") - parser.add_argument( - "--model-name", - type=str, - help="Model name to use for benchmarking", - required=True, + parser = parse_base_benchmark_args( + "Run the Manipulation O3DE Benchmark", "o3de_manipulation" ) - parser.add_argument("--vendor", type=str, help="Vendor of the model", required=True) parser.add_argument( "--o3de-config-path", type=str, @@ -113,13 +115,11 @@ def parse_manipulation_o3de_benchmark_args(): default=["trivial", "easy", "medium", "hard", "very_hard"], help="Difficulty levels to include in the benchmark", ) - now = datetime.now() - parser.add_argument( - "--out-dir", - type=str, - default=f"src/rai_bench/rai_bench/experiments/o3de_manipulation/{now.strftime('%Y-%m-%d_%H-%M-%S')}", - help="Output directory for results and logs", - ) + return parser.parse_args() + + +def parse_vlm_benchmark_args(): + parser = parse_base_benchmark_args("Run the VLM Benchmark", "vlm_benchmark") return parser.parse_args() @@ -134,11 +134,16 @@ def define_benchmark_logger(out_dir: Path, level: int = logging.INFO) -> logging ) file_handler.setFormatter(formatter) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + bench_logger = logging.getLogger("Benchmark logger") for handler in bench_logger.handlers: bench_logger.removeHandler(handler) bench_logger.setLevel(level) bench_logger.addHandler(file_handler) + bench_logger.addHandler(console_handler) return bench_logger diff --git a/src/rai_bench/rai_bench/vlm_benchmark/__init__.py b/src/rai_bench/rai_bench/vlm_benchmark/__init__.py new file mode 100644 index 000000000..b920f5bf0 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .benchmark import run_benchmark +from .predefined.tasks import get_spatial_tasks + +__all__ = ["get_spatial_tasks", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py b/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py new file mode 100644 index 000000000..8f1c9d699 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py @@ -0,0 +1,211 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import statistics +import time +import uuid +from pathlib import Path +from typing import Iterator, List, Sequence, Tuple + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.runnables.config import RunnableConfig +from langgraph.errors import GraphRecursionError +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel +from rai.agents.langchain.core import ( + create_structured_output_runnable, +) +from rai.messages import HumanMultimodalMessage + +from rai_bench.base_benchmark import BaseBenchmark, RunSummary, TimeoutException +from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler +from rai_bench.utils import get_llm_model_name +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask, TaskValidationError +from rai_bench.vlm_benchmark.results_tracking import ( + TaskResult, +) + + +class VLMBenchmark(BaseBenchmark): + """Benchmark for VLMs.""" + + def __init__( + self, + tasks: Sequence[ImageReasoningTask[BaseModel]], + model_name: str, + results_dir: Path, + logger: logging.Logger | None = None, + ) -> None: + super().__init__( + model_name=model_name, + results_dir=results_dir, + logger=logger, + ) + self._tasks: Iterator[Tuple[int, ImageReasoningTask[BaseModel]]] = enumerate( + iter(tasks) + ) + self.num_tasks = len(tasks) + self.task_results: List[TaskResult] = [] + + self.score_tracing_handler = ScoreTracingHandler() + self.tasks_results: List[TaskResult] = [] + self.csv_initialize(self.results_filename, TaskResult) + + def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: + """Runs the next task of the benchmark. + + Parameters + ---------- + agent : CompiledStateGraph + LangChain tool calling agent. + model_name : str + Name of the LLM model. + """ + # try: + i, task = next(self._tasks) + self.logger.info( + "======================================================================================" + ) + self.logger.info( + f"RUNNING TASK NUMBER {i + 1} / {self.num_tasks}, TASK {task.get_prompt()}" + ) + callbacks = self.score_tracing_handler.get_callbacks() + run_id = uuid.uuid4() + config: RunnableConfig = { + "run_id": run_id, + "callbacks": callbacks, + "tags": [ + f"experiment-id:{experiment_id}", + "benchmark:vlm-benchmark", + self.model_name, + f"task-complexity:{task.complexity}", + ], + "recursion_limit": len(agent.get_graph().nodes), + } + + ts = time.perf_counter() + messages: List[BaseMessage] = [] + prev_count: int = 0 + errors: List[str] = [] + try: + with self.time_limit(60): + for state in agent.stream( + { + "messages": [ + HumanMultimodalMessage( + content=task.get_prompt(), images=task.get_images() + ) + ] + }, + config=config, + ): + node = next(iter(state)) + all_messages = state[node]["messages"] + for new_msg in all_messages[prev_count:]: + messages.append(new_msg) + prev_count = len(messages) + except TimeoutException as e: + self.logger.error(msg=f"Task timeout: {e}") + except GraphRecursionError as e: + self.logger.error(msg=f"Reached recursion limit {e}") + + structured_output = None + try: + structured_output = task.get_structured_output_from_messages( + messages=messages + ) + except TaskValidationError as e: + errors.append(str(e)) + + if structured_output is not None: + score = task.validate(output=structured_output) + else: + errors.append(f"Not valid structured output: {type(structured_output)}") + score = False + + te = time.perf_counter() + total_time = te - ts + + self.logger.info(f"TASK SCORE: {score}, TOTAL TIME: {total_time:.3f}") + + task_result = TaskResult( + task_prompt=task.get_prompt(), + system_prompt=task.get_system_prompt(), + type=task.type, + complexity=task.complexity, + model_name=self.model_name, + score=score, + total_time=total_time, + run_id=run_id, + ) + + self.task_results.append(task_result) + + self.csv_writerow(self.results_filename, task_result) + # computing after every iteration in case of early stopping + self.compute_and_save_summary() + + for callback in callbacks: + self.score_tracing_handler.send_score( + callback=callback, + run_id=run_id, + score=score, + errors=[errors], + ) + + def compute_and_save_summary(self): + self.logger.info("Computing and saving average results...") + + success_count = sum(1 for r in self.task_results if r.score == 1.0) + success_rate = success_count / len(self.task_results) * 100 + avg_time = statistics.mean(r.total_time for r in self.task_results) + + summary = RunSummary( + model_name=self.model_name, + success_rate=round(success_rate, 2), + avg_time=round(avg_time, 3), + total_tasks=len(self.task_results), + ) + self.csv_initialize(self.summary_filename, RunSummary) + self.csv_writerow(self.summary_filename, summary) + + +def run_benchmark( + llm: BaseChatModel, + out_dir: Path, + tasks: List[ImageReasoningTask[BaseModel]], + bench_logger: logging.Logger, + experiment_id: uuid.UUID = uuid.uuid4(), +): + benchmark = VLMBenchmark( + tasks=tasks, + logger=bench_logger, + model_name=get_llm_model_name(llm), + results_dir=out_dir, + ) + + for task in tasks: + agent = create_structured_output_runnable( + llm=llm, + structured_output=task.structured_output, + system_prompt=task.get_system_prompt(), + logger=bench_logger, + ) + + benchmark.run_next(agent=agent, experiment_id=experiment_id) + + bench_logger.info("===============================================================") + bench_logger.info("ALL TASKS DONE. BENCHMARK COMPLETED!") + bench_logger.info("===============================================================") diff --git a/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py b/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py new file mode 100644 index 000000000..97f769b93 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py @@ -0,0 +1,174 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from typing import Generic, List, Literal, Optional, TypeVar + +from langchain_core.messages import BaseMessage +from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT +from pydantic import BaseModel, ConfigDict, ValidationError + +loggers_type = logging.Logger + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) + + +IMAGE_REASONING_SYSTEM_PROMPT = "You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response in requested structured output format." + + +class LangchainRawOutputModel(BaseModel): + """ + A Pydantic model for wrapping Langchain message parsing results from a structured output agent. See documentation for more details: + https://github.com/langchain-ai/langchain/blob/02001212b0a2b37d90451d8493089389ea220cab/libs/core/langchain_core/language_models/chat_models.py#L1430-L1432 + + + Attributes + ---------- + raw : BaseMessage + The original raw message object from Langchain before parsing. + parsed : BaseModel + The parsed and validated Pydantic model instance derived from the raw message. + parsing_error : Optional[BaseException] + Any exception that occurred during the parsing process, None if parsing + was successful. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + raw: BaseMessage + parsed: BaseModel + parsing_error: Optional[BaseException] + + +class TaskValidationError(Exception): + pass + + +class ImageReasoningTask(ABC, Generic[BaseModelT]): + complexity: Literal["easy", "medium", "hard"] + recursion_limit: int = DEFAULT_RECURSION_LIMIT + + def __init__( + self, + logger: loggers_type | None = None, + ) -> None: + """ + Abstract base class representing a complete image reasoning task to be validated. + + Each Task has a consistent prompt and structured output schema, along + with validation methods that check the output against the expected result. + + Attributes + ---------- + logger : logging.Logger + Logger for recording task validation results and errors. + """ + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + self.question: str + self.images_paths: List[str] + + def set_logger(self, logger: loggers_type): + self.logger = logger + + @property + @abstractmethod + def structured_output(self) -> type[BaseModelT]: + """Structured output that agent should return.""" + pass + + @property + @abstractmethod + def type(self) -> str: + """Type of task, for example: image_reasoning""" + pass + + def get_system_prompt(self) -> str: + """Get the system prompt that will be passed to agent + + Returns + ------- + str + System prompt + """ + return IMAGE_REASONING_SYSTEM_PROMPT + + @abstractmethod + def get_prompt(self) -> str: + """Get the task instruction - the prompt that will be passed to agent. + + Returns + ------- + str + Prompt + """ + pass + + @abstractmethod + def validate(self, output: BaseModelT) -> bool: + """Validate result of the task.""" + pass + + @abstractmethod + def get_images(self) -> List[str]: + """Get the images related to the task. + + Returns + ------- + List[str] + List of image paths + """ + pass + + def get_structured_output_from_messages( + self, messages: List[BaseMessage] + ) -> BaseModelT | None: + """Extract and validate structured output from a list of messages. + + Iterates through messages in reverse order, attempting to find the message that is + a LangchainRawOutputModel containing the structured output. + + Parameters + ---------- + messages : List[BaseMessage] + List of messages to search for structured output. + + Returns + ------- + BaseModelT | None + The first valid structured output found that matches the task's expected + output type, or None if no valid structured output is found. + + Raises + ------ + TaskValidationError + If a message contains a parsing error during validation. + """ + for message in reversed(messages): + if isinstance(message, dict): + try: + validated_message = LangchainRawOutputModel.model_validate(message) + if validated_message.parsing_error is not None: + raise TaskValidationError( + f"Parsing error: {validated_message.parsing_error}" + ) + + parsed = validated_message.parsed + if isinstance(parsed, self.structured_output): + return parsed + except ValidationError: + continue + return None diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_1.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_1.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_1.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_1.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_2.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_2.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_2.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_2.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_3.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_3.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_3.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_3.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_4.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_4.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_4.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_4.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_5.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_5.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_5.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_5.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_6.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_6.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_6.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_6.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_7.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_7.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_7.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_7.jpg diff --git a/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py b/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py new file mode 100644 index 000000000..bd1d9cd61 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py @@ -0,0 +1,112 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, cast + +from pydantic import BaseModel + +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask +from rai_bench.vlm_benchmark.tasks.tasks import BoolImageTask, BoolImageTaskInput + +IMG_PATH = "src/rai_bench/rai_bench/vlm_benchmark/predefined/images/" +true_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door on the left from the desk?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is the light on in the room?", + images_paths=[IMG_PATH + "image_2.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_2.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Are there any pictures on the wall?", + images_paths=[IMG_PATH + "image_3.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Are there 3 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is there a plant behind the rack?", + images_paths=[IMG_PATH + "image_5.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is there a pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + expected_answer=True, + ), +] +false_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door open?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is someone in the room?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_3.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Are there 4 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a rack on the left from the sofa?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a plant on the right from the window?", + images_paths=[IMG_PATH + "image_6.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a red pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + expected_answer=False, + ), +] + + +def get_spatial_tasks() -> List[ImageReasoningTask[BaseModel]]: + true_tasks = [ + BoolImageTask( + task_input=input_item, + ) + for input_item in true_response_inputs + ] + false_tasks = [ + BoolImageTask( + task_input=input_item, + ) + for input_item in false_response_inputs + ] + return cast(List[ImageReasoningTask[BaseModel]], true_tasks + false_tasks) diff --git a/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py b/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py new file mode 100644 index 000000000..5d8b77b71 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py @@ -0,0 +1,35 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from uuid import UUID + +from pydantic import BaseModel, Field + + +class TaskResult(BaseModel): + task_prompt: str = Field(..., description="The task prompt.") + system_prompt: str = Field(..., description="The system prompt.") + complexity: str = Field(..., description="Complexity of the task.") + type: str = Field( + ..., description="Type of task, for example: bool_response_image_task" + ) + model_name: str = Field(..., description="Name of the LLM.") + score: float = Field( + ..., + description="Value between 0 and 1.", + ) + + total_time: float = Field(..., description="Total time taken to complete the task.") + run_id: UUID = Field(..., description="UUID of the task run.") diff --git a/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py b/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py new file mode 100644 index 000000000..639b50400 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py @@ -0,0 +1,76 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import List + +from pydantic import BaseModel, Field +from rai.messages import preprocess_image + +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask + +loggers_type = logging.Logger + + +class BoolAnswerWithJustification(BaseModel): + """A boolean answer to the user question along with justification for the answer.""" + + answer: bool + justification: str + + +class BoolImageTaskInput(BaseModel): + question: str = Field(..., description="The question to be answered.") + images_paths: List[str] = Field( + ..., + description="List of image file paths to be used for answering the question.", + ) + expected_answer: bool = Field( + ..., description="The expected answer to the question." + ) + + +class BoolImageTask(ImageReasoningTask[BoolAnswerWithJustification]): + complexity = "easy" + + def __init__( + self, + task_input: BoolImageTaskInput, + logger: loggers_type | None = None, + ) -> None: + super().__init__( + logger=logger, + ) + self.question = task_input.question + self.images_paths = task_input.images_paths + self.expected_answer = task_input.expected_answer + + @property + def structured_output(self) -> type[BoolAnswerWithJustification]: + return BoolAnswerWithJustification + + @property + def type(self) -> str: + return "bool_response_image_task" + + def get_prompt(self): + return self.question + + def get_images(self): + images = [preprocess_image(image_path) for image_path in self.images_paths] + return images + + def validate(self, output: BoolAnswerWithJustification) -> bool: + return output.answer == self.expected_answer diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 4f2523a07..714ff448a 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.2.1" +version = "2.5.0" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" diff --git a/src/rai_core/rai/agents/langchain/__init__.py b/src/rai_core/rai/agents/langchain/__init__.py index b608b97bb..186ad47d8 100644 --- a/src/rai_core/rai/agents/langchain/__init__.py +++ b/src/rai_core/rai/agents/langchain/__init__.py @@ -19,6 +19,7 @@ create_react_runnable, create_state_based_runnable, ) +from .invocation_helpers import invoke_llm_with_tracing from .react_agent import ReActAgent from .state_based_agent import BaseStateBasedAgent, StateBasedConfig @@ -32,5 +33,6 @@ "StateBasedConfig", "create_react_runnable", "create_state_based_runnable", + "invoke_llm_with_tracing", "newMessageBehaviorType", ] diff --git a/src/rai_core/rai/agents/langchain/core/__init__.py b/src/rai_core/rai/agents/langchain/core/__init__.py index 9aade0321..e8b0b9d9b 100644 --- a/src/rai_core/rai/agents/langchain/core/__init__.py +++ b/src/rai_core/rai/agents/langchain/core/__init__.py @@ -25,6 +25,7 @@ create_react_runnable, ) from .state_based_agent import create_state_based_runnable +from .structured_output_agent import create_structured_output_runnable from .tool_runner import SubAgentToolRunner, ToolRunner __all__ = [ @@ -38,5 +39,6 @@ "create_megamind", "create_react_runnable", "create_state_based_runnable", + "create_structured_output_runnable", "get_initial_megamind_state", ] diff --git a/src/rai_core/rai/agents/langchain/core/conversational_agent.py b/src/rai_core/rai/agents/langchain/core/conversational_agent.py index 8b940cdf6..e008fdb7d 100644 --- a/src/rai_core/rai/agents/langchain/core/conversational_agent.py +++ b/src/rai_core/rai/agents/langchain/core/conversational_agent.py @@ -17,17 +17,20 @@ from functools import partial from typing import List, Optional, TypedDict +from deprecated import deprecated from langchain.chat_models.base import BaseChatModel from langchain_core.messages import ( BaseMessage, SystemMessage, ) +from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt.tool_node import tools_condition from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing class State(TypedDict): @@ -39,6 +42,7 @@ def agent( logger: logging.Logger, system_prompt: str | SystemMessage, state: State, + config: RunnableConfig, ): logger.info("Running thinker") @@ -54,11 +58,17 @@ def agent( else system_prompt ) state["messages"].insert(0, system_msg) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) return state +@deprecated( + "Use rai.agents.langchain.core.create_react_runnable instead. " + "Support for the conversational agent will be removed in the 3.0 release." +) def create_conversational_agent( llm: BaseChatModel, tools: List[BaseTool], diff --git a/src/rai_core/rai/agents/langchain/core/react_agent.py b/src/rai_core/rai/agents/langchain/core/react_agent.py index 34424a84e..a15092404 100644 --- a/src/rai_core/rai/agents/langchain/core/react_agent.py +++ b/src/rai_core/rai/agents/langchain/core/react_agent.py @@ -21,13 +21,14 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition from typing_extensions import TypedDict from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing from rai.initialization import get_llm_model from rai.messages import SystemMultimodalMessage @@ -48,6 +49,7 @@ def llm_node( llm: BaseChatModel, system_prompt: Optional[str | SystemMultimodalMessage], state: ReActAgentState, + config: RunnableConfig, ): """Process messages using the LLM. @@ -57,6 +59,8 @@ def llm_node( The language model to use for processing state : ReActAgentState Current state containing messages + config : RunnableConfig + Configuration including callbacks for tracing Returns ------- @@ -75,7 +79,9 @@ def llm_node( # at this point, state['messages'] length should at least be 1 if not isinstance(state["messages"][0], SystemMessage): state["messages"].insert(0, SystemMessage(content=system_prompt)) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) diff --git a/src/rai_core/rai/agents/langchain/core/state_based_agent.py b/src/rai_core/rai/agents/langchain/core/state_based_agent.py index 0ffcff6d5..ea651acd3 100644 --- a/src/rai_core/rai/agents/langchain/core/state_based_agent.py +++ b/src/rai_core/rai/agents/langchain/core/state_based_agent.py @@ -26,12 +26,13 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing from rai.initialization import get_llm_model from rai.messages import HumanMultimodalMessage, SystemMultimodalMessage @@ -52,6 +53,7 @@ def llm_node( llm: BaseChatModel, system_prompt: Optional[str | SystemMultimodalMessage], state: ReActAgentState, + config: RunnableConfig, ): """Process messages using the LLM. @@ -61,6 +63,8 @@ def llm_node( The language model to use for processing state : ReActAgentState Current state containing messages + config : RunnableConfig + Configuration including callbacks for tracing Returns ------- @@ -79,7 +83,9 @@ def llm_node( # at this point, state['messages'] length should at least be 1 if not isinstance(state["messages"][0], SystemMessage): state["messages"].insert(0, SystemMessage(content=system_prompt)) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) diff --git a/src/rai_core/rai/agents/langchain/core/structured_output_agent.py b/src/rai_core/rai/agents/langchain/core/structured_output_agent.py new file mode 100644 index 000000000..7c1878cd3 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/core/structured_output_agent.py @@ -0,0 +1,60 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from functools import partial +from typing import Optional + +from langchain.chat_models.base import BaseChatModel +from langchain_core.messages import ( + SystemMessage, +) +from langgraph.graph import START, StateGraph +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel + +from rai.agents.langchain.core.conversational_agent import State, agent + + +def create_structured_output_runnable( + llm: BaseChatModel, + structured_output: type[BaseModel], + system_prompt: str | SystemMessage, + logger: Optional[logging.Logger] = None, + debug: bool = False, +) -> CompiledStateGraph: + _logger = None + if logger: + _logger = logger + else: + _logger = logging.getLogger(__name__) + + _logger.info("Creating structured output runnable") + + llm_with_structured_output = llm.with_structured_output( + schema=structured_output, include_raw=True + ) + + workflow = StateGraph(State) + + workflow.add_node( + "thinker", partial(agent, llm_with_structured_output, _logger, system_prompt) + ) + + workflow.add_edge(START, "thinker") + + app = workflow.compile(debug=debug) + _logger.info("State based agent created") + return app diff --git a/src/rai_core/rai/agents/langchain/invocation_helpers.py b/src/rai_core/rai/agents/langchain/invocation_helpers.py new file mode 100644 index 000000000..d3ecee749 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/invocation_helpers.py @@ -0,0 +1,76 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, List, Optional + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.runnables import RunnableConfig + +from rai.initialization import get_tracing_callbacks + +logger = logging.getLogger(__name__) + + +def invoke_llm_with_tracing( + llm: BaseChatModel, + messages: List[BaseMessage], + config: Optional[RunnableConfig] = None, +) -> Any: + """ + Invoke an LLM with enhanced tracing callbacks. + + This function automatically adds tracing callbacks (like Langfuse) to LLM calls + within LangGraph nodes, solving the callback propagation issue. + + Tracing is controlled by config.toml. If the file is missing, no tracing is applied. + + Parameters + ---------- + llm : BaseChatModel + The language model to invoke + messages : List[BaseMessage] + Messages to send to the LLM + config : Optional[RunnableConfig] + Existing configuration (may contain some callbacks) + + Returns + ------- + Any + The LLM response + """ + tracing_callbacks = get_tracing_callbacks() + + if len(tracing_callbacks) == 0: + # No tracing callbacks available, use config as-is + return llm.invoke(messages, config=config) + + # Create enhanced config with tracing callbacks + enhanced_config = config.copy() if config else {} + + # Add tracing callbacks to existing callbacks + existing_callbacks = config.get("callbacks", []) if config else [] + + if hasattr(existing_callbacks, "handlers"): + # Merge with existing CallbackManager + all_callbacks = existing_callbacks.handlers + tracing_callbacks + elif isinstance(existing_callbacks, list): + all_callbacks = existing_callbacks + tracing_callbacks + else: + all_callbacks = tracing_callbacks + + enhanced_config["callbacks"] = all_callbacks + + return llm.invoke(messages, config=enhanced_config) diff --git a/src/rai_core/rai/communication/hri_connector.py b/src/rai_core/rai/communication/hri_connector.py index 17f18d416..bffc06ba0 100644 --- a/src/rai_core/rai/communication/hri_connector.py +++ b/src/rai_core/rai/communication/hri_connector.py @@ -103,12 +103,11 @@ def from_langchain( seq_no: int = 0, seq_end: bool = False, ) -> "HRIMessage": + text = message.text() if isinstance(message, RAIMultimodalMessage): - text = message.text images = message.images audios = message.audios else: - text = str(message.content) images = None audios = None if message.type not in ["ai", "human"]: diff --git a/src/rai_core/rai/communication/ros2/api/service.py b/src/rai_core/rai/communication/ros2/api/service.py index 464625aed..06dac1bc1 100644 --- a/src/rai_core/rai/communication/ros2/api/service.py +++ b/src/rai_core/rai/communication/ros2/api/service.py @@ -14,6 +14,7 @@ import os import uuid +from threading import Lock from typing import ( Any, Callable, @@ -30,6 +31,7 @@ import rclpy.qos import rclpy.subscription import rclpy.task +from rclpy.client import Client from rclpy.service import Service from rai.communication.ros2.api.base import ( @@ -39,12 +41,18 @@ class ROS2ServiceAPI(BaseROS2API): - """Handles ROS2 service operations including calling services.""" + """Handles ROS 2 service operations including calling services.""" def __init__(self, node: rclpy.node.Node) -> None: self.node = node self._logger = node.get_logger() self._services: Dict[str, Service] = {} + self._persistent_clients: Dict[str, Client] = {} + self._persistent_clients_lock = Lock() + + def release_client(self, service_name: str) -> bool: + with self._persistent_clients_lock: + return self._persistent_clients.pop(service_name, None) is not None def call_service( self, @@ -52,30 +60,57 @@ def call_service( service_type: str, request: Any, timeout_sec: float = 5.0, + *, + reuse_client: bool = True, ) -> Any: """ - Call a ROS2 service. + Call a ROS 2 service. Args: - service_name: Name of the service to call - service_type: ROS2 service type as string - request: Request message content + service_name: Fully-qualified service name. + service_type: ROS 2 service type string (e.g., 'std_srvs/srv/SetBool'). + request: Request payload dict. + timeout_sec: Seconds to wait for availability/response. + reuse_client: Reuse a cached client. Client creation is synchronized; set + False to create a new client per call. Returns: - The response message + Response message instance. + + Raises: + ValueError: Service not available within the timeout. + AttributeError: Service type or request cannot be constructed. + + Note: + With reuse_client=True, access to the cached client (including the + service call) is serialized by a lock, preventing concurrent calls + through the same client. Use reuse_client=False for per-call clients + when concurrent service calls are required. """ srv_msg, srv_cls = self.build_ros2_service_request(service_type, request) - service_client = self.node.create_client(srv_cls, service_name) # type: ignore - client_ready = service_client.wait_for_service(timeout_sec=timeout_sec) - if not client_ready: - raise ValueError( - f"Service {service_name} not ready within {timeout_sec} seconds. " - "Try increasing the timeout or check if the service is running." - ) - if os.getenv("ROS_DISTRO") == "humble": - return service_client.call(srv_msg) + + def _call_service(client: Client, timeout_sec: float) -> Any: + is_service_available = client.wait_for_service(timeout_sec=timeout_sec) + if not is_service_available: + raise ValueError( + f"Service {service_name} not ready within {timeout_sec} seconds. " + "Try increasing the timeout or check if the service is running." + ) + if os.getenv("ROS_DISTRO") == "humble": + return client.call(srv_msg) + else: + return client.call(srv_msg, timeout_sec=timeout_sec) + + if reuse_client: + with self._persistent_clients_lock: + client = self._persistent_clients.get(service_name, None) + if client is None: + client = self.node.create_client(srv_cls, service_name) # type: ignore + self._persistent_clients[service_name] = client + return _call_service(client, timeout_sec) else: - return service_client.call(srv_msg, timeout_sec=timeout_sec) + client = self.node.create_client(srv_cls, service_name) # type: ignore + return _call_service(client, timeout_sec) def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]: return self.node.get_service_names_and_types() diff --git a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py index 985de6c16..7c1597a56 100644 --- a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py +++ b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py @@ -30,6 +30,9 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: f"{self.__class__.__name__} instance must have an attribute '_service_api' of type ROS2ServiceAPI" ) + def release_client(self, service_name: str) -> bool: + return self._service_api.release_client(service_name) + def service_call( self, message: ROS2Message, @@ -37,6 +40,7 @@ def service_call( timeout_sec: float = 5.0, *, msg_type: str, + reuse_client: bool = True, **kwargs: Any, ) -> ROS2Message: msg = self._service_api.call_service( @@ -44,6 +48,7 @@ def service_call( service_type=msg_type, request=message.payload, timeout_sec=timeout_sec, + reuse_client=reuse_client, ) return ROS2Message( payload=msg, metadata={"msg_type": str(type(msg)), "service": target} diff --git a/src/rai_core/rai/initialization/model_initialization.py b/src/rai_core/rai/initialization/model_initialization.py index c7f4fa263..4496286f0 100644 --- a/src/rai_core/rai/initialization/model_initialization.py +++ b/src/rai_core/rai/initialization/model_initialization.py @@ -275,11 +275,16 @@ def get_embeddings_model( def get_tracing_callbacks( - override_use_langfuse: bool = False, override_use_langsmith: bool = False + config_path: Optional[str] = None, ) -> List[BaseCallbackHandler]: - config = load_config() + try: + config = load_config(config_path) + except Exception as e: + logger.warning(f"Failed to load config for tracing: {e}, tracing disabled") + return [] + callbacks: List[BaseCallbackHandler] = [] - if config.tracing.langfuse.use_langfuse or override_use_langfuse: + if config.tracing.langfuse.use_langfuse: from langfuse.callback import CallbackHandler # type: ignore public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None) @@ -294,7 +299,7 @@ def get_tracing_callbacks( ) callbacks.append(callback) - if config.tracing.langsmith.use_langsmith or override_use_langsmith: + if config.tracing.langsmith.use_langsmith: os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = config.tracing.project api_key = os.getenv("LANGCHAIN_API_KEY", None) diff --git a/src/rai_core/rai/messages/multimodal.py b/src/rai_core/rai/messages/multimodal.py index b1646f37d..1b5129d78 100644 --- a/src/rai_core/rai/messages/multimodal.py +++ b/src/rai_core/rai/messages/multimodal.py @@ -63,10 +63,6 @@ def __init__( _content.extend(_image_content) self.content = _content - @property - def text(self) -> str: - return self.content[0]["text"] - class HumanMultimodalMessage(HumanMessage, MultimodalMessage): def __repr_args__(self) -> Any: diff --git a/src/rai_core/rai/tools/ros2/navigation/__init__.py b/src/rai_core/rai/tools/ros2/navigation/__init__.py index 8c5a94838..4142fcc68 100644 --- a/src/rai_core/rai/tools/ros2/navigation/__init__.py +++ b/src/rai_core/rai/tools/ros2/navigation/__init__.py @@ -20,6 +20,7 @@ Nav2Toolkit, NavigateToPoseTool, ) +from .nav2_blocking import NavigateToPoseBlockingTool __all__ = [ "CancelNavigateToPoseTool", @@ -27,5 +28,6 @@ "GetNavigateToPoseResultTool", "GetOccupancyGridTool", "Nav2Toolkit", + "NavigateToPoseBlockingTool", "NavigateToPoseTool", ] diff --git a/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py b/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py new file mode 100644 index 000000000..b51a79207 --- /dev/null +++ b/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py @@ -0,0 +1,69 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from geometry_msgs.msg import PoseStamped, Quaternion +from nav2_msgs.action import NavigateToPose +from pydantic import BaseModel, Field +from rclpy.action import ActionClient +from tf_transformations import quaternion_from_euler + +from rai.tools.ros2.base import BaseROS2Tool + + +class NavigateToPoseBlockingToolInput(BaseModel): + x: float = Field(..., description="The x coordinate of the pose") + y: float = Field(..., description="The y coordinate of the pose") + z: float = Field(..., description="The z coordinate of the pose") + yaw: float = Field(..., description="The yaw angle of the pose") + + +class NavigateToPoseBlockingTool(BaseROS2Tool): + name: str = "navigate_to_pose_blocking" + description: str = "Navigate to a specific pose" + frame_id: str = Field( + default="map", description="The frame id of the Nav2 stack (map, odom, etc.)" + ) + action_name: str = Field( + default="navigate_to_pose", description="The name of the Nav2 action" + ) + args_schema: Type[NavigateToPoseBlockingToolInput] = NavigateToPoseBlockingToolInput + + def _run(self, x: float, y: float, z: float, yaw: float) -> str: + action_client = ActionClient( + self.connector.node, NavigateToPose, self.action_name + ) + + pose = PoseStamped() + pose.header.frame_id = self.frame_id + pose.header.stamp = self.connector.node.get_clock().now().to_msg() + pose.pose.position.x = x + pose.pose.position.y = y + pose.pose.position.z = z + quat = quaternion_from_euler(0, 0, yaw) + pose.pose.orientation = Quaternion(x=quat[0], y=quat[1], z=quat[2], w=quat[3]) + + goal = NavigateToPose.Goal() + goal.pose = pose + + result = action_client.send_goal(goal) + + if result is None: + return "Navigate to pose action failed. Please try again." + + if result.result.error_code != 0: + return f"Navigate to pose action failed. Error code: {result.result.error_code}" + + return "Navigate to pose successful." diff --git a/tests/agents/langchain/test_langchain_agent.py b/tests/agents/langchain/test_langchain_agent.py index c220f7bbe..44e883120 100644 --- a/tests/agents/langchain/test_langchain_agent.py +++ b/tests/agents/langchain/test_langchain_agent.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from collections import deque from typing import List +from unittest.mock import MagicMock, patch import pytest +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.language_models.fake_chat_models import ParrotFakeChatModel +from langchain_core.runnables import RunnableConfig +from rai.agents.langchain import invoke_llm_with_tracing from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType +from rai.initialization import get_tracing_callbacks +from rai.messages import HumanMultimodalMessage @pytest.mark.parametrize( @@ -39,3 +47,119 @@ def test_reduce_messages( output_ = LangChainAgent._apply_reduction_behavior(new_message_behavior, buffer) assert output == output_ assert buffer == deque(out_buffer) + + +class TestTracingConfiguration: + """Test tracing configuration integration with langchain agents.""" + + def test_tracing_with_missing_config_file(self): + """Test that tracing gracefully handles missing config.toml file in langchain context.""" + # This should not crash even without config.toml + callbacks = get_tracing_callbacks() + assert len(callbacks) == 0 + + def test_tracing_with_config_file_present(self, test_config_toml): + """Test that tracing works when config.toml is present in langchain context.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=True, langsmith_enabled=False + ) + + try: + # Mock environment variables to avoid actual API calls + with patch.dict( + os.environ, + { + "LANGFUSE_PUBLIC_KEY": "test_key", + "LANGFUSE_SECRET_KEY": "test_secret", + }, + ): + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 1 callback for langfuse + assert len(callbacks) == 1 + finally: + cleanup() + + +class TestInvokeLLMWithTracing: + """Test the invoke_llm_with_tracing function.""" + + def test_invoke_llm_without_tracing(self): + """Test that invoke_llm_with_tracing works when no tracing callbacks are available.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock get_tracing_callbacks to return empty list (no config.toml) + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = [] + + result = invoke_llm_with_tracing(mock_llm, mock_messages) + + mock_llm.invoke.assert_called_once_with(mock_messages, config=None) + assert result == "test response" + + def test_invoke_llm_with_tracing(self): + """Test that invoke_llm_with_tracing works when tracing callbacks are available.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock get_tracing_callbacks to return some callbacks + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = ["tracing_callback"] + + _ = invoke_llm_with_tracing(mock_llm, mock_messages) + + # Verify that the LLM was called with enhanced config + mock_llm.invoke.assert_called_once() + call_args = mock_llm.invoke.call_args + assert call_args[0][0] == mock_messages + assert "callbacks" in call_args[1]["config"] + assert "tracing_callback" in call_args[1]["config"]["callbacks"] + + def test_invoke_llm_with_existing_config(self): + """Test that invoke_llm_with_tracing preserves existing config.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock existing config + existing_config = {"callbacks": ["existing_callback"]} + + # Mock get_tracing_callbacks to return some callbacks + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = ["tracing_callback"] + + _ = invoke_llm_with_tracing(mock_llm, mock_messages, existing_config) + + # Verify that the LLM was called with enhanced config + mock_llm.invoke.assert_called_once() + call_args = mock_llm.invoke.call_args + assert call_args[0][0] == mock_messages + assert "callbacks" in call_args[1]["config"] + assert "existing_callback" in call_args[1]["config"]["callbacks"] + assert "tracing_callback" in call_args[1]["config"]["callbacks"] + + def test_invoke_llm_with_callback_integration(self): + """Test that invoke_llm_with_tracing works with a callback handler.""" + llm = ParrotFakeChatModel() + human_msg = HumanMultimodalMessage(content="human") + response = llm.invoke( + [human_msg], config=RunnableConfig(callbacks=[BaseCallbackHandler()]) + ) + assert response.content == [{"type": "text", "text": "human"}] diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index bc102a58a..d51f30378 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -138,11 +138,14 @@ def test_ros2_single_message_publish_wrong_qos_setup( shutdown_executors_and_threads(executors, threads) -def service_call_helper(service_name: str, service_api: ROS2ServiceAPI): +def invoke_set_bool_service( + service_name: str, service_api: ROS2ServiceAPI, reuse_client: bool = True +): response = service_api.call_service( service_name, service_type="std_srvs/srv/SetBool", request={"data": True}, + reuse_client=reuse_client, ) assert response.success assert response.message == "Test service called" @@ -164,7 +167,7 @@ def test_ros2_service_single_call( try: service_api = ROS2ServiceAPI(node) - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api) finally: shutdown_executors_and_threads(executors, threads) @@ -186,7 +189,30 @@ def test_ros2_service_multiple_calls( try: service_api = ROS2ServiceAPI(node) for _ in range(3): - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api, reuse_client=False) + finally: + shutdown_executors_and_threads(executors, threads) + + +@pytest.mark.parametrize( + "callback_group", + [MutuallyExclusiveCallbackGroup(), ReentrantCallbackGroup()], + ids=["MutuallyExclusiveCallbackGroup", "ReentrantCallbackGroup"], +) +def test_ros2_service_multiple_calls_with_reused_client( + ros_setup: None, request: pytest.FixtureRequest, callback_group: CallbackGroup +) -> None: + service_name = f"{request.node.originalname}_service" # type: ignore + node_name = f"{request.node.originalname}_node" # type: ignore + service_server = ServiceServer(service_name, callback_group) + node = Node(node_name) + executors, threads = multi_threaded_spinner([service_server, node]) + + try: + service_api = ROS2ServiceAPI(node) + for _ in range(3): + invoke_set_bool_service(service_name, service_api, reuse_client=True) + assert service_api.release_client(service_name), "Client not released" finally: shutdown_executors_and_threads(executors, threads) @@ -210,7 +236,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_threading( service_threads: List[threading.Thread] = [] for _ in range(10): thread = threading.Thread( - target=service_call_helper, args=(service_name, service_api) + target=invoke_set_bool_service, args=(service_name, service_api) ) service_threads.append(thread) thread.start() @@ -241,7 +267,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_multiprocessing( service_api = ROS2ServiceAPI(node) with Pool(10) as pool: pool.map( - lambda _: service_call_helper(service_name, service_api), range(10) + lambda _: invoke_set_bool_service(service_name, service_api), range(10) ) finally: shutdown_executors_and_threads(executors, threads) diff --git a/tests/conftest.py b/tests/conftest.py index 97ceef6f0..adb9e1850 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,3 +11,132 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import os +import tempfile + +import pytest + + +@pytest.fixture +def test_config_toml(): + """ + Fixture to create a temporary test config.toml file with tracing enabled. + + Returns + ------- + tuple + (config_path, cleanup_function) - The path to the config file and a function to clean it up + """ + + def _create_config(langfuse_enabled=False, langsmith_enabled=False): + # Create a temporary config.toml file + f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) + + # Base config sections (always required) + config_content = """[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "text-embedding-ada-002" + +[aws] +simple_model = "anthropic.claude-instant-v1" +complex_model = "anthropic.claude-v2" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-3.5-turbo" +complex_model = "gpt-4" +embeddings_model = "text-embedding-ada-002" +base_url = "https://api.openai.com/v1" + +[ollama] +simple_model = "llama2" +complex_model = "llama2" +embeddings_model = "llama2" +base_url = "http://localhost:11434" + +[tracing] +project = "test-project" + +[tracing.langfuse] +use_langfuse = {langfuse_enabled} +host = "http://localhost:3000" + +[tracing.langsmith] +use_langsmith = {langsmith_enabled} +host = "https://api.smith.langchain.com" +""".format( + langfuse_enabled=str(langfuse_enabled).lower(), + langsmith_enabled=str(langsmith_enabled).lower(), + ) + + f.write(config_content) + f.close() + + def cleanup(): + try: + f.close() # Ensure file is properly closed + os.unlink(f.name) + except (OSError, PermissionError): + pass # File might already be deleted or have permission issues + + return f.name, cleanup + + return _create_config + + +@pytest.fixture +def test_config_no_tracing(): + """ + Fixture to create a temporary test config.toml file with no tracing section. + + Returns + ------- + tuple + (config_path, cleanup_function) - The path to the config file and a function to clean it up + """ + + def _create_config(): + # Create a temporary config.toml file + f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) + + # Base config sections (always required) + config_content = """[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "text-embedding-ada-002" + +[aws] +simple_model = "anthropic.claude-instant-v1" +complex_model = "anthropic.claude-v2" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-3.5-turbo" +complex_model = "gpt-4" +embeddings_model = "text-embedding-ada-002" +base_url = "https://api.openai.com/v1" + +[ollama] +simple_model = "llama2" +complex_model = "llama2" +embeddings_model = "llama2" +base_url = "http://localhost:11434" +""" + + f.write(config_content) + f.close() + + def cleanup(): + try: + f.close() # Ensure file is properly closed + os.unlink(f.name) + except (OSError, PermissionError): + pass # File might already be deleted or have permission issues + + return f.name, cleanup + + return _create_config diff --git a/tests/initialization/test_tracing.py b/tests/initialization/test_tracing.py new file mode 100644 index 000000000..659941755 --- /dev/null +++ b/tests/initialization/test_tracing.py @@ -0,0 +1,73 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +from rai.initialization import get_tracing_callbacks + + +class TestInitializationTracing: + """Test the initialization module's tracing functionality.""" + + def test_tracing_with_missing_config_file(self): + """Test that tracing gracefully handles missing config.toml file.""" + # This should not crash even without config.toml + callbacks = get_tracing_callbacks() + assert len(callbacks) == 0 + + def test_tracing_with_config_file_present_tracing_disabled(self, test_config_toml): + """Test that tracing works when config.toml is present but tracing is disabled.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=False, langsmith_enabled=False + ) + + try: + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 0 callbacks since both langfuse and langsmith are disabled + assert len(callbacks) == 0 + finally: + cleanup() + + def test_tracing_with_config_file_present_tracing_enabled(self, test_config_toml): + """Test that tracing works when config.toml is present and tracing is enabled.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=True, langsmith_enabled=False + ) + + try: + # Mock environment variables to avoid actual API calls + with patch.dict( + os.environ, + { + "LANGFUSE_PUBLIC_KEY": "test_key", + "LANGFUSE_SECRET_KEY": "test_secret", + }, + ): + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 1 callback for langfuse + assert len(callbacks) == 1 + finally: + cleanup() + + def test_tracing_with_valid_config_file_no_tracing(self, test_config_no_tracing): + """Test that tracing works when config.toml is valid but has no tracing sections.""" + config_path, cleanup = test_config_no_tracing() + + try: + # This should not crash, should return empty callbacks + callbacks = get_tracing_callbacks(config_path=config_path) + assert len(callbacks) == 0 + finally: + cleanup() diff --git a/tests/messages/test_multimodal_message.py b/tests/messages/test_multimodal_message.py new file mode 100644 index 000000000..221ec0fed --- /dev/null +++ b/tests/messages/test_multimodal_message.py @@ -0,0 +1,37 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from rai.messages import HumanMultimodalMessage + + +class TestMultimodalMessage: + """Test the MultimodalMessage class and expected behaviors.""" + + def test_human_multimodal_message_text_simple(self): + """Test text() method with simple text content.""" + msg = HumanMultimodalMessage(content="Hello world") + assert msg.text() == "Hello world" + assert isinstance(msg.text(), str) + + def test_human_multimodal_message_text_with_images(self): + """Test text() method with text and images.""" + # Use a small valid base64 image (1x1 pixel PNG) + valid_base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + msg = HumanMultimodalMessage( + content="Look at this image", images=[valid_base64_image] + ) + assert msg.text() == "Look at this image" + # Should only return text type blocks, not image content + assert valid_base64_image not in msg.text()