diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index d3072a783..24c033a1a 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -6,6 +6,7 @@ RAI Bench is a comprehensive package that both provides benchmarks with ready-to - [Manipulation O3DE Benchmark](#manipulation-o3de-benchmark) - [Tool Calling Agent Benchmark](#tool-calling-agent-benchmark) +- [VLM Benchmark](#vlm-benchmark) ## Manipulation O3DE Benchmark @@ -94,9 +95,9 @@ Evaluates agent performance independently from any simulation, based only on too The `SubTask` class is used to validate just one tool call. Following classes are available: - `CheckArgsToolCallSubTask` - verify if a certain tool was called with expected arguments -- `CheckTopicFieldsToolCallSubTask` - verify if a message published to ROS 2topic was of proper type and included expected fields -- `CheckServiceFieldsToolCallSubTask` - verify if a message published to ROS 2service was of proper type and included expected fields -- `CheckActionFieldsToolCallSubTask` - verify if a message published to ROS 2action was of proper type and included expected fields +- `CheckTopicFieldsToolCallSubTask` - verify if a message published to ROS2 topic was of proper type and included expected fields +- `CheckServiceFieldsToolCallSubTask` - verify if a message published to ROS2 service was of proper type and included expected fields +- `CheckActionFieldsToolCallSubTask` - verify if a message published to ROS2 action was of proper type and included expected fields ### Validator @@ -129,7 +130,6 @@ The ToolCallingAgentBenchmark class manages the execution of tasks and collects There are predefined Tasks available which are grouped by categories: - Basic - require retrieving info from certain topics -- Spatial reasoning - questions about surroundings with images attached - Manipulation - Custom Interfaces - requires using messages with custom interfaces @@ -164,3 +164,17 @@ class TaskArgs(BaseModel): - `GetROS2RGBCameraTask` has 1 required tool call and 1 optional. When `extra_tool_calls` set to 5, agent can correct himself couple times and still pass even with 7 tool calls. There can be 2 types of invalid tool calls, first when the tool is used incorrectly and agent receives an error - this allows him to correct himself easier. Second type is when tool is called properly but it is not the tool that should be called or it is called with wrong params. In this case agent won't get any error so it will be harder for him to correct, but BOTH of these cases are counted as `extra tool call`. If you want to know details about every task, visit `rai_bench/tool_calling_agent/tasks` + +## VLM Benchmark + +The VLM Benchmark is a benchmark for VLM models. It includes a set of tasks containing questions related to images and evaluates the performance of the agent that returns the answer in the structured format. + +### Running + +To run the benchmark: + +```bash +cd rai +source setup_shell.sh +python src/rai_bench/rai_bench/examples/vlm_benchmark.py --model-name gemma3:4b --vendor ollama +``` diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index 85dfca688..ccb933a1c 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -73,7 +73,6 @@ if __name__ == "__main__": extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", - "spatial_reasoning", "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt @@ -95,7 +94,7 @@ if __name__ == "__main__": ) ``` -Based on the example above the `Tool Calling` benchmark will run basic, spatial_reasoning and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. +Based on the example above the `Tool Calling` benchmark will run basic and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. !!! note diff --git a/src/rai_bench/rai_bench/__init__.py b/src/rai_bench/rai_bench/__init__.py index 395b5cf9a..5dd4df1e4 100644 --- a/src/rai_bench/rai_bench/__init__.py +++ b/src/rai_bench/rai_bench/__init__.py @@ -14,7 +14,6 @@ from .test_models import ( ManipulationO3DEBenchmarkConfig, ToolCallingAgentBenchmarkConfig, - test_dual_agents, test_models, ) from .utils import ( @@ -22,6 +21,7 @@ get_llm_for_benchmark, parse_manipulation_o3de_benchmark_args, parse_tool_calling_benchmark_args, + parse_vlm_benchmark_args, ) __all__ = [ @@ -31,6 +31,6 @@ "get_llm_for_benchmark", "parse_manipulation_o3de_benchmark_args", "parse_tool_calling_benchmark_args", - "test_dual_agents", + "parse_vlm_benchmark_args", "test_models", ] diff --git a/src/rai_bench/rai_bench/agents.py b/src/rai_bench/rai_bench/agents.py deleted file mode 100644 index a38e4e9d0..000000000 --- a/src/rai_bench/rai_bench/agents.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from functools import partial -from typing import List, Optional - -from langchain.chat_models.base import BaseChatModel -from langchain_core.messages import ( - AIMessage, - BaseMessage, - HumanMessage, -) -from langchain_core.tools import BaseTool -from langgraph.graph import START, StateGraph -from langgraph.graph.state import CompiledStateGraph -from langgraph.prebuilt.tool_node import tools_condition -from rai.agents.langchain.core.conversational_agent import State, agent -from rai.agents.langchain.core.tool_runner import ToolRunner - - -def multimodal_to_tool_bridge(state: State): - """Node of langchain workflow designed to bridge - nodes with llms. Removing images for context - """ - - cleaned_messages: List[BaseMessage] = [] - for msg in state["messages"]: - if isinstance(msg, HumanMessage): - # Remove images but keep the direct request - if isinstance(msg.content, list): - # Extract text only - text_parts = [ - part.get("text", "") - for part in msg.content - if isinstance(part, dict) and part.get("type") == "text" - ] - if text_parts: - cleaned_messages.append(HumanMessage(content=" ".join(text_parts))) - else: - cleaned_messages.append(msg) - elif isinstance(msg, AIMessage): - # Keep AI messages for context - cleaned_messages.append(msg) - - state["messages"] = cleaned_messages - return state - - -def create_multimodal_to_tool_agent( - multimodal_llm: BaseChatModel, - tool_llm: BaseChatModel, - tools: List[BaseTool], - multimodal_system_prompt: str, - tool_system_prompt: str, - logger: Optional[logging.Logger] = None, - debug: bool = False, -) -> CompiledStateGraph: - """ - Creates an agent flow where inputs first go to a multimodal LLM, - then its output is passed to a tool-calling LLM. - Can be usefull when multimodal llm does not provide tool calling. - - Args: - tools: List of tools available to the tool agent - - Returns: - Compiled state graph - """ - _logger = None - if logger: - _logger = logger - else: - _logger = logging.getLogger(__name__) - - _logger.info("Creating multimodal to tool agent flow") - - tool_llm_with_tools = tool_llm.bind_tools(tools) - tool_node = ToolRunner(tools=tools, logger=_logger) - - workflow = StateGraph(State) - workflow.add_node( - "thinker", - partial(agent, multimodal_llm, _logger, multimodal_system_prompt), - ) - # context bridge for altering the - workflow.add_node( - "context_bridge", - multimodal_to_tool_bridge, - ) - workflow.add_node( - "tool_agent", - partial(agent, tool_llm_with_tools, _logger, tool_system_prompt), - ) - workflow.add_node("tools", tool_node) - - workflow.add_edge(START, "thinker") - workflow.add_edge("thinker", "context_bridge") - workflow.add_edge("context_bridge", "tool_agent") - - workflow.add_conditional_edges( - "tool_agent", - tools_condition, - ) - - # Tool node goes back to tool agent - workflow.add_edge("tools", "tool_agent") - - app = workflow.compile(debug=debug) - _logger.info("Multimodal to tool agent flow created") - return app diff --git a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md index 0eb6664b5..0f287d483 100644 --- a/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md +++ b/src/rai_bench/rai_bench/docs/tool_calling_agent_benchmark.md @@ -14,4 +14,4 @@ Implementations can be found: - Validators [Validators](../tool_calling_agent/validators.py) - Subtasks [Validators](../tool_calling_agent/tasks/subtasks.py) -- Tasks, including basic, spatial, custom interfaces and manipulation [Tasks](../tool_calling_agent/tasks/) +- Tasks, including basic, custom interfaces and manipulation [Tasks](../tool_calling_agent/tasks/) diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index 2a0cc188c..43cdb3099 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -20,7 +20,7 @@ if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen3:4b", "llama3.2:3b"] + model_names = ["qwen2.5:3b", "llama3.2:3b"] vendors = ["ollama", "ollama"] # Define benchmarks that will be used @@ -36,7 +36,7 @@ extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", - "spatial_reasoning", + "manipulation", "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt @@ -48,11 +48,6 @@ test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[mani_conf, tool_conf], + benchmark_configs=[tool_conf, mani_conf], out_dir=out_dir, - # if you want to pass any additinal args to model - additional_model_args=[ - {"reasoning": False}, - {}, - ], ) diff --git a/src/rai_bench/rai_bench/examples/dual_agent.py b/src/rai_bench/rai_bench/examples/dual_agent.py deleted file mode 100644 index cf38cc8d9..000000000 --- a/src/rai_bench/rai_bench/examples/dual_agent.py +++ /dev/null @@ -1,53 +0,0 @@ -# # Copyright (C) 2025 Robotec.AI -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -from langchain_community.chat_models import ChatOllama -from langchain_openai import ChatOpenAI - -from rai_bench import ( - ManipulationO3DEBenchmarkConfig, - ToolCallingAgentBenchmarkConfig, - test_dual_agents, -) - -if __name__ == "__main__": - # Define models you want to benchmark - model_name = "gemma3:4b" - m_llm = ChatOllama( - model=model_name, base_url="http://localhost:11434", keep_alive=30 - ) - - tool_llm = ChatOpenAI(model="gpt-4o-mini", base_url="https://api.openai.com/v1/") - # Define benchmarks that will be used - tool_conf = ToolCallingAgentBenchmarkConfig( - extra_tool_calls=0, # how many extra tool calls allowed to still pass - task_types=["spatial_reasoning"], - repeats=15, - ) - - man_conf = ManipulationO3DEBenchmarkConfig( - o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config - levels=[ # define what difficulty of tasks to include in benchmark - "trivial", - ], - repeats=1, # how many times to repeat - ) - - out_dir = "src/rai_bench/rai_bench/experiments/dual_agents/" - - test_dual_agents( - multimodal_llms=[m_llm], - tool_calling_models=[tool_llm], - benchmark_configs=[man_conf, tool_conf], - out_dir=out_dir, - ) diff --git a/src/rai_bench/rai_bench/examples/vlm_benchmark.py b/src/rai_bench/rai_bench/examples/vlm_benchmark.py new file mode 100644 index 000000000..65f8526a6 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/vlm_benchmark.py @@ -0,0 +1,48 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from rai_bench import ( + define_benchmark_logger, + parse_vlm_benchmark_args, +) +from rai_bench.utils import get_llm_for_benchmark +from rai_bench.vlm_benchmark import get_spatial_tasks, run_benchmark + +if __name__ == "__main__": + args = parse_vlm_benchmark_args() + experiment_dir = Path(args.out_dir) + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + try: + tasks = get_spatial_tasks() + for task in tasks: + task.set_logger(bench_logger) + + llm = get_llm_for_benchmark( + model_name=args.model_name, + vendor=args.vendor, + ) + run_benchmark( + llm=llm, + out_dir=experiment_dir, + tasks=tasks, + bench_logger=bench_logger, + ) + except Exception as e: + bench_logger.critical( + msg=f"Benchmark failed with error: {e}", + exc_info=True, + ) diff --git a/src/rai_bench/rai_bench/manipulation_o3de/__init__.py b/src/rai_bench/rai_bench/manipulation_o3de/__init__.py index 206300fff..8ada5037a 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/__init__.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .benchmark import run_benchmark, run_benchmark_dual_agent +from .benchmark import run_benchmark from .predefined.scenarios import get_scenarios -__all__ = ["get_scenarios", "run_benchmark", "run_benchmark_dual_agent"] +__all__ = ["get_scenarios", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index bb4730ad8..8bf9421d5 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -16,7 +16,7 @@ import time import uuid from pathlib import Path -from typing import List, Optional, TypeVar +from typing import List, TypeVar import rclpy from langchain.tools import BaseTool @@ -45,7 +45,6 @@ ) from rai_open_set_vision.tools import GetGrabbingPointTool -from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, RunSummary, TimeoutException from rai_bench.manipulation_o3de.interfaces import Task from rai_bench.manipulation_o3de.results_tracking import ( @@ -285,7 +284,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: for msg in new_messages: if isinstance(msg, HumanMultimodalMessage): - last_msg = msg.text + last_msg = msg.text() elif isinstance(msg, BaseMessage): if isinstance(msg.content, list): if len(msg.content) == 1: @@ -452,57 +451,3 @@ def run_benchmark( connector.shutdown() o3de.shutdown() rclpy.shutdown() - - -def run_benchmark_dual_agent( - multimodal_llm: BaseChatModel, - tool_calling_llm: BaseChatModel, - out_dir: Path, - scenarios: List[Scenario], - o3de_config_path: str, - bench_logger: logging.Logger, - experiment_id: uuid.UUID = uuid.uuid4(), - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - connector, o3de, benchmark, tools = _setup_benchmark_environment( - o3de_config_path, - get_llm_model_name(multimodal_llm), - scenarios, - out_dir, - bench_logger, - ) - basic_tool_system_prompt = ( - "Based on the conversation call the tools with appropriate arguments" - ) - try: - for scenario in scenarios: - agent = create_multimodal_to_tool_agent( - multimodal_llm=multimodal_llm, - tool_llm=tool_calling_llm, - tools=tools, - multimodal_system_prompt=( - m_system_prompt if m_system_prompt else scenario.task.system_prompt - ), - tool_system_prompt=( - tool_system_prompt - if tool_system_prompt - else basic_tool_system_prompt - ), - logger=bench_logger, - ) - - benchmark.run_next(agent=agent, experiment_id=experiment_id) - - bench_logger.info( - "===============================================================" - ) - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info( - "===============================================================" - ) - - finally: - connector.shutdown() - o3de.shutdown() - rclpy.shutdown() diff --git a/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py b/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py index 772974b40..bee9c70d4 100644 --- a/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py +++ b/src/rai_bench/rai_bench/results_processing/langfuse_scores_tracing.py @@ -47,7 +47,7 @@ def send_score( if isinstance(callback, CallbackHandler): callback.langfuse.score( trace_id=str(run_id), - name="tool calls result", + name="result", value=score, comment=comment, ) @@ -55,7 +55,7 @@ def send_score( if isinstance(callback, LangChainTracer): callback.client.create_feedback( run_id=run_id, - key="tool calls result", + key="result", score=score, comment=comment, ) diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index dbccc393b..84622a5ad 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -17,7 +17,6 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional -from langchain.chat_models.base import BaseChatModel from pydantic import BaseModel import rai_bench.manipulation_o3de as manipulation_o3de @@ -25,7 +24,6 @@ from rai_bench.utils import ( define_benchmark_logger, get_llm_for_benchmark, - get_llm_model_name, ) @@ -79,9 +77,9 @@ class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): complexities : List[Literal["easy", "medium", "hard"]], optional complexity levels of tasks to include in the benchmark, by default all levels are included: ["easy", "medium", "hard"] - task_types : List[Literal["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"]], optional + task_types : List[Literal["basic", "manipulation", "navigation", "custom_interfaces"], optional types of tasks to include in the benchmark, by default all types are included: - ["basic", "manipulation", "navigation", "custom_interfaces", "spatial_reasoning"] + ["basic", "manipulation", "navigation", "custom_interfaces"] For more detailed explanation of parameters, see the documentation: (https://robotecai.github.io/rai/simulation_and_benchmarking/rai_bench/) @@ -96,13 +94,11 @@ class ToolCallingAgentBenchmarkConfig(BenchmarkConfig): "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ] ] = [ "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ] @property @@ -110,74 +106,6 @@ def name(self) -> str: return "tool_calling_agent" -def test_dual_agents( - multimodal_llms: List[BaseChatModel], - tool_calling_models: List[BaseChatModel], - benchmark_configs: List[BenchmarkConfig], - out_dir: str, - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - if len(multimodal_llms) != len(tool_calling_models): - raise ValueError( - "Number of passed multimodal models must match number of passed tool calling models" - ) - experiment_id = uuid.uuid4() - for bench_conf in benchmark_configs: - # for each bench configuration seperate run folder - now = datetime.now() - run_name = f"run_{now.strftime('%Y-%m-%d_%H-%M-%S')}" - for i, m_llm in enumerate(multimodal_llms): - tool_llm = tool_calling_models[i] - for u in range(bench_conf.repeats): - curr_out_dir = ( - out_dir - + "/" - + run_name - + "/" - + bench_conf.name - + "/" - + get_llm_model_name(m_llm) - + "/" - + str(u) - ) - bench_logger = define_benchmark_logger(out_dir=Path(curr_out_dir)) - try: - if isinstance(bench_conf, ToolCallingAgentBenchmarkConfig): - tool_calling_tasks = tool_calling_agent.get_tasks( - extra_tool_calls=bench_conf.extra_tool_calls, - complexities=bench_conf.complexities, - task_types=bench_conf.task_types, - ) - tool_calling_agent.run_benchmark_dual_agent( - multimodal_llm=m_llm, - tool_calling_llm=tool_llm, - m_system_prompt=m_system_prompt, - tool_system_prompt=tool_system_prompt, - out_dir=Path(curr_out_dir), - tasks=tool_calling_tasks, - experiment_id=experiment_id, - bench_logger=bench_logger, - ) - elif isinstance(bench_conf, ManipulationO3DEBenchmarkConfig): - manipulation_o3de_scenarios = manipulation_o3de.get_scenarios( - levels=bench_conf.levels, - logger=bench_logger, - ) - manipulation_o3de.run_benchmark_dual_agent( - multimodal_llm=m_llm, - tool_calling_llm=tool_llm, - out_dir=Path(curr_out_dir), - o3de_config_path=bench_conf.o3de_config_path, - scenarios=manipulation_o3de_scenarios, - experiment_id=experiment_id, - bench_logger=bench_logger, - ) - except Exception as e: - bench_logger.critical(f"BENCHMARK RUN FAILED: {e}") - raise e - - def test_models( model_names: List[str], vendors: List[str], @@ -185,7 +113,6 @@ def test_models( out_dir: str, additional_model_args: Optional[List[Dict[str, Any]]] = None, ): - # TODO (jmatejcz) add docstring after passing agent logic will be added if additional_model_args is None: additional_model_args = [{} for _ in model_names] @@ -215,7 +142,6 @@ def test_models( vendor=vendors[i], **additional_model_args[i], ) - # TODO (jmatejcz) take param to set log level bench_logger = define_benchmark_logger(out_dir=Path(curr_out_dir)) try: if isinstance(bench_conf, ToolCallingAgentBenchmarkConfig): diff --git a/src/rai_bench/rai_bench/tool_calling_agent/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/__init__.py index 5f9771011..c4c668ac6 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/__init__.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .benchmark import run_benchmark, run_benchmark_dual_agent +from .benchmark import run_benchmark from .predefined.tasks import get_tasks -__all__ = ["get_tasks", "run_benchmark", "run_benchmark_dual_agent"] +__all__ = ["get_tasks", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py index d1843b843..8d22bbb2c 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/benchmark.py @@ -16,7 +16,7 @@ import time import uuid from pathlib import Path -from typing import Iterator, List, Optional, Sequence, Tuple +from typing import Iterator, List, Sequence, Tuple from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage @@ -26,9 +26,8 @@ from rai.agents.langchain.core import ( create_conversational_agent, ) -from rai.agents.langchain.core.react_agent import ReActAgentState +from rai.messages import HumanMultimodalMessage -from rai_bench.agents import create_multimodal_to_tool_agent from rai_bench.base_benchmark import BaseBenchmark, TimeoutException from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler from rai_bench.tool_calling_agent.interfaces import ( @@ -67,7 +66,7 @@ def __init__( def run_next( self, agent: CompiledStateGraph, - initial_state: ReActAgentState, + initial_state: dict, experiment_id: uuid.UUID, ) -> None: """Runs the next task of the benchmark. @@ -227,50 +226,14 @@ def run_benchmark( system_prompt=task.get_system_prompt(), logger=bench_logger, ) - benchmark.run_next(agent=agent, experiment_id=experiment_id) - - bench_logger.info("===============================================================") - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info("===============================================================") - - -def run_benchmark_dual_agent( - multimodal_llm: BaseChatModel, - tool_calling_llm: BaseChatModel, - out_dir: Path, - tasks: List[Task], - bench_logger: logging.Logger, - experiment_id: uuid.UUID = uuid.uuid4(), - m_system_prompt: Optional[str] = None, - tool_system_prompt: Optional[str] = None, -): - benchmark = ToolCallingAgentBenchmark( - tasks=tasks, - logger=bench_logger, - model_name=get_llm_model_name(multimodal_llm), - results_dir=out_dir, - ) - - basic_tool_system_prompt = ( - "Based on the conversation call the tools with appropriate arguments" - ) - for task in tasks: - agent = create_multimodal_to_tool_agent( - multimodal_llm=multimodal_llm, - tool_llm=tool_calling_llm, - tools=task.available_tools, - multimodal_system_prompt=( - m_system_prompt if m_system_prompt else task.get_system_prompt() - ), - tool_system_prompt=( - tool_system_prompt if tool_system_prompt else basic_tool_system_prompt - ), - logger=bench_logger, - debug=False, + benchmark.run_next( + agent=agent, + initial_state={ + "messages": [HumanMultimodalMessage(content=task.get_prompt())] + }, + experiment_id=experiment_id, ) - benchmark.run_next(agent=agent, experiment_id=experiment_id) - bench_logger.info("===============================================================") bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") bench_logger.info("===============================================================") diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py index 53ff03a2a..af2958f59 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/__init__.py @@ -15,11 +15,9 @@ from .basic_tasks import get_basic_tasks from .custom_interfaces_tasks import get_custom_interfaces_tasks from .manipulation_tasks import get_manipulation_tasks -from .spatial_reasoning_tasks import get_spatial_tasks __all__ = [ "get_basic_tasks", "get_custom_interfaces_tasks", "get_manipulation_tasks", - "get_spatial_tasks", ] diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py deleted file mode 100644 index d3ccbfa5e..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/spatial_reasoning_tasks.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Literal, Sequence - -from rai_bench.tool_calling_agent.interfaces import ( - Task, - TaskArgs, -) -from rai_bench.tool_calling_agent.subtasks import ( - CheckArgsToolCallSubTask, -) -from rai_bench.tool_calling_agent.tasks.spatial import ( - BoolImageTaskEasy, - BoolImageTaskHard, - BoolImageTaskInput, - BoolImageTaskMedium, -) -from rai_bench.tool_calling_agent.validators import ( - OrderedCallsValidator, -) - -IMG_PATH = "src/rai_bench/rai_bench/tool_calling_agent/predefined/images/" -########## SUBTASKS ################################################################# -return_true_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": True} -) -return_false_subtask = CheckArgsToolCallSubTask( - expected_tool_name="return_bool_response", expected_args={"response": False} -) - -######### VALIDATORS ######################################################################################### -ret_true_ord_val = OrderedCallsValidator(subtasks=[return_true_subtask]) -ret_false_ord_val = OrderedCallsValidator(subtasks=[return_false_subtask]) - - -def get_spatial_tasks( - extra_tool_calls: List[int] = [0], - prompt_detail: List[Literal["brief", "descriptive"]] = ["brief", "descriptive"], - n_shots: List[Literal[0, 2, 5]] = [0, 2, 5], -) -> Sequence[Task]: - """Get predefined spatial reasoning tasks. - - Parameters - ---------- - Parameters match :class:`~src.rai_bench.rai_bench.test_models.ToolCallingAgentBenchmarkConfig`. - See the class documentation for parameter descriptions. - - Returns - ------- - Returned list match :func:`~src.rai_bench.rai_bench.tool_calling_agent.predefined.tasks.get_tasks`. - """ - tasks: List[Task] = [] - - # Categorize tasks by complexity based on question difficulty - easy_true_inputs = [ - # Single object presence/detection - BoolImageTaskInput( - question="Is the chair in the room?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Do you see the plant?", images_paths=[IMG_PATH + "image_2.jpg"] - ), - BoolImageTaskInput( - question="Are there any pictures on the wall?", - images_paths=[IMG_PATH + "image_3.jpg"], - ), - BoolImageTaskInput( - question="is there a TV in the room?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - ] - - medium_true_inputs = [ - # Object state or counting - BoolImageTaskInput( - question="Are there 3 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is the light on in the room?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Is there something to sit on?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - ] - - hard_true_inputs = [ - # Spatial relationships between objects - BoolImageTaskInput( - question="Is the door on the left from the desk?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant behind the rack?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Is there a rug under the bed?", - images_paths=[IMG_PATH + "image_2.jpg"], - ), - BoolImageTaskInput( - question="Is there a pillow on the armchain?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - ] - - easy_false_inputs = [ - # Single object presence/detection - BoolImageTaskInput( - question="Is someone in the room?", images_paths=[IMG_PATH + "image_1.jpg"] - ), - BoolImageTaskInput( - question="Do you see the plant?", images_paths=[IMG_PATH + "image_3.jpg"] - ), - BoolImageTaskInput( - question="Is there a red pillow on the armchair?", - images_paths=[IMG_PATH + "image_7.jpg"], - ), - BoolImageTaskInput( - question="Is there a red desk with chair in the room?", - images_paths=[IMG_PATH + "image_5.jpg"], - ), - BoolImageTaskInput( - question="Do you see the bed?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - ] - - medium_false_inputs = [ - # Object state or counting - BoolImageTaskInput( - question="Is the door open?", images_paths=[IMG_PATH + "image_1.jpg"] - ), - BoolImageTaskInput( - question="Are there 4 pictures on the wall?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is the TV switched on?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is the window opened?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - ] - - hard_false_inputs = [ - # Spatial relationships between objects - BoolImageTaskInput( - question="Is there a rack on the left from the sofa?", - images_paths=[IMG_PATH + "image_4.jpg"], - ), - BoolImageTaskInput( - question="Is there a plant on the right from the window?", - images_paths=[IMG_PATH + "image_6.jpg"], - ), - BoolImageTaskInput( - question="Is the chair next to a bed?", - images_paths=[IMG_PATH + "image_1.jpg"], - ), - ] - - for extra_calls in extra_tool_calls: - for detail in prompt_detail: - for shots in n_shots: - task_args = TaskArgs( - extra_tool_calls=extra_calls, - prompt_detail=detail, - examples_in_system_prompt=shots, - ) - - tasks.extend( - [ - BoolImageTaskEasy( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in easy_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskEasy( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in easy_false_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskMedium( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in medium_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskMedium( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in medium_false_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskHard( - task_input=input_item, - validators=[ret_true_ord_val], - task_args=task_args, - ) - for input_item in hard_true_inputs - ] - ) - - tasks.extend( - [ - BoolImageTaskHard( - task_input=input_item, - validators=[ret_false_ord_val], - task_args=task_args, - ) - for input_item in hard_false_inputs - ] - ) - - return tasks diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py index 5699841cd..d2302ce22 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent/predefined/tasks.py @@ -22,7 +22,6 @@ get_basic_tasks, get_custom_interfaces_tasks, get_manipulation_tasks, - get_spatial_tasks, ) @@ -36,13 +35,11 @@ def get_tasks( "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ] ] = [ "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ], ) -> List[Task]: """Get a list of tasks based on the provided configuration. @@ -55,7 +52,7 @@ def get_tasks( Returns ------- List[Task] - sequence of spatial reasoning tasks with varying difficulty levels. + sequence of tasks with varying difficulty levels. There will be every combination of extra_tool_calls x prompt_detail x n_shots tasks generated. """ @@ -78,12 +75,6 @@ def get_tasks( prompt_detail=prompt_detail, n_shots=n_shots, ) - if "spatial_reasoning" in task_types: - all_tasks += get_spatial_tasks( - extra_tool_calls=extra_tool_calls, - prompt_detail=prompt_detail, - n_shots=n_shots, - ) filtered_tasks: List[Task] = [] for task in all_tasks: diff --git a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py b/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py deleted file mode 100644 index 2f9b58e0d..000000000 --- a/src/rai_bench/rai_bench/tool_calling_agent/tasks/spatial.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from abc import ABC, abstractmethod -from typing import List - -from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field -from rai.messages import preprocess_image - -from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs, Validator - -loggers_type = logging.Logger - -SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT = """You are a helpful and knowledgeable AI assistant that specializes -in interpreting and analyzing visual content. Your task is to answer questions based -on the images provided to you. Please response with the use of the provided tools.""" -# NOTE (jmatejcz) In this case we are using only one tool so there is no difference bettween 2 and 5 shot -# so I made 1 example in '2 shot' and 2 examples in '5 shot' prompt - -SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT = ( - SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT - + """ - -Example of tool calls: -- return_bool_response, args: {'response': True}""" -) - -SPATIAL_REASONING_SYSTEM_PROMPT_5_SHOT = ( - SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT - + """ -- return_bool_response, args: {'response': False}""" -) - - -class TaskParametrizationError(Exception): - """Exception raised when the task parameters are not valid.""" - - pass - - -class ReturnBoolResponseToolInput(BaseModel): - response: bool = Field(..., description="The response to the question.") - - -class ReturnBoolResponseTool(BaseTool): - """Tool that returns a boolean response.""" - - name: str = "return_bool_response" - description: str = "Return a bool response to the question." - args_schema = ReturnBoolResponseToolInput - - def _run(self, response: bool) -> bool: - if type(response) is bool: - return response - raise ValueError("Invalid response type. Response must be a boolean.") - - -class BoolImageTaskInput(BaseModel): - question: str = Field(..., description="The question to be answered.") - images_paths: List[str] = Field( - ..., - description="List of image file paths to be used for answering the question.", - ) - - -class SpatialReasoningAgentTask(Task): - """Abstract class for spatial reasoning tasks for tool calling agent.""" - - type = "spatial_reasoning" - - def __init__( - self, - validators: List[Validator], - task_args: TaskArgs, - logger: loggers_type | None = None, - ) -> None: - super().__init__( - validators=validators, - task_args=task_args, - logger=logger, - ) - self.expected_tools: List[BaseTool] - self.question: str - self.images_paths: List[str] - - @abstractmethod - def get_images(self) -> List[str]: - """Get the images related to the task. - Returns - ------- - List[str] - List of image paths - """ - pass - - def get_system_prompt(self) -> str: - if self.n_shots == 0: - return SPATIAL_REASONING_SYSTEM_PROMPT_0_SHOT - elif self.n_shots == 2: - return SPATIAL_REASONING_SYSTEM_PROMPT_2_SHOT - else: - return SPATIAL_REASONING_SYSTEM_PROMPT_5_SHOT - - -class BoolImageTask(SpatialReasoningAgentTask, ABC): - def __init__( - self, - task_input: BoolImageTaskInput, - validators: List[Validator], - task_args: TaskArgs, - logger: loggers_type | None = None, - ) -> None: - super().__init__( - validators=validators, - task_args=task_args, - logger=logger, - ) - self.question = task_input.question - self.images_paths = task_input.images_paths - - @property - def available_tools(self) -> List[BaseTool]: - return [ReturnBoolResponseTool()] - - @property - def optional_tool_calls_number(self) -> int: - return 0 - - def get_base_prompt(self) -> str: - return self.question - - def get_prompt(self): - if self.prompt_detail == "brief": - return self.get_base_prompt() - else: - return ( - f"{self.get_base_prompt()}" - "You can examine the provided image(s) carefully to identify relevant features, " - "analyze the visual content, and provide a boolean response based on your observations." - ) - - def get_images(self): - images = [preprocess_image(image_path) for image_path in self.images_paths] - return images - - -# NOTE (jmatejcz) spatial reasoning task's difficulty is based solely on prompt and image -# so in this case when declaring task, please subjectivly decide how hard is the task -# examples: -# easy -> locating single object, tell if it is present -# medium -> tell in what state is the object (is door open?) or locating multiple objects -# hard -> locating multiple objects and resoning about their relative positions (is X on the right side of Y?) -class BoolImageTaskEasy(BoolImageTask): - complexity = "easy" - - -class BoolImageTaskMedium(BoolImageTask): - complexity = "medium" - - -class BoolImageTaskHard(BoolImageTask): - complexity = "hard" diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index 79e0e0f51..6e6943b1e 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -22,8 +22,9 @@ from rai.initialization import get_llm_model_direct -def parse_tool_calling_benchmark_args(): - parser = argparse.ArgumentParser(description="Run the Tool Calling Agent Benchmark") +def parse_base_benchmark_args(description: str, default_out_subdir: str): + """Parse common benchmark arguments shared across different benchmark types.""" + parser = argparse.ArgumentParser(description=description) parser.add_argument( "--model-name", type=str, @@ -31,6 +32,21 @@ def parse_tool_calling_benchmark_args(): required=True, ) parser.add_argument("--vendor", type=str, help="Vendor of the model", required=True) + + now = datetime.now() + parser.add_argument( + "--out-dir", + type=str, + default=f"src/rai_bench/rai_bench/experiments/{default_out_subdir}/{now.strftime('%Y-%m-%d_%H-%M-%S')}", + help="Output directory for results and logs", + ) + return parser + + +def parse_tool_calling_benchmark_args(): + parser = parse_base_benchmark_args( + "Run the Tool Calling Agent Benchmark", "tool_calling" + ) parser.add_argument( "--extra-tool-calls", type=int, @@ -70,35 +86,21 @@ def parse_tool_calling_benchmark_args(): "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ], default=[ "basic", "manipulation", "custom_interfaces", - "spatial_reasoning", ], help="Types of tasks to include in the benchmark", ) - now = datetime.now() - parser.add_argument( - "--out-dir", - type=str, - default=f"src/rai_bench/rai_bench/experiments/tool_calling/{now.strftime('%Y-%m-%d_%H-%M-%S')}", - help="Output directory for results and logs", - ) return parser.parse_args() def parse_manipulation_o3de_benchmark_args(): - parser = argparse.ArgumentParser(description="Run the Manipulation O3DE Benchmark") - parser.add_argument( - "--model-name", - type=str, - help="Model name to use for benchmarking", - required=True, + parser = parse_base_benchmark_args( + "Run the Manipulation O3DE Benchmark", "o3de_manipulation" ) - parser.add_argument("--vendor", type=str, help="Vendor of the model", required=True) parser.add_argument( "--o3de-config-path", type=str, @@ -113,13 +115,11 @@ def parse_manipulation_o3de_benchmark_args(): default=["trivial", "easy", "medium", "hard", "very_hard"], help="Difficulty levels to include in the benchmark", ) - now = datetime.now() - parser.add_argument( - "--out-dir", - type=str, - default=f"src/rai_bench/rai_bench/experiments/o3de_manipulation/{now.strftime('%Y-%m-%d_%H-%M-%S')}", - help="Output directory for results and logs", - ) + return parser.parse_args() + + +def parse_vlm_benchmark_args(): + parser = parse_base_benchmark_args("Run the VLM Benchmark", "vlm_benchmark") return parser.parse_args() @@ -134,11 +134,16 @@ def define_benchmark_logger(out_dir: Path, level: int = logging.INFO) -> logging ) file_handler.setFormatter(formatter) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + bench_logger = logging.getLogger("Benchmark logger") for handler in bench_logger.handlers: bench_logger.removeHandler(handler) bench_logger.setLevel(level) bench_logger.addHandler(file_handler) + bench_logger.addHandler(console_handler) return bench_logger diff --git a/src/rai_bench/rai_bench/vlm_benchmark/__init__.py b/src/rai_bench/rai_bench/vlm_benchmark/__init__.py new file mode 100644 index 000000000..b920f5bf0 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .benchmark import run_benchmark +from .predefined.tasks import get_spatial_tasks + +__all__ = ["get_spatial_tasks", "run_benchmark"] diff --git a/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py b/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py new file mode 100644 index 000000000..8f1c9d699 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/benchmark.py @@ -0,0 +1,211 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import statistics +import time +import uuid +from pathlib import Path +from typing import Iterator, List, Sequence, Tuple + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.runnables.config import RunnableConfig +from langgraph.errors import GraphRecursionError +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel +from rai.agents.langchain.core import ( + create_structured_output_runnable, +) +from rai.messages import HumanMultimodalMessage + +from rai_bench.base_benchmark import BaseBenchmark, RunSummary, TimeoutException +from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler +from rai_bench.utils import get_llm_model_name +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask, TaskValidationError +from rai_bench.vlm_benchmark.results_tracking import ( + TaskResult, +) + + +class VLMBenchmark(BaseBenchmark): + """Benchmark for VLMs.""" + + def __init__( + self, + tasks: Sequence[ImageReasoningTask[BaseModel]], + model_name: str, + results_dir: Path, + logger: logging.Logger | None = None, + ) -> None: + super().__init__( + model_name=model_name, + results_dir=results_dir, + logger=logger, + ) + self._tasks: Iterator[Tuple[int, ImageReasoningTask[BaseModel]]] = enumerate( + iter(tasks) + ) + self.num_tasks = len(tasks) + self.task_results: List[TaskResult] = [] + + self.score_tracing_handler = ScoreTracingHandler() + self.tasks_results: List[TaskResult] = [] + self.csv_initialize(self.results_filename, TaskResult) + + def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None: + """Runs the next task of the benchmark. + + Parameters + ---------- + agent : CompiledStateGraph + LangChain tool calling agent. + model_name : str + Name of the LLM model. + """ + # try: + i, task = next(self._tasks) + self.logger.info( + "======================================================================================" + ) + self.logger.info( + f"RUNNING TASK NUMBER {i + 1} / {self.num_tasks}, TASK {task.get_prompt()}" + ) + callbacks = self.score_tracing_handler.get_callbacks() + run_id = uuid.uuid4() + config: RunnableConfig = { + "run_id": run_id, + "callbacks": callbacks, + "tags": [ + f"experiment-id:{experiment_id}", + "benchmark:vlm-benchmark", + self.model_name, + f"task-complexity:{task.complexity}", + ], + "recursion_limit": len(agent.get_graph().nodes), + } + + ts = time.perf_counter() + messages: List[BaseMessage] = [] + prev_count: int = 0 + errors: List[str] = [] + try: + with self.time_limit(60): + for state in agent.stream( + { + "messages": [ + HumanMultimodalMessage( + content=task.get_prompt(), images=task.get_images() + ) + ] + }, + config=config, + ): + node = next(iter(state)) + all_messages = state[node]["messages"] + for new_msg in all_messages[prev_count:]: + messages.append(new_msg) + prev_count = len(messages) + except TimeoutException as e: + self.logger.error(msg=f"Task timeout: {e}") + except GraphRecursionError as e: + self.logger.error(msg=f"Reached recursion limit {e}") + + structured_output = None + try: + structured_output = task.get_structured_output_from_messages( + messages=messages + ) + except TaskValidationError as e: + errors.append(str(e)) + + if structured_output is not None: + score = task.validate(output=structured_output) + else: + errors.append(f"Not valid structured output: {type(structured_output)}") + score = False + + te = time.perf_counter() + total_time = te - ts + + self.logger.info(f"TASK SCORE: {score}, TOTAL TIME: {total_time:.3f}") + + task_result = TaskResult( + task_prompt=task.get_prompt(), + system_prompt=task.get_system_prompt(), + type=task.type, + complexity=task.complexity, + model_name=self.model_name, + score=score, + total_time=total_time, + run_id=run_id, + ) + + self.task_results.append(task_result) + + self.csv_writerow(self.results_filename, task_result) + # computing after every iteration in case of early stopping + self.compute_and_save_summary() + + for callback in callbacks: + self.score_tracing_handler.send_score( + callback=callback, + run_id=run_id, + score=score, + errors=[errors], + ) + + def compute_and_save_summary(self): + self.logger.info("Computing and saving average results...") + + success_count = sum(1 for r in self.task_results if r.score == 1.0) + success_rate = success_count / len(self.task_results) * 100 + avg_time = statistics.mean(r.total_time for r in self.task_results) + + summary = RunSummary( + model_name=self.model_name, + success_rate=round(success_rate, 2), + avg_time=round(avg_time, 3), + total_tasks=len(self.task_results), + ) + self.csv_initialize(self.summary_filename, RunSummary) + self.csv_writerow(self.summary_filename, summary) + + +def run_benchmark( + llm: BaseChatModel, + out_dir: Path, + tasks: List[ImageReasoningTask[BaseModel]], + bench_logger: logging.Logger, + experiment_id: uuid.UUID = uuid.uuid4(), +): + benchmark = VLMBenchmark( + tasks=tasks, + logger=bench_logger, + model_name=get_llm_model_name(llm), + results_dir=out_dir, + ) + + for task in tasks: + agent = create_structured_output_runnable( + llm=llm, + structured_output=task.structured_output, + system_prompt=task.get_system_prompt(), + logger=bench_logger, + ) + + benchmark.run_next(agent=agent, experiment_id=experiment_id) + + bench_logger.info("===============================================================") + bench_logger.info("ALL TASKS DONE. BENCHMARK COMPLETED!") + bench_logger.info("===============================================================") diff --git a/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py b/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py new file mode 100644 index 000000000..97f769b93 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/interfaces.py @@ -0,0 +1,174 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from typing import Generic, List, Literal, Optional, TypeVar + +from langchain_core.messages import BaseMessage +from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT +from pydantic import BaseModel, ConfigDict, ValidationError + +loggers_type = logging.Logger + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) + + +IMAGE_REASONING_SYSTEM_PROMPT = "You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response in requested structured output format." + + +class LangchainRawOutputModel(BaseModel): + """ + A Pydantic model for wrapping Langchain message parsing results from a structured output agent. See documentation for more details: + https://github.com/langchain-ai/langchain/blob/02001212b0a2b37d90451d8493089389ea220cab/libs/core/langchain_core/language_models/chat_models.py#L1430-L1432 + + + Attributes + ---------- + raw : BaseMessage + The original raw message object from Langchain before parsing. + parsed : BaseModel + The parsed and validated Pydantic model instance derived from the raw message. + parsing_error : Optional[BaseException] + Any exception that occurred during the parsing process, None if parsing + was successful. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + raw: BaseMessage + parsed: BaseModel + parsing_error: Optional[BaseException] + + +class TaskValidationError(Exception): + pass + + +class ImageReasoningTask(ABC, Generic[BaseModelT]): + complexity: Literal["easy", "medium", "hard"] + recursion_limit: int = DEFAULT_RECURSION_LIMIT + + def __init__( + self, + logger: loggers_type | None = None, + ) -> None: + """ + Abstract base class representing a complete image reasoning task to be validated. + + Each Task has a consistent prompt and structured output schema, along + with validation methods that check the output against the expected result. + + Attributes + ---------- + logger : logging.Logger + Logger for recording task validation results and errors. + """ + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + self.question: str + self.images_paths: List[str] + + def set_logger(self, logger: loggers_type): + self.logger = logger + + @property + @abstractmethod + def structured_output(self) -> type[BaseModelT]: + """Structured output that agent should return.""" + pass + + @property + @abstractmethod + def type(self) -> str: + """Type of task, for example: image_reasoning""" + pass + + def get_system_prompt(self) -> str: + """Get the system prompt that will be passed to agent + + Returns + ------- + str + System prompt + """ + return IMAGE_REASONING_SYSTEM_PROMPT + + @abstractmethod + def get_prompt(self) -> str: + """Get the task instruction - the prompt that will be passed to agent. + + Returns + ------- + str + Prompt + """ + pass + + @abstractmethod + def validate(self, output: BaseModelT) -> bool: + """Validate result of the task.""" + pass + + @abstractmethod + def get_images(self) -> List[str]: + """Get the images related to the task. + + Returns + ------- + List[str] + List of image paths + """ + pass + + def get_structured_output_from_messages( + self, messages: List[BaseMessage] + ) -> BaseModelT | None: + """Extract and validate structured output from a list of messages. + + Iterates through messages in reverse order, attempting to find the message that is + a LangchainRawOutputModel containing the structured output. + + Parameters + ---------- + messages : List[BaseMessage] + List of messages to search for structured output. + + Returns + ------- + BaseModelT | None + The first valid structured output found that matches the task's expected + output type, or None if no valid structured output is found. + + Raises + ------ + TaskValidationError + If a message contains a parsing error during validation. + """ + for message in reversed(messages): + if isinstance(message, dict): + try: + validated_message = LangchainRawOutputModel.model_validate(message) + if validated_message.parsing_error is not None: + raise TaskValidationError( + f"Parsing error: {validated_message.parsing_error}" + ) + + parsed = validated_message.parsed + if isinstance(parsed, self.structured_output): + return parsed + except ValidationError: + continue + return None diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_1.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_1.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_1.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_1.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_2.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_2.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_2.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_2.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_3.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_3.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_3.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_3.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_4.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_4.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_4.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_4.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_5.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_5.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_5.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_5.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_6.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_6.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_6.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_6.jpg diff --git a/src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_7.jpg b/src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_7.jpg similarity index 100% rename from src/rai_bench/rai_bench/tool_calling_agent/predefined/images/image_7.jpg rename to src/rai_bench/rai_bench/vlm_benchmark/predefined/images/image_7.jpg diff --git a/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py b/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py new file mode 100644 index 000000000..bd1d9cd61 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/predefined/tasks.py @@ -0,0 +1,112 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, cast + +from pydantic import BaseModel + +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask +from rai_bench.vlm_benchmark.tasks.tasks import BoolImageTask, BoolImageTaskInput + +IMG_PATH = "src/rai_bench/rai_bench/vlm_benchmark/predefined/images/" +true_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door on the left from the desk?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is the light on in the room?", + images_paths=[IMG_PATH + "image_2.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_2.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Are there any pictures on the wall?", + images_paths=[IMG_PATH + "image_3.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Are there 3 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is there a plant behind the rack?", + images_paths=[IMG_PATH + "image_5.jpg"], + expected_answer=True, + ), + BoolImageTaskInput( + question="Is there a pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + expected_answer=True, + ), +] +false_response_inputs: List[BoolImageTaskInput] = [ + BoolImageTaskInput( + question="Is the door open?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is someone in the room?", + images_paths=[IMG_PATH + "image_1.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Do you see the plant?", + images_paths=[IMG_PATH + "image_3.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Are there 4 pictures on the wall?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a rack on the left from the sofa?", + images_paths=[IMG_PATH + "image_4.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a plant on the right from the window?", + images_paths=[IMG_PATH + "image_6.jpg"], + expected_answer=False, + ), + BoolImageTaskInput( + question="Is there a red pillow on the armchair?", + images_paths=[IMG_PATH + "image_7.jpg"], + expected_answer=False, + ), +] + + +def get_spatial_tasks() -> List[ImageReasoningTask[BaseModel]]: + true_tasks = [ + BoolImageTask( + task_input=input_item, + ) + for input_item in true_response_inputs + ] + false_tasks = [ + BoolImageTask( + task_input=input_item, + ) + for input_item in false_response_inputs + ] + return cast(List[ImageReasoningTask[BaseModel]], true_tasks + false_tasks) diff --git a/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py b/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py new file mode 100644 index 000000000..5d8b77b71 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py @@ -0,0 +1,35 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from uuid import UUID + +from pydantic import BaseModel, Field + + +class TaskResult(BaseModel): + task_prompt: str = Field(..., description="The task prompt.") + system_prompt: str = Field(..., description="The system prompt.") + complexity: str = Field(..., description="Complexity of the task.") + type: str = Field( + ..., description="Type of task, for example: bool_response_image_task" + ) + model_name: str = Field(..., description="Name of the LLM.") + score: float = Field( + ..., + description="Value between 0 and 1.", + ) + + total_time: float = Field(..., description="Total time taken to complete the task.") + run_id: UUID = Field(..., description="UUID of the task run.") diff --git a/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py b/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py new file mode 100644 index 000000000..639b50400 --- /dev/null +++ b/src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py @@ -0,0 +1,76 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import List + +from pydantic import BaseModel, Field +from rai.messages import preprocess_image + +from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask + +loggers_type = logging.Logger + + +class BoolAnswerWithJustification(BaseModel): + """A boolean answer to the user question along with justification for the answer.""" + + answer: bool + justification: str + + +class BoolImageTaskInput(BaseModel): + question: str = Field(..., description="The question to be answered.") + images_paths: List[str] = Field( + ..., + description="List of image file paths to be used for answering the question.", + ) + expected_answer: bool = Field( + ..., description="The expected answer to the question." + ) + + +class BoolImageTask(ImageReasoningTask[BoolAnswerWithJustification]): + complexity = "easy" + + def __init__( + self, + task_input: BoolImageTaskInput, + logger: loggers_type | None = None, + ) -> None: + super().__init__( + logger=logger, + ) + self.question = task_input.question + self.images_paths = task_input.images_paths + self.expected_answer = task_input.expected_answer + + @property + def structured_output(self) -> type[BoolAnswerWithJustification]: + return BoolAnswerWithJustification + + @property + def type(self) -> str: + return "bool_response_image_task" + + def get_prompt(self): + return self.question + + def get_images(self): + images = [preprocess_image(image_path) for image_path in self.images_paths] + return images + + def validate(self, output: BoolAnswerWithJustification) -> bool: + return output.answer == self.expected_answer diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 4f2523a07..714ff448a 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.2.1" +version = "2.5.0" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" diff --git a/src/rai_core/rai/agents/langchain/__init__.py b/src/rai_core/rai/agents/langchain/__init__.py index b608b97bb..186ad47d8 100644 --- a/src/rai_core/rai/agents/langchain/__init__.py +++ b/src/rai_core/rai/agents/langchain/__init__.py @@ -19,6 +19,7 @@ create_react_runnable, create_state_based_runnable, ) +from .invocation_helpers import invoke_llm_with_tracing from .react_agent import ReActAgent from .state_based_agent import BaseStateBasedAgent, StateBasedConfig @@ -32,5 +33,6 @@ "StateBasedConfig", "create_react_runnable", "create_state_based_runnable", + "invoke_llm_with_tracing", "newMessageBehaviorType", ] diff --git a/src/rai_core/rai/agents/langchain/core/__init__.py b/src/rai_core/rai/agents/langchain/core/__init__.py index 9aade0321..e8b0b9d9b 100644 --- a/src/rai_core/rai/agents/langchain/core/__init__.py +++ b/src/rai_core/rai/agents/langchain/core/__init__.py @@ -25,6 +25,7 @@ create_react_runnable, ) from .state_based_agent import create_state_based_runnable +from .structured_output_agent import create_structured_output_runnable from .tool_runner import SubAgentToolRunner, ToolRunner __all__ = [ @@ -38,5 +39,6 @@ "create_megamind", "create_react_runnable", "create_state_based_runnable", + "create_structured_output_runnable", "get_initial_megamind_state", ] diff --git a/src/rai_core/rai/agents/langchain/core/conversational_agent.py b/src/rai_core/rai/agents/langchain/core/conversational_agent.py index 8b940cdf6..e008fdb7d 100644 --- a/src/rai_core/rai/agents/langchain/core/conversational_agent.py +++ b/src/rai_core/rai/agents/langchain/core/conversational_agent.py @@ -17,17 +17,20 @@ from functools import partial from typing import List, Optional, TypedDict +from deprecated import deprecated from langchain.chat_models.base import BaseChatModel from langchain_core.messages import ( BaseMessage, SystemMessage, ) +from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt.tool_node import tools_condition from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing class State(TypedDict): @@ -39,6 +42,7 @@ def agent( logger: logging.Logger, system_prompt: str | SystemMessage, state: State, + config: RunnableConfig, ): logger.info("Running thinker") @@ -54,11 +58,17 @@ def agent( else system_prompt ) state["messages"].insert(0, system_msg) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) return state +@deprecated( + "Use rai.agents.langchain.core.create_react_runnable instead. " + "Support for the conversational agent will be removed in the 3.0 release." +) def create_conversational_agent( llm: BaseChatModel, tools: List[BaseTool], diff --git a/src/rai_core/rai/agents/langchain/core/react_agent.py b/src/rai_core/rai/agents/langchain/core/react_agent.py index 34424a84e..a15092404 100644 --- a/src/rai_core/rai/agents/langchain/core/react_agent.py +++ b/src/rai_core/rai/agents/langchain/core/react_agent.py @@ -21,13 +21,14 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition from typing_extensions import TypedDict from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing from rai.initialization import get_llm_model from rai.messages import SystemMultimodalMessage @@ -48,6 +49,7 @@ def llm_node( llm: BaseChatModel, system_prompt: Optional[str | SystemMultimodalMessage], state: ReActAgentState, + config: RunnableConfig, ): """Process messages using the LLM. @@ -57,6 +59,8 @@ def llm_node( The language model to use for processing state : ReActAgentState Current state containing messages + config : RunnableConfig + Configuration including callbacks for tracing Returns ------- @@ -75,7 +79,9 @@ def llm_node( # at this point, state['messages'] length should at least be 1 if not isinstance(state["messages"][0], SystemMessage): state["messages"].insert(0, SystemMessage(content=system_prompt)) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) diff --git a/src/rai_core/rai/agents/langchain/core/state_based_agent.py b/src/rai_core/rai/agents/langchain/core/state_based_agent.py index 0ffcff6d5..ea651acd3 100644 --- a/src/rai_core/rai/agents/langchain/core/state_based_agent.py +++ b/src/rai_core/rai/agents/langchain/core/state_based_agent.py @@ -26,12 +26,13 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition from rai.agents.langchain.core.tool_runner import ToolRunner +from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing from rai.initialization import get_llm_model from rai.messages import HumanMultimodalMessage, SystemMultimodalMessage @@ -52,6 +53,7 @@ def llm_node( llm: BaseChatModel, system_prompt: Optional[str | SystemMultimodalMessage], state: ReActAgentState, + config: RunnableConfig, ): """Process messages using the LLM. @@ -61,6 +63,8 @@ def llm_node( The language model to use for processing state : ReActAgentState Current state containing messages + config : RunnableConfig + Configuration including callbacks for tracing Returns ------- @@ -79,7 +83,9 @@ def llm_node( # at this point, state['messages'] length should at least be 1 if not isinstance(state["messages"][0], SystemMessage): state["messages"].insert(0, SystemMessage(content=system_prompt)) - ai_msg = llm.invoke(state["messages"]) + + # Invoke LLM with tracing if it is configured and available + ai_msg = invoke_llm_with_tracing(llm, state["messages"], config) state["messages"].append(ai_msg) diff --git a/src/rai_core/rai/agents/langchain/core/structured_output_agent.py b/src/rai_core/rai/agents/langchain/core/structured_output_agent.py new file mode 100644 index 000000000..7c1878cd3 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/core/structured_output_agent.py @@ -0,0 +1,60 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from functools import partial +from typing import Optional + +from langchain.chat_models.base import BaseChatModel +from langchain_core.messages import ( + SystemMessage, +) +from langgraph.graph import START, StateGraph +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel + +from rai.agents.langchain.core.conversational_agent import State, agent + + +def create_structured_output_runnable( + llm: BaseChatModel, + structured_output: type[BaseModel], + system_prompt: str | SystemMessage, + logger: Optional[logging.Logger] = None, + debug: bool = False, +) -> CompiledStateGraph: + _logger = None + if logger: + _logger = logger + else: + _logger = logging.getLogger(__name__) + + _logger.info("Creating structured output runnable") + + llm_with_structured_output = llm.with_structured_output( + schema=structured_output, include_raw=True + ) + + workflow = StateGraph(State) + + workflow.add_node( + "thinker", partial(agent, llm_with_structured_output, _logger, system_prompt) + ) + + workflow.add_edge(START, "thinker") + + app = workflow.compile(debug=debug) + _logger.info("State based agent created") + return app diff --git a/src/rai_core/rai/agents/langchain/invocation_helpers.py b/src/rai_core/rai/agents/langchain/invocation_helpers.py new file mode 100644 index 000000000..d3ecee749 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/invocation_helpers.py @@ -0,0 +1,76 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, List, Optional + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.runnables import RunnableConfig + +from rai.initialization import get_tracing_callbacks + +logger = logging.getLogger(__name__) + + +def invoke_llm_with_tracing( + llm: BaseChatModel, + messages: List[BaseMessage], + config: Optional[RunnableConfig] = None, +) -> Any: + """ + Invoke an LLM with enhanced tracing callbacks. + + This function automatically adds tracing callbacks (like Langfuse) to LLM calls + within LangGraph nodes, solving the callback propagation issue. + + Tracing is controlled by config.toml. If the file is missing, no tracing is applied. + + Parameters + ---------- + llm : BaseChatModel + The language model to invoke + messages : List[BaseMessage] + Messages to send to the LLM + config : Optional[RunnableConfig] + Existing configuration (may contain some callbacks) + + Returns + ------- + Any + The LLM response + """ + tracing_callbacks = get_tracing_callbacks() + + if len(tracing_callbacks) == 0: + # No tracing callbacks available, use config as-is + return llm.invoke(messages, config=config) + + # Create enhanced config with tracing callbacks + enhanced_config = config.copy() if config else {} + + # Add tracing callbacks to existing callbacks + existing_callbacks = config.get("callbacks", []) if config else [] + + if hasattr(existing_callbacks, "handlers"): + # Merge with existing CallbackManager + all_callbacks = existing_callbacks.handlers + tracing_callbacks + elif isinstance(existing_callbacks, list): + all_callbacks = existing_callbacks + tracing_callbacks + else: + all_callbacks = tracing_callbacks + + enhanced_config["callbacks"] = all_callbacks + + return llm.invoke(messages, config=enhanced_config) diff --git a/src/rai_core/rai/communication/hri_connector.py b/src/rai_core/rai/communication/hri_connector.py index 17f18d416..bffc06ba0 100644 --- a/src/rai_core/rai/communication/hri_connector.py +++ b/src/rai_core/rai/communication/hri_connector.py @@ -103,12 +103,11 @@ def from_langchain( seq_no: int = 0, seq_end: bool = False, ) -> "HRIMessage": + text = message.text() if isinstance(message, RAIMultimodalMessage): - text = message.text images = message.images audios = message.audios else: - text = str(message.content) images = None audios = None if message.type not in ["ai", "human"]: diff --git a/src/rai_core/rai/communication/ros2/api/service.py b/src/rai_core/rai/communication/ros2/api/service.py index 464625aed..06dac1bc1 100644 --- a/src/rai_core/rai/communication/ros2/api/service.py +++ b/src/rai_core/rai/communication/ros2/api/service.py @@ -14,6 +14,7 @@ import os import uuid +from threading import Lock from typing import ( Any, Callable, @@ -30,6 +31,7 @@ import rclpy.qos import rclpy.subscription import rclpy.task +from rclpy.client import Client from rclpy.service import Service from rai.communication.ros2.api.base import ( @@ -39,12 +41,18 @@ class ROS2ServiceAPI(BaseROS2API): - """Handles ROS2 service operations including calling services.""" + """Handles ROS 2 service operations including calling services.""" def __init__(self, node: rclpy.node.Node) -> None: self.node = node self._logger = node.get_logger() self._services: Dict[str, Service] = {} + self._persistent_clients: Dict[str, Client] = {} + self._persistent_clients_lock = Lock() + + def release_client(self, service_name: str) -> bool: + with self._persistent_clients_lock: + return self._persistent_clients.pop(service_name, None) is not None def call_service( self, @@ -52,30 +60,57 @@ def call_service( service_type: str, request: Any, timeout_sec: float = 5.0, + *, + reuse_client: bool = True, ) -> Any: """ - Call a ROS2 service. + Call a ROS 2 service. Args: - service_name: Name of the service to call - service_type: ROS2 service type as string - request: Request message content + service_name: Fully-qualified service name. + service_type: ROS 2 service type string (e.g., 'std_srvs/srv/SetBool'). + request: Request payload dict. + timeout_sec: Seconds to wait for availability/response. + reuse_client: Reuse a cached client. Client creation is synchronized; set + False to create a new client per call. Returns: - The response message + Response message instance. + + Raises: + ValueError: Service not available within the timeout. + AttributeError: Service type or request cannot be constructed. + + Note: + With reuse_client=True, access to the cached client (including the + service call) is serialized by a lock, preventing concurrent calls + through the same client. Use reuse_client=False for per-call clients + when concurrent service calls are required. """ srv_msg, srv_cls = self.build_ros2_service_request(service_type, request) - service_client = self.node.create_client(srv_cls, service_name) # type: ignore - client_ready = service_client.wait_for_service(timeout_sec=timeout_sec) - if not client_ready: - raise ValueError( - f"Service {service_name} not ready within {timeout_sec} seconds. " - "Try increasing the timeout or check if the service is running." - ) - if os.getenv("ROS_DISTRO") == "humble": - return service_client.call(srv_msg) + + def _call_service(client: Client, timeout_sec: float) -> Any: + is_service_available = client.wait_for_service(timeout_sec=timeout_sec) + if not is_service_available: + raise ValueError( + f"Service {service_name} not ready within {timeout_sec} seconds. " + "Try increasing the timeout or check if the service is running." + ) + if os.getenv("ROS_DISTRO") == "humble": + return client.call(srv_msg) + else: + return client.call(srv_msg, timeout_sec=timeout_sec) + + if reuse_client: + with self._persistent_clients_lock: + client = self._persistent_clients.get(service_name, None) + if client is None: + client = self.node.create_client(srv_cls, service_name) # type: ignore + self._persistent_clients[service_name] = client + return _call_service(client, timeout_sec) else: - return service_client.call(srv_msg, timeout_sec=timeout_sec) + client = self.node.create_client(srv_cls, service_name) # type: ignore + return _call_service(client, timeout_sec) def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]: return self.node.get_service_names_and_types() diff --git a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py index 985de6c16..7c1597a56 100644 --- a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py +++ b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py @@ -30,6 +30,9 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: f"{self.__class__.__name__} instance must have an attribute '_service_api' of type ROS2ServiceAPI" ) + def release_client(self, service_name: str) -> bool: + return self._service_api.release_client(service_name) + def service_call( self, message: ROS2Message, @@ -37,6 +40,7 @@ def service_call( timeout_sec: float = 5.0, *, msg_type: str, + reuse_client: bool = True, **kwargs: Any, ) -> ROS2Message: msg = self._service_api.call_service( @@ -44,6 +48,7 @@ def service_call( service_type=msg_type, request=message.payload, timeout_sec=timeout_sec, + reuse_client=reuse_client, ) return ROS2Message( payload=msg, metadata={"msg_type": str(type(msg)), "service": target} diff --git a/src/rai_core/rai/initialization/model_initialization.py b/src/rai_core/rai/initialization/model_initialization.py index c7f4fa263..4496286f0 100644 --- a/src/rai_core/rai/initialization/model_initialization.py +++ b/src/rai_core/rai/initialization/model_initialization.py @@ -275,11 +275,16 @@ def get_embeddings_model( def get_tracing_callbacks( - override_use_langfuse: bool = False, override_use_langsmith: bool = False + config_path: Optional[str] = None, ) -> List[BaseCallbackHandler]: - config = load_config() + try: + config = load_config(config_path) + except Exception as e: + logger.warning(f"Failed to load config for tracing: {e}, tracing disabled") + return [] + callbacks: List[BaseCallbackHandler] = [] - if config.tracing.langfuse.use_langfuse or override_use_langfuse: + if config.tracing.langfuse.use_langfuse: from langfuse.callback import CallbackHandler # type: ignore public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None) @@ -294,7 +299,7 @@ def get_tracing_callbacks( ) callbacks.append(callback) - if config.tracing.langsmith.use_langsmith or override_use_langsmith: + if config.tracing.langsmith.use_langsmith: os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = config.tracing.project api_key = os.getenv("LANGCHAIN_API_KEY", None) diff --git a/src/rai_core/rai/messages/multimodal.py b/src/rai_core/rai/messages/multimodal.py index b1646f37d..1b5129d78 100644 --- a/src/rai_core/rai/messages/multimodal.py +++ b/src/rai_core/rai/messages/multimodal.py @@ -63,10 +63,6 @@ def __init__( _content.extend(_image_content) self.content = _content - @property - def text(self) -> str: - return self.content[0]["text"] - class HumanMultimodalMessage(HumanMessage, MultimodalMessage): def __repr_args__(self) -> Any: diff --git a/src/rai_core/rai/tools/ros2/navigation/__init__.py b/src/rai_core/rai/tools/ros2/navigation/__init__.py index 8c5a94838..4142fcc68 100644 --- a/src/rai_core/rai/tools/ros2/navigation/__init__.py +++ b/src/rai_core/rai/tools/ros2/navigation/__init__.py @@ -20,6 +20,7 @@ Nav2Toolkit, NavigateToPoseTool, ) +from .nav2_blocking import NavigateToPoseBlockingTool __all__ = [ "CancelNavigateToPoseTool", @@ -27,5 +28,6 @@ "GetNavigateToPoseResultTool", "GetOccupancyGridTool", "Nav2Toolkit", + "NavigateToPoseBlockingTool", "NavigateToPoseTool", ] diff --git a/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py b/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py new file mode 100644 index 000000000..b51a79207 --- /dev/null +++ b/src/rai_core/rai/tools/ros2/navigation/nav2_blocking.py @@ -0,0 +1,69 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from geometry_msgs.msg import PoseStamped, Quaternion +from nav2_msgs.action import NavigateToPose +from pydantic import BaseModel, Field +from rclpy.action import ActionClient +from tf_transformations import quaternion_from_euler + +from rai.tools.ros2.base import BaseROS2Tool + + +class NavigateToPoseBlockingToolInput(BaseModel): + x: float = Field(..., description="The x coordinate of the pose") + y: float = Field(..., description="The y coordinate of the pose") + z: float = Field(..., description="The z coordinate of the pose") + yaw: float = Field(..., description="The yaw angle of the pose") + + +class NavigateToPoseBlockingTool(BaseROS2Tool): + name: str = "navigate_to_pose_blocking" + description: str = "Navigate to a specific pose" + frame_id: str = Field( + default="map", description="The frame id of the Nav2 stack (map, odom, etc.)" + ) + action_name: str = Field( + default="navigate_to_pose", description="The name of the Nav2 action" + ) + args_schema: Type[NavigateToPoseBlockingToolInput] = NavigateToPoseBlockingToolInput + + def _run(self, x: float, y: float, z: float, yaw: float) -> str: + action_client = ActionClient( + self.connector.node, NavigateToPose, self.action_name + ) + + pose = PoseStamped() + pose.header.frame_id = self.frame_id + pose.header.stamp = self.connector.node.get_clock().now().to_msg() + pose.pose.position.x = x + pose.pose.position.y = y + pose.pose.position.z = z + quat = quaternion_from_euler(0, 0, yaw) + pose.pose.orientation = Quaternion(x=quat[0], y=quat[1], z=quat[2], w=quat[3]) + + goal = NavigateToPose.Goal() + goal.pose = pose + + result = action_client.send_goal(goal) + + if result is None: + return "Navigate to pose action failed. Please try again." + + if result.result.error_code != 0: + return f"Navigate to pose action failed. Error code: {result.result.error_code}" + + return "Navigate to pose successful." diff --git a/tests/agents/langchain/test_langchain_agent.py b/tests/agents/langchain/test_langchain_agent.py index c220f7bbe..44e883120 100644 --- a/tests/agents/langchain/test_langchain_agent.py +++ b/tests/agents/langchain/test_langchain_agent.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from collections import deque from typing import List +from unittest.mock import MagicMock, patch import pytest +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.language_models.fake_chat_models import ParrotFakeChatModel +from langchain_core.runnables import RunnableConfig +from rai.agents.langchain import invoke_llm_with_tracing from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType +from rai.initialization import get_tracing_callbacks +from rai.messages import HumanMultimodalMessage @pytest.mark.parametrize( @@ -39,3 +47,119 @@ def test_reduce_messages( output_ = LangChainAgent._apply_reduction_behavior(new_message_behavior, buffer) assert output == output_ assert buffer == deque(out_buffer) + + +class TestTracingConfiguration: + """Test tracing configuration integration with langchain agents.""" + + def test_tracing_with_missing_config_file(self): + """Test that tracing gracefully handles missing config.toml file in langchain context.""" + # This should not crash even without config.toml + callbacks = get_tracing_callbacks() + assert len(callbacks) == 0 + + def test_tracing_with_config_file_present(self, test_config_toml): + """Test that tracing works when config.toml is present in langchain context.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=True, langsmith_enabled=False + ) + + try: + # Mock environment variables to avoid actual API calls + with patch.dict( + os.environ, + { + "LANGFUSE_PUBLIC_KEY": "test_key", + "LANGFUSE_SECRET_KEY": "test_secret", + }, + ): + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 1 callback for langfuse + assert len(callbacks) == 1 + finally: + cleanup() + + +class TestInvokeLLMWithTracing: + """Test the invoke_llm_with_tracing function.""" + + def test_invoke_llm_without_tracing(self): + """Test that invoke_llm_with_tracing works when no tracing callbacks are available.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock get_tracing_callbacks to return empty list (no config.toml) + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = [] + + result = invoke_llm_with_tracing(mock_llm, mock_messages) + + mock_llm.invoke.assert_called_once_with(mock_messages, config=None) + assert result == "test response" + + def test_invoke_llm_with_tracing(self): + """Test that invoke_llm_with_tracing works when tracing callbacks are available.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock get_tracing_callbacks to return some callbacks + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = ["tracing_callback"] + + _ = invoke_llm_with_tracing(mock_llm, mock_messages) + + # Verify that the LLM was called with enhanced config + mock_llm.invoke.assert_called_once() + call_args = mock_llm.invoke.call_args + assert call_args[0][0] == mock_messages + assert "callbacks" in call_args[1]["config"] + assert "tracing_callback" in call_args[1]["config"]["callbacks"] + + def test_invoke_llm_with_existing_config(self): + """Test that invoke_llm_with_tracing preserves existing config.""" + # Mock LLM + mock_llm = MagicMock() + mock_llm.invoke.return_value = "test response" + + # Mock messages + mock_messages = ["test message"] + + # Mock existing config + existing_config = {"callbacks": ["existing_callback"]} + + # Mock get_tracing_callbacks to return some callbacks + with patch( + "rai.agents.langchain.invocation_helpers.get_tracing_callbacks" + ) as mock_get_callbacks: + mock_get_callbacks.return_value = ["tracing_callback"] + + _ = invoke_llm_with_tracing(mock_llm, mock_messages, existing_config) + + # Verify that the LLM was called with enhanced config + mock_llm.invoke.assert_called_once() + call_args = mock_llm.invoke.call_args + assert call_args[0][0] == mock_messages + assert "callbacks" in call_args[1]["config"] + assert "existing_callback" in call_args[1]["config"]["callbacks"] + assert "tracing_callback" in call_args[1]["config"]["callbacks"] + + def test_invoke_llm_with_callback_integration(self): + """Test that invoke_llm_with_tracing works with a callback handler.""" + llm = ParrotFakeChatModel() + human_msg = HumanMultimodalMessage(content="human") + response = llm.invoke( + [human_msg], config=RunnableConfig(callbacks=[BaseCallbackHandler()]) + ) + assert response.content == [{"type": "text", "text": "human"}] diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index bc102a58a..d51f30378 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -138,11 +138,14 @@ def test_ros2_single_message_publish_wrong_qos_setup( shutdown_executors_and_threads(executors, threads) -def service_call_helper(service_name: str, service_api: ROS2ServiceAPI): +def invoke_set_bool_service( + service_name: str, service_api: ROS2ServiceAPI, reuse_client: bool = True +): response = service_api.call_service( service_name, service_type="std_srvs/srv/SetBool", request={"data": True}, + reuse_client=reuse_client, ) assert response.success assert response.message == "Test service called" @@ -164,7 +167,7 @@ def test_ros2_service_single_call( try: service_api = ROS2ServiceAPI(node) - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api) finally: shutdown_executors_and_threads(executors, threads) @@ -186,7 +189,30 @@ def test_ros2_service_multiple_calls( try: service_api = ROS2ServiceAPI(node) for _ in range(3): - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api, reuse_client=False) + finally: + shutdown_executors_and_threads(executors, threads) + + +@pytest.mark.parametrize( + "callback_group", + [MutuallyExclusiveCallbackGroup(), ReentrantCallbackGroup()], + ids=["MutuallyExclusiveCallbackGroup", "ReentrantCallbackGroup"], +) +def test_ros2_service_multiple_calls_with_reused_client( + ros_setup: None, request: pytest.FixtureRequest, callback_group: CallbackGroup +) -> None: + service_name = f"{request.node.originalname}_service" # type: ignore + node_name = f"{request.node.originalname}_node" # type: ignore + service_server = ServiceServer(service_name, callback_group) + node = Node(node_name) + executors, threads = multi_threaded_spinner([service_server, node]) + + try: + service_api = ROS2ServiceAPI(node) + for _ in range(3): + invoke_set_bool_service(service_name, service_api, reuse_client=True) + assert service_api.release_client(service_name), "Client not released" finally: shutdown_executors_and_threads(executors, threads) @@ -210,7 +236,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_threading( service_threads: List[threading.Thread] = [] for _ in range(10): thread = threading.Thread( - target=service_call_helper, args=(service_name, service_api) + target=invoke_set_bool_service, args=(service_name, service_api) ) service_threads.append(thread) thread.start() @@ -241,7 +267,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_multiprocessing( service_api = ROS2ServiceAPI(node) with Pool(10) as pool: pool.map( - lambda _: service_call_helper(service_name, service_api), range(10) + lambda _: invoke_set_bool_service(service_name, service_api), range(10) ) finally: shutdown_executors_and_threads(executors, threads) diff --git a/tests/conftest.py b/tests/conftest.py index 97ceef6f0..adb9e1850 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,3 +11,132 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import os +import tempfile + +import pytest + + +@pytest.fixture +def test_config_toml(): + """ + Fixture to create a temporary test config.toml file with tracing enabled. + + Returns + ------- + tuple + (config_path, cleanup_function) - The path to the config file and a function to clean it up + """ + + def _create_config(langfuse_enabled=False, langsmith_enabled=False): + # Create a temporary config.toml file + f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) + + # Base config sections (always required) + config_content = """[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "text-embedding-ada-002" + +[aws] +simple_model = "anthropic.claude-instant-v1" +complex_model = "anthropic.claude-v2" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-3.5-turbo" +complex_model = "gpt-4" +embeddings_model = "text-embedding-ada-002" +base_url = "https://api.openai.com/v1" + +[ollama] +simple_model = "llama2" +complex_model = "llama2" +embeddings_model = "llama2" +base_url = "http://localhost:11434" + +[tracing] +project = "test-project" + +[tracing.langfuse] +use_langfuse = {langfuse_enabled} +host = "http://localhost:3000" + +[tracing.langsmith] +use_langsmith = {langsmith_enabled} +host = "https://api.smith.langchain.com" +""".format( + langfuse_enabled=str(langfuse_enabled).lower(), + langsmith_enabled=str(langsmith_enabled).lower(), + ) + + f.write(config_content) + f.close() + + def cleanup(): + try: + f.close() # Ensure file is properly closed + os.unlink(f.name) + except (OSError, PermissionError): + pass # File might already be deleted or have permission issues + + return f.name, cleanup + + return _create_config + + +@pytest.fixture +def test_config_no_tracing(): + """ + Fixture to create a temporary test config.toml file with no tracing section. + + Returns + ------- + tuple + (config_path, cleanup_function) - The path to the config file and a function to clean it up + """ + + def _create_config(): + # Create a temporary config.toml file + f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) + + # Base config sections (always required) + config_content = """[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "text-embedding-ada-002" + +[aws] +simple_model = "anthropic.claude-instant-v1" +complex_model = "anthropic.claude-v2" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-3.5-turbo" +complex_model = "gpt-4" +embeddings_model = "text-embedding-ada-002" +base_url = "https://api.openai.com/v1" + +[ollama] +simple_model = "llama2" +complex_model = "llama2" +embeddings_model = "llama2" +base_url = "http://localhost:11434" +""" + + f.write(config_content) + f.close() + + def cleanup(): + try: + f.close() # Ensure file is properly closed + os.unlink(f.name) + except (OSError, PermissionError): + pass # File might already be deleted or have permission issues + + return f.name, cleanup + + return _create_config diff --git a/tests/initialization/test_tracing.py b/tests/initialization/test_tracing.py new file mode 100644 index 000000000..659941755 --- /dev/null +++ b/tests/initialization/test_tracing.py @@ -0,0 +1,73 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +from rai.initialization import get_tracing_callbacks + + +class TestInitializationTracing: + """Test the initialization module's tracing functionality.""" + + def test_tracing_with_missing_config_file(self): + """Test that tracing gracefully handles missing config.toml file.""" + # This should not crash even without config.toml + callbacks = get_tracing_callbacks() + assert len(callbacks) == 0 + + def test_tracing_with_config_file_present_tracing_disabled(self, test_config_toml): + """Test that tracing works when config.toml is present but tracing is disabled.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=False, langsmith_enabled=False + ) + + try: + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 0 callbacks since both langfuse and langsmith are disabled + assert len(callbacks) == 0 + finally: + cleanup() + + def test_tracing_with_config_file_present_tracing_enabled(self, test_config_toml): + """Test that tracing works when config.toml is present and tracing is enabled.""" + config_path, cleanup = test_config_toml( + langfuse_enabled=True, langsmith_enabled=False + ) + + try: + # Mock environment variables to avoid actual API calls + with patch.dict( + os.environ, + { + "LANGFUSE_PUBLIC_KEY": "test_key", + "LANGFUSE_SECRET_KEY": "test_secret", + }, + ): + callbacks = get_tracing_callbacks(config_path=config_path) + # Should return 1 callback for langfuse + assert len(callbacks) == 1 + finally: + cleanup() + + def test_tracing_with_valid_config_file_no_tracing(self, test_config_no_tracing): + """Test that tracing works when config.toml is valid but has no tracing sections.""" + config_path, cleanup = test_config_no_tracing() + + try: + # This should not crash, should return empty callbacks + callbacks = get_tracing_callbacks(config_path=config_path) + assert len(callbacks) == 0 + finally: + cleanup() diff --git a/tests/messages/test_multimodal_message.py b/tests/messages/test_multimodal_message.py new file mode 100644 index 000000000..221ec0fed --- /dev/null +++ b/tests/messages/test_multimodal_message.py @@ -0,0 +1,37 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from rai.messages import HumanMultimodalMessage + + +class TestMultimodalMessage: + """Test the MultimodalMessage class and expected behaviors.""" + + def test_human_multimodal_message_text_simple(self): + """Test text() method with simple text content.""" + msg = HumanMultimodalMessage(content="Hello world") + assert msg.text() == "Hello world" + assert isinstance(msg.text(), str) + + def test_human_multimodal_message_text_with_images(self): + """Test text() method with text and images.""" + # Use a small valid base64 image (1x1 pixel PNG) + valid_base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + msg = HumanMultimodalMessage( + content="Look at this image", images=[valid_base64_image] + ) + assert msg.text() == "Look at this image" + # Should only return text type blocks, not image content + assert valid_base64_image not in msg.text()