diff --git a/docs/extensions/perception.md b/docs/extensions/perception.md index 3680b559e..b9f12e77e 100644 --- a/docs/extensions/perception.md +++ b/docs/extensions/perception.md @@ -1,29 +1,5 @@ --8<-- "src/rai_extensions/rai_perception/README.md:sec1" -Agents create two ROS 2 Nodes: `grounding_dino` and `grounded_sam` using [ROS2Connector](../API_documentation/connectors/ROS_2_Connectors.md). -These agents can be triggered by ROS2 services: - -- `grounding_dino_classify`: `rai_interfaces/srv/RAIGroundingDino` -- `grounded_sam_segment`: `rai_interfaces/srv/RAIGroundedSam` - -> [!TIP] -> -> If you wish to integrate open-set detection into your ros2 launch file, a premade launch -> file can be found in `rai/src/rai_bringup/launch/openset.launch.py` - -> [!NOTE] -> The weights will be downloaded to `~/.cache/rai` directory. - -## RAI Tools - -`rai_perception` package contains tools that can be used by [RAI LLM agents](../tutorials/walkthrough.md) -enhance their perception capabilities. For more information on RAI Tools see -[Tool use and development](../tutorials/tools.md) tutorial. - ---8<-- "src/rai_extensions/rai_perception/README.md:sec3" - -> [!TIP] -> -> you can try example below with [rosbotxl demo](../demos/rosbot_xl.md) binary. -> The binary exposes `/camera/camera/color/image_raw` and `/camera/camera/depth/image_raw` topics. --8<-- "src/rai_extensions/rai_perception/README.md:sec4" + +--8<-- "src/rai_extensions/rai_perception/README.md:sec5" diff --git a/src/rai_extensions/rai_perception/README.md b/src/rai_extensions/rai_perception/README.md index ab3028e23..4c02f3ee0 100644 --- a/src/rai_extensions/rai_perception/README.md +++ b/src/rai_extensions/rai_perception/README.md @@ -2,125 +2,139 @@ # RAI Perception -This package provides ROS2 integration with [Idea-Research GroundingDINO Model](https://github.com/IDEA-Research/GroundingDINO) and [Grounded-SAM-2, RobotecAI fork](https://github.com/RobotecAI/Grounded-SAM-2) for object detection, segmentation, and gripping point calculation. The `GroundedSamAgent` and `GroundingDinoAgent` are ROS2 service nodes that can be readily added to ROS2 applications. It also provides tools that can be used with [RAI LLM agents](../tutorials/walkthrough.md) to construct conversational scenarios. +RAI Perception brings powerful computer vision capabilities to your ROS2 applications. It integrates [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) and [Grounded-SAM-2](https://github.com/RobotecAI/Grounded-SAM-2) to detect objects, create segmentation masks, and calculate gripping points. -In addition to these building blocks, this package includes utilities to facilitate development, such as a ROS2 client that demonstrates interactions with agent nodes. +The package includes two ready-to-use ROS2 service nodes (`GroundedSamAgent` and `GroundingDinoAgent`) that you can easily add to your applications. It also provides tools that work seamlessly with [RAI LLM agents](../tutorials/walkthrough.md) to build conversational robot scenarios. -## Installation +## Prerequisites + +Before installing `rai-perception`, ensure you have: -While installing `rai_perception` via Pip is being actively worked on, to incorporate it into your application, you will need to set up a ROS2 workspace. +1. **ROS2 installed** (Jazzy recommended, or Humble). If you don't have ROS2 yet, follow the official ROS2 installation guide for [jazzy](https://docs.ros.org/en/jazzy/Installation.html) or [humble](https://docs.ros.org/en/humble/Installation.html). +2. **Python 3.8+** and `pip` installed (usually pre-installed on Ubuntu). +3. **NVIDIA GPU** with CUDA support (required for optimal performance). +4. **wget** installed (required for downloading model weights): + ```bash + sudo apt install wget + ``` -### ROS2 Workspace Setup +## Installation -Create a ROS2 workspace and copy this package: +**Step 1:** Source ROS2 in your terminal: ```bash -mkdir -p ~/rai_perception_ws/src -cd ~/rai_perception_ws/src - -# only checkout rai_perception package -git clone --depth 1 --branch main https://github.com/RobotecAI/rai.git temp -cd temp -git archive --format=tar --prefix=rai_perception/ HEAD:src/rai_extensions/rai_perception | tar -xf - -mv rai_perception ../rai_perception -cd .. -rm -rf temp +# For ROS2 Jazzy (recommended) +source /opt/ros/jazzy/setup.bash + +# For ROS2 Humble +source /opt/ros/humble/setup.bash ``` -### ROS2 Dependencies +**Step 2:** Install ROS2 dependencies. `rai-perception` requires its ROS2 packages that needs to be installed separately: + +```bash +# Update package lists first +sudo apt update + +# Install rai_interfaces as a debian package +sudo apt install ros-jazzy-rai-interfaces # or ros-humble-rai-interfaces for Humble +``` -Add required ROS dependencies. From the workspace root, run +**Step 3:** Install `rai-perception` via pip: ```bash -rosdep install --from-paths src --ignore-src -r +pip install rai-perception ``` -### Build and Run +> [!TIP] +> It's recommended to install `rai-perception` in a virtual environment to avoid conflicts with other Python packages. -Source ROS2 and build: +> [!TIP] +> To avoid sourcing ROS2 in every new terminal, add the source command to your `~/.bashrc` file: +> +> ```bash +> echo "source /opt/ros/jazzy/setup.bash" >> ~/.bashrc # or humble +> ``` -```bash -# Source ROS2 (humble or jazzy) -source /opt/ros/${ROS_DISTRO}/setup.bash + -# Build workspace -cd ~/rai_perception_ws -colcon build --symlink-install + -# Source ROS2 packages -source install/setup.bash -``` +## Getting Started -### Python Dependencies +This section provides a step-by-step guide to get you up and running with RAI Perception. -`rai_perception` depends on `rai-core` and `sam2`. There are many ways to set up a virtual environment and install these dependencies. Below, we provide an example using Poetry. +### Quick Start -**Step 1:** Copy the following template to `pyproject.toml` in your workspace root, updating it according to your directory setup: +After installing `rai-perception`, launch the perception agents: -```toml -# rai_perception_project pyproject template -[tool.poetry] -name = "rai_perception_ws" -version = "0.1.0" -description = "ROS2 workspace for RAI perception" -package-mode = false +**Step 1:** Open a terminal and source ROS2: -[tool.poetry.dependencies] -python = "^3.10, <3.13" -rai-core = ">=2.5.4" -rai-perception = {path = "src/rai_perception", develop = true} +```bash +source /opt/ros/jazzy/setup.bash # or humble +``` -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +**Step 2:** Launch the perception agents: + +```bash +python -m rai_perception.scripts.run_perception_agents ``` -**Step 2:** Install dependencies: +> [!NOTE] +> The weights will be downloaded to `~/.cache/rai` directory on first use. + +The agents create two ROS 2 nodes: `grounding_dino` and `grounded_sam` using [ROS2Connector](../API_documentation/connectors/ROS_2_Connectors.md). -First, we create Virtual Environment with Poetry: +### Testing with Example Client + +The `rai_perception/talker.py` example demonstrates how to use the perception services for object detection and segmentation. It shows the complete pipeline: GroundingDINO for object detection followed by GroundedSAM for instance segmentation, with visualization output. + +**Step 1:** Open a terminal and source ROS2: ```bash -cd ~/rai_perception_ws -poetry lock -poetry install +source /opt/ros/jazzy/setup.bash # or humble ``` -Now, we are ready to launch perception agents: +**Step 2:** Launch the perception agents: ```bash -# Activate virtual environment -source "$(poetry env info --path)"/bin/activate -export PYTHONPATH -PYTHONPATH="$(dirname "$(dirname "$(poetry run which python)")")/lib/python$(poetry run python --version | awk '{print $2}' | cut -d. -f1,2)/site-packages:$PYTHONPATH" +python -m rai_perception.scripts.run_perception_agents +``` -# run agents -python src/rai_perception/scripts/run_perception_agents.py +**Step 3:** In a different terminal (remember to source ROS2 first), run the example client: + +```bash +source /opt/ros/jazzy/setup.bash # or humble +python -m rai_perception.examples.talker --ros-args -p image_path:="" ``` +You can use any image containing objects like dragons, lizards, or dinosaurs. For example, use the `sample.jpg` from the package's `images` folder. The client will detect these objects and save a visualization with bounding boxes and masks to `masks.png` in the current directory. + > [!TIP] -> To manage ROS 2 + Poetry environment with less friction: Keep build tools (colcon) at system level, use Poetry only for runtime dependencies of your packages. +> +> If you wish to integrate open-set vision into your ros2 launch file, a premade launch +> file can be found in `rai/src/rai_bringup/launch/openset.launch.py` - +### ROS2 Service Interface -`rai-perception` agents create two ROS 2 nodes: `grounding_dino` and `grounded_sam` using [ROS2Connector](../../../docs/API_documentation/connectors/ROS_2_Connectors.md). -These agents can be triggered by ROS2 services: +The agents can be triggered by ROS2 services: - `grounding_dino_classify`: `rai_interfaces/srv/RAIGroundingDino` - `grounded_sam_segment`: `rai_interfaces/srv/RAIGroundedSam` -> [!TIP] -> -> If you wish to integrate open-set vision into your ros2 launch file, a premade launch -> file can be found in `rai/src/rai_bringup/launch/openset.launch.py` + -> [!NOTE] -> The weights will be downloaded to `~/.cache/rai` directory. + + +## Dive Deeper: Tools and Integration -## RAI Tools +This section provides information for developers looking to integrate RAI Perception tools into their applications. -`rai_perception` package contains tools that can be used by [RAI LLM agents](../../../docs/tutorials/walkthrough.md) +### RAI Tools + +`rai_perception` package contains tools that can be used by [RAI LLM agents](../tutorials/walkthrough.md) to enhance their perception capabilities. For more information on RAI Tools see -[Tool use and development](../../../docs/tutorials/tools.md) tutorial. +[Tool use and development](../tutorials/tools.md) tutorial. @@ -132,7 +146,7 @@ This tool calls the GroundingDINO service to detect objects from a comma-separat > [!TIP] > -> you can try example below with [rosbotxl demo](../../../docs/demos/rosbot_xl.md) binary. +> you can try example below with [rosbotxl demo](../demos/rosbot_xl.md) binary. > The binary exposes `/camera/camera/color/image_raw` and `/camera/camera/depth/image_rect_raw` topics. @@ -198,30 +212,6 @@ with ROS2Context(): I have detected the following items in the picture desk: 2.43m away ``` -## Simple ROS2 Client Node Example - -The `rai_perception/talker.py` example demonstrates how to use the perception services for object detection and segmentation. It shows the complete pipeline: GroundingDINO for object detection followed by GroundedSAM for instance segmentation, with visualization output. - -This example is useful for: - -- Testing perception services integration -- Understanding the ROS2 service call patterns -- Seeing detection and segmentation results with bounding boxes and masks - -Run the example: - -```bash -cd ~/rai_perception_ws -python src/rai_perception/scripts/run_perception_agents.py -``` - -In a different window, run - -```bash -cd ~/rai_perception_ws -ros2 run rai_perception talker --ros-args -p image_path:=src/rai_perception/images/sample.jpg -``` - -The example will detect objects (dragon, lizard, dinosaur) and save a visualization with bounding boxes and masks to `masks.png`. - + + diff --git a/src/rai_extensions/rai_perception/pyproject.toml b/src/rai_extensions/rai_perception/pyproject.toml index d0a1f0712..0ebeb0079 100644 --- a/src/rai_extensions/rai_perception/pyproject.toml +++ b/src/rai_extensions/rai_perception/pyproject.toml @@ -1,6 +1,7 @@ [tool.poetry] -name = "rai_perception" -version = "0.1.2" +name = "rai-perception" +# TODO, update the version once it is published to PyPi +version = "0.1.5" description = "Package for object detection, segmentation and gripping point detection." authors = ["Kajetan RachwaƂ "] readme = "README.md" diff --git a/src/rai_extensions/rai_perception/rai_perception/agents/base_vision_agent.py b/src/rai_extensions/rai_perception/rai_perception/agents/base_vision_agent.py index 26be9557e..a695b7610 100644 --- a/src/rai_extensions/rai_perception/rai_perception/agents/base_vision_agent.py +++ b/src/rai_extensions/rai_perception/rai_perception/agents/base_vision_agent.py @@ -67,19 +67,46 @@ def _load_model_with_error_handling(self, model_class): raise e def _download_weights(self): + self.logger.info( + f"Downloading weights from {self.WEIGHTS_URL} to {self.weights_path}" + ) try: subprocess.run( [ "wget", self.WEIGHTS_URL, "-O", - self.weights_path, + str(self.weights_path), "--progress=dot:giga", - ] + ], + check=True, + capture_output=True, + text=True, + ) + # Verify file exists and has reasonable size (> 1MB) + if not os.path.exists(self.weights_path): + raise Exception(f"Downloaded file not found at {self.weights_path}") + file_size = os.path.getsize(self.weights_path) + if file_size < 1024 * 1024: + raise Exception( + f"Downloaded file is too small ({file_size} bytes), expected > 1MB" + ) + self.logger.info( + f"Successfully downloaded weights ({file_size / (1024 * 1024):.2f} MB)" ) - except Exception: - self.logger.error("Could not download weights") - raise Exception("Could not download weights") + except subprocess.CalledProcessError as e: + error_msg = e.stderr if e.stderr else e.stdout if e.stdout else str(e) + self.logger.error(f"wget failed: {error_msg}") + # Clean up partial download + if os.path.exists(self.weights_path): + os.remove(self.weights_path) + raise Exception(f"Could not download weights: {error_msg}") + except Exception as e: + self.logger.error(f"Could not download weights: {e}") + # Clean up partial download + if os.path.exists(self.weights_path): + os.remove(self.weights_path) + raise def _remove_weights(self): os.remove(self.weights_path) diff --git a/src/rai_extensions/rai_perception/rai_perception/scripts/__init__.py b/src/rai_extensions/rai_perception/rai_perception/scripts/__init__.py new file mode 100644 index 000000000..97ceef6f0 --- /dev/null +++ b/src/rai_extensions/rai_perception/rai_perception/scripts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai_extensions/rai_perception/scripts/run_perception_agents.py b/src/rai_extensions/rai_perception/rai_perception/scripts/run_perception_agents.py similarity index 99% rename from src/rai_extensions/rai_perception/scripts/run_perception_agents.py rename to src/rai_extensions/rai_perception/rai_perception/scripts/run_perception_agents.py index dc29c3221..a3a56260f 100644 --- a/src/rai_extensions/rai_perception/scripts/run_perception_agents.py +++ b/src/rai_extensions/rai_perception/rai_perception/scripts/run_perception_agents.py @@ -15,6 +15,7 @@ import rclpy from rai.agents import wait_for_shutdown + from rai_perception.agents import GroundedSamAgent, GroundingDinoAgent diff --git a/src/rai_extensions/rai_perception/rai_perception/tools/gdino_tools.py b/src/rai_extensions/rai_perception/rai_perception/tools/gdino_tools.py index 416fba619..3407267a7 100644 --- a/src/rai_extensions/rai_perception/rai_perception/tools/gdino_tools.py +++ b/src/rai_extensions/rai_perception/rai_perception/tools/gdino_tools.py @@ -17,7 +17,7 @@ import numpy as np import sensor_msgs.msg from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from rai.communication.ros2 import ROS2Connector from rai.communication.ros2.api import convert_ros_img_to_ndarray from rai.communication.ros2.ros_async import get_future_result @@ -84,12 +84,18 @@ class GroundingDinoBaseTool(BaseTool): box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") + model_config = ConfigDict(arbitrary_types_allowed=True) + + def _run(self, *args, **kwargs): + """Abstract method - must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement _run method") + def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_names: list[str] ) -> Future: cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info( + self.connector.node.get_logger().info( f"service {GDINO_SERVICE_NAME} not available, waiting again..." ) req = RAIGroundingDino.Request() diff --git a/src/rai_extensions/rai_perception/rai_perception/vision_markup/boxer.py b/src/rai_extensions/rai_perception/rai_perception/vision_markup/boxer.py index 949d18a2c..08982c2d7 100644 --- a/src/rai_extensions/rai_perception/rai_perception/vision_markup/boxer.py +++ b/src/rai_extensions/rai_perception/rai_perception/vision_markup/boxer.py @@ -43,7 +43,16 @@ def to_detection_msg( ) -> Detection2D: detection = Detection2D() detection.header = Header() - detection.header.stamp = timestamp + # TODO(juliaj): Investigate why timestamp is sometimes rclpy.time.Time and sometimes + # builtin_interfaces.msg.Time. The function signature expects rclpy.time.Time, but + # grounding_dino.py calls .to_msg() before passing it. Should we fix the caller or + # change the signature to accept Union[rclpy.time.Time, builtin_interfaces.msg.Time]? + # Handle both rclpy.time.Time (call to_msg()) and builtin_interfaces.msg.Time (use directly) + if hasattr(timestamp, "to_msg"): + detection.header.stamp = timestamp.to_msg() + else: + # Already a builtin_interfaces.msg.Time + detection.header.stamp = timestamp detection.results = [] hypothesis_with_pose = ObjectHypothesisWithPose() hypothesis_with_pose.hypothesis = ObjectHypothesis() diff --git a/tests/rai_perception/conftest.py b/tests/rai_perception/conftest.py new file mode 100644 index 000000000..56f88a991 --- /dev/null +++ b/tests/rai_perception/conftest.py @@ -0,0 +1,71 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_connector(): + """Mock ROS2Connector for testing perception tools. + + Provides a mock ROS2Connector with attributes and methods used by + perception tools: + - connector.node: Mock node with create_client, get_logger, get_parameter + - connector.receive_message: Mock method for receiving ROS2 messages + + Note: Unlike communication package tests which use real ROS2Connector + instances with actual ROS2 infrastructure (integration tests), we use + MagicMock here because: + - We're testing tool logic, not ROS2 integration + - Unit tests should be fast and not require ROS2 infrastructure + - We can control mock behavior for specific test scenarios + """ + connector = MagicMock() + + # Mock the node with all required methods + mock_node = MagicMock() + mock_node.create_client = MagicMock() + mock_node.get_logger = MagicMock(return_value=MagicMock()) + mock_node.get_parameter = MagicMock() + + connector.node = mock_node + connector._node = mock_node # Some code accesses _node directly + connector.receive_message = MagicMock() + + return connector + + +@contextmanager +def patch_ros2_for_agent_tests(mock_connector): + """Context manager to patch ROS2Connector and rclpy.ok for agent tests. + + This patches: + - ROS2Connector at both the source and usage locations to return the provided mock_connector + - rclpy.ok to return False (prevents cleanup_agent from calling rclpy.shutdown) + + Use this in agent tests where BaseVisionAgent creates a real ROS2Connector + which would otherwise require ROS2 to be initialized. + """ + with ( + patch("rai.communication.ros2.ROS2Connector", return_value=mock_connector), + patch( + "rai_perception.agents.base_vision_agent.ROS2Connector", + return_value=mock_connector, + ), + patch("rclpy.ok", return_value=False), + ): + yield diff --git a/tests/rai_perception/test_base_vision_agent.py b/tests/rai_perception/test_base_vision_agent.py index 13f6637e3..371685bd3 100644 --- a/tests/rai_perception/test_base_vision_agent.py +++ b/tests/rai_perception/test_base_vision_agent.py @@ -13,6 +13,7 @@ # limitations under the License. import subprocess +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -31,67 +32,333 @@ def run(self): pass +def create_valid_weights_file(weights_path: Path, size_mb: int = 2) -> None: + """Helper to create a valid weights file for testing. + + Args: + weights_path: Path where the weights file should be created + size_mb: Size of the file in megabytes (default: 2MB) + """ + weights_path.parent.mkdir(parents=True, exist_ok=True) + weights_path.write_bytes(b"0" * (size_mb * 1024 * 1024)) + + +def get_weights_path(tmp_path: Path) -> Path: + """Helper to get the standard weights path for testing. + + Args: + tmp_path: Temporary directory path + + Returns: + Path to the weights file + """ + return tmp_path / "vision" / "weights" / "test_weights.pth" + + +def create_agent_with_weights( + tmp_path: Path, weights_path: Path +) -> MockBaseVisionAgent: + """Helper to create an agent with weights path set. + + Args: + tmp_path: Temporary directory path + weights_path: Path to weights file + + Returns: + Configured MockBaseVisionAgent instance + """ + agent = MockBaseVisionAgent(weights_root_path=str(tmp_path), ros2_name="test_agent") + agent.weights_path = weights_path + return agent + + +def cleanup_agent(agent: MockBaseVisionAgent) -> None: + """Helper to clean up agent and ROS2 context. + + Args: + agent: Agent instance to clean up + """ + agent.stop() + if rclpy.ok(): + rclpy.shutdown() + + +def extract_output_path_from_wget_args(args) -> Path: + """Helper to extract output path from wget subprocess args. + + Args: + args: Arguments passed to subprocess.run (args[0] is the command list) + + Returns: + Path object for the output file + """ + output_path_str = args[0][3] # -O argument is at index 3 + return Path(output_path_str) + + class TestVisionWeightsDownload: """Test cases for BaseVisionAgent._download_weights method.""" def test_download_weights_success(self, tmp_path): """Test successful weight download.""" - # make sure we have multiple levels of directories - weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + weights_path = get_weights_path(tmp_path) # check whether file doesn't exist before download assert not weights_path.exists() def mock_wget(*args, **kwargs): # Simulate wget creating the file - output_path = args[0][3] # -O argument is at index 3 - output_path.write_text("downloaded weights content") + output_path = extract_output_path_from_wget_args(args) + create_valid_weights_file(output_path) return MagicMock(returncode=0) with patch("subprocess.run", side_effect=mock_wget) as mock_run: - agent = MockBaseVisionAgent( - weights_root_path=str(tmp_path), ros2_name="test_agent" - ) - agent.weights_path = weights_path + agent = create_agent_with_weights(tmp_path, weights_path) mock_run.assert_called_once_with( [ "wget", "https://example.com/test_weights.pth", "-O", - weights_path, + str(weights_path), "--progress=dot:giga", - ] + ], + check=True, + capture_output=True, + text=True, ) # Verify file exists after download assert weights_path.exists() - # Clean up ROS2 node and context - agent.stop() - if rclpy.ok(): - rclpy.shutdown() + cleanup_agent(agent) def test_download_weights_failure(self, tmp_path): """Test weight download failure raises exception.""" + weights_path = get_weights_path(tmp_path) + + call_count = 0 + + def mock_wget(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call succeeds (during initialization) + output_path = extract_output_path_from_wget_args(args) + create_valid_weights_file(output_path) + result = MagicMock() + result.returncode = 0 + return result + else: + # Second call fails - raise CalledProcessError + # This will be caught and re-raised as "Could not download weights" + raise subprocess.CalledProcessError( + returncode=1, cmd="wget", stderr="Download failed" + ) + + with patch("subprocess.run", side_effect=mock_wget): + agent = create_agent_with_weights(tmp_path, weights_path) + + # Remove the file to force re-download + weights_path.unlink() + + with pytest.raises(Exception, match="Could not download weights"): + agent._download_weights() + + cleanup_agent(agent) + + def test_download_weights_file_too_small(self, tmp_path): + """Test download failure when file is too small.""" + weights_path = get_weights_path(tmp_path) + # Create file first so initialization doesn't trigger download + create_valid_weights_file(weights_path) + + def mock_wget(*args, **kwargs): + # Simulate wget creating a file that's too small + output_path = extract_output_path_from_wget_args(args) + output_path.write_bytes(b"0" * 100) # 100 bytes, too small + return MagicMock(returncode=0) + + with patch("subprocess.run", side_effect=mock_wget): + agent = create_agent_with_weights(tmp_path, weights_path) + + with pytest.raises(Exception, match="Downloaded file is too small"): + agent._download_weights() + + # Verify file was cleaned up + assert not weights_path.exists() + + cleanup_agent(agent) + + +class TestBaseVisionAgentInit: + """Test cases for BaseVisionAgent.__init__ method.""" + + def test_init_without_weights_filename(self): + """Test that ValueError is raised when WEIGHTS_FILENAME is not set.""" + + class InvalidAgent(BaseVisionAgent): + WEIGHTS_FILENAME = "" + + def run(self): + """Dummy implementation of abstract run method.""" + pass + + with pytest.raises(ValueError, match="WEIGHTS_FILENAME is not set"): + InvalidAgent() + + def test_init_with_path_string(self, tmp_path): + """Test initialization with string path.""" weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + create_valid_weights_file(weights_path) - with patch("subprocess.run") as mock_run: - # First call succeeds (during initialization), second call fails - mock_run.side_effect = [ - MagicMock(returncode=0), # Initial download succeeds - subprocess.CalledProcessError(1, "wget"), # Explicit call fails - ] + agent = MockBaseVisionAgent(weights_root_path=str(tmp_path), ros2_name="test") + assert agent.weights_root_path == tmp_path + assert agent.weights_path == weights_path + + agent.stop() + if rclpy.ok(): + rclpy.shutdown() + + def test_init_with_path_object(self, tmp_path): + """Test initialization with Path object.""" + weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + create_valid_weights_file(weights_path) + agent = MockBaseVisionAgent(weights_root_path=tmp_path, ros2_name="test") + assert agent.weights_root_path == tmp_path + assert agent.weights_path == weights_path + + agent.stop() + if rclpy.ok(): + rclpy.shutdown() + + def test_init_with_existing_file(self, tmp_path): + """Test initialization when weights file already exists.""" + weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + create_valid_weights_file(weights_path) + + with patch("subprocess.run") as mock_run: agent = MockBaseVisionAgent( weights_root_path=str(tmp_path), ros2_name="test_agent" ) + # Should not call download since file exists + mock_run.assert_not_called() + + agent.stop() + if rclpy.ok(): + rclpy.shutdown() + + +class TestLoadModelWithErrorHandling: + """Test cases for BaseVisionAgent._load_model_with_error_handling method.""" + + def test_load_model_success(self, tmp_path): + """Test successful model loading.""" + weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + create_valid_weights_file(weights_path) + + class MockModel: + def __init__(self, weights_path): + self.weights_path = weights_path + + agent = MockBaseVisionAgent(weights_root_path=str(tmp_path), ros2_name="test") + agent.weights_path = weights_path + + model = agent._load_model_with_error_handling(MockModel) + assert model.weights_path == weights_path + + agent.stop() + if rclpy.ok(): + rclpy.shutdown() + + def test_load_model_corrupted_weights(self, tmp_path): + """Test model loading with corrupted weights triggers redownload.""" + weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + weights_path.parent.mkdir(parents=True, exist_ok=True) + weights_path.write_bytes(b"corrupted") + + call_count = 0 + + class MockModel: + def __init__(self, weights_path): + nonlocal call_count + call_count += 1 + self.weights_path = weights_path + if call_count == 1: + raise RuntimeError("PytorchStreamReader failed") + + def mock_wget(*args, **kwargs): + output_path_str = args[0][3] + output_path = Path(output_path_str) + output_path.write_bytes(b"0" * (2 * 1024 * 1024)) + return MagicMock(returncode=0) + + with patch("subprocess.run", side_effect=mock_wget): + agent = MockBaseVisionAgent( + weights_root_path=str(tmp_path), ros2_name="test" + ) agent.weights_path = weights_path - with pytest.raises(Exception, match="Could not download weights"): - agent._download_weights() + model = agent._load_model_with_error_handling(MockModel) + assert model.weights_path == weights_path + assert call_count == 2 # Called twice: once fails, once succeeds + + agent.stop() + if rclpy.ok(): + rclpy.shutdown() + + def test_load_model_other_runtime_error(self, tmp_path): + """Test that non-corruption RuntimeErrors are re-raised.""" + weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + create_valid_weights_file(weights_path) + + class MockModel: + def __init__(self, weights_path): + raise RuntimeError("Some other error") + + agent = MockBaseVisionAgent(weights_root_path=str(tmp_path), ros2_name="test") + agent.weights_path = weights_path + + with pytest.raises(RuntimeError, match="Some other error"): + agent._load_model_with_error_handling(MockModel) + + agent.stop() + if rclpy.ok(): + rclpy.shutdown() - # Clean up ROS2 node and context + +class TestBaseVisionAgentMethods: + """Test cases for other BaseVisionAgent methods.""" + + def test_remove_weights(self, tmp_path): + """Test _remove_weights method.""" + weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + weights_path.parent.mkdir(parents=True, exist_ok=True) + weights_path.write_bytes(b"test") + + agent = MockBaseVisionAgent(weights_root_path=str(tmp_path), ros2_name="test") + agent.weights_path = weights_path + + assert weights_path.exists() + agent._remove_weights() + assert not weights_path.exists() + + agent.stop() + if rclpy.ok(): + rclpy.shutdown() + + def test_stop(self, tmp_path): + """Test stop method shuts down ROS2 connector.""" + weights_path = tmp_path / "vision" / "weights" / "test_weights.pth" + create_valid_weights_file(weights_path) + + agent = MockBaseVisionAgent(weights_root_path=str(tmp_path), ros2_name="test") + agent.weights_path = weights_path + + with patch.object(agent.ros2_connector, "shutdown") as mock_shutdown: agent.stop() - if rclpy.ok(): - rclpy.shutdown() + mock_shutdown.assert_called_once() + + if rclpy.ok(): + rclpy.shutdown() diff --git a/tests/rai_perception/test_boxer.py b/tests/rai_perception/test_boxer.py new file mode 100644 index 000000000..bc9c00805 --- /dev/null +++ b/tests/rai_perception/test_boxer.py @@ -0,0 +1,167 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import MagicMock, patch + +import numpy as np +import rclpy +from rai_perception.vision_markup.boxer import Box, GDBoxer +from rclpy.time import Time +from sensor_msgs.msg import Image +from vision_msgs.msg import Detection2D + + +class TestBox: + """Test cases for Box class.""" + + def test_box_initialization(self): + """Test Box initialization.""" + box = Box((50.0, 50.0), 40.0, 40.0, "dinosaur", 0.9) + + assert box.center == (50.0, 50.0) + assert box.size_x == 40.0 + assert box.size_y == 40.0 + assert box.phrase == "dinosaur" + assert box.confidence == 0.9 + + def test_box_to_detection_msg(self): + """Test Box conversion to Detection2D message.""" + box = Box((50.0, 50.0), 40.0, 40.0, "dinosaur", 0.9) + + class_dict = {"dinosaur": 0, "dragon": 1} + timestamp = Time() + + detection = box.to_detection_msg(class_dict, timestamp) + + assert isinstance(detection, Detection2D) + assert detection.bbox.center.position.x == 50.0 + assert detection.bbox.center.position.y == 50.0 + assert detection.bbox.size_x == 40.0 + assert detection.bbox.size_y == 40.0 + assert detection.results[0].hypothesis.class_id == "dinosaur" + assert detection.results[0].hypothesis.score == 0.9 + assert detection.header.stamp == timestamp.to_msg() + + +class TestGDBoxer: + """Test cases for GDBoxer class.""" + + def setup_method(self): + """Initialize ROS2 before tests that use Time() or ROS2 messages.""" + if not rclpy.ok(): + rclpy.init() + + def teardown_method(self): + """Clean up ROS2 context after each test to prevent thread exceptions.""" + try: + if rclpy.ok(): + # Give any executor threads a moment to finish before shutting down + time.sleep(0.1) + rclpy.shutdown() + except Exception: + # Ignore errors during shutdown - thread may have already been cleaned up + pass + + def test_gdboxer_initialization(self, tmp_path): + """Test GDBoxer initialization.""" + weights_path = tmp_path / "weights.pth" + weights_path.parent.mkdir(parents=True, exist_ok=True) + weights_path.write_bytes(b"test") + + with patch("rai_perception.vision_markup.boxer.Model") as mock_model: + mock_model_instance = MagicMock() + mock_model.return_value = mock_model_instance + mock_model_instance.predict_with_classes.return_value = MagicMock( + xyxy=[[10, 10, 50, 50], [60, 60, 90, 90]], + class_id=[0, 1], + confidence=[0.9, 0.8], + ) + + boxer = GDBoxer(str(weights_path), use_cuda=False) + + assert boxer.weight_path == str(weights_path) + mock_model.assert_called_once() + + def test_gdboxer_get_boxes(self, tmp_path): + """Test GDBoxer get_boxes method.""" + weights_path = tmp_path / "weights.pth" + weights_path.parent.mkdir(parents=True, exist_ok=True) + weights_path.write_bytes(b"test") + + with ( + patch("rai_perception.vision_markup.boxer.Model") as mock_model, + patch("rai_perception.vision_markup.boxer.CvBridge") as mock_bridge, + ): + mock_model_instance = MagicMock() + mock_model.return_value = mock_model_instance + + # Mock predictions + mock_predictions = MagicMock() + mock_predictions.xyxy = [[10, 10, 50, 50], [60, 60, 90, 90]] + mock_predictions.class_id = [0, 1] + mock_predictions.confidence = [0.9, 0.8] + mock_model_instance.predict_with_classes.return_value = mock_predictions + + # Mock bridge + mock_bridge_instance = MagicMock() + mock_bridge.return_value = mock_bridge_instance + # Return a valid numpy array (BGR format) that cv2.cvtColor can process + mock_bridge_instance.imgmsg_to_cv2.return_value = np.zeros( + (100, 100, 3), dtype=np.uint8 + ) + + boxer = GDBoxer(str(weights_path), use_cuda=False) + + image_msg = Image() + classes = ["dinosaur", "dragon"] + boxes = boxer.get_boxes(image_msg, classes, 0.4, 0.4) + + assert len(boxes) == 2 + assert boxes[0].phrase == "dinosaur" + assert boxes[0].confidence == 0.9 + assert boxes[1].phrase == "dragon" + assert boxes[1].confidence == 0.8 + + def test_gdboxer_get_boxes_empty(self, tmp_path): + """Test GDBoxer get_boxes with no detections.""" + weights_path = tmp_path / "weights.pth" + weights_path.parent.mkdir(parents=True, exist_ok=True) + weights_path.write_bytes(b"test") + + with ( + patch("rai_perception.vision_markup.boxer.Model") as mock_model, + patch("rai_perception.vision_markup.boxer.CvBridge") as mock_bridge, + ): + mock_model_instance = MagicMock() + mock_model.return_value = mock_model_instance + + mock_predictions = MagicMock() + mock_predictions.xyxy = [] + mock_model_instance.predict_with_classes.return_value = mock_predictions + + mock_bridge_instance = MagicMock() + mock_bridge.return_value = mock_bridge_instance + # Return a valid numpy array (BGR format) that cv2.cvtColor can process + mock_bridge_instance.imgmsg_to_cv2.return_value = np.zeros( + (100, 100, 3), dtype=np.uint8 + ) + + boxer = GDBoxer(str(weights_path), use_cuda=False) + + image_msg = Image() + classes = ["dinosaur"] + boxes = boxer.get_boxes(image_msg, classes, 0.4, 0.4) + + assert len(boxes) == 0 diff --git a/tests/rai_perception/test_gdino_tools.py b/tests/rai_perception/test_gdino_tools.py new file mode 100644 index 000000000..88f00de2d --- /dev/null +++ b/tests/rai_perception/test_gdino_tools.py @@ -0,0 +1,283 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import sensor_msgs.msg +from rai_perception.tools.gdino_tools import ( + BoundingBox, + DetectionData, + GetDetectionTool, + GetDistanceToObjectsTool, + GroundingDinoBaseTool, +) + +from rai_interfaces.srv import RAIGroundingDino + + +# Create a concrete test subclass for GroundingDinoBaseTool +# Note: Name doesn't start with "Test" to avoid pytest collection +class ConcreteGroundingDinoBaseTool(GroundingDinoBaseTool): + """Concrete implementation for testing GroundingDinoBaseTool.""" + + def _run(self, *args, **kwargs): + """Test implementation of _run.""" + return "test" + + +class TestGroundingDinoBaseTool: + """Test cases for GroundingDinoBaseTool.""" + + @pytest.fixture + def base_tool(self, mock_connector): + """Create a GroundingDinoBaseTool instance.""" + # Use model_construct to bypass Pydantic validation + tool = ConcreteGroundingDinoBaseTool.model_construct(connector=mock_connector) + return tool + + def test_base_tool_initialization(self, base_tool): + """Test GroundingDinoBaseTool initialization.""" + assert base_tool.box_threshold == 0.35 + assert base_tool.text_threshold == 0.45 + assert base_tool.connector is not None + + def test_get_img_from_topic_success(self, base_tool, mock_connector): + """Test get_img_from_topic with successful message.""" + image_msg = sensor_msgs.msg.Image() + mock_connector.receive_message.return_value.payload = image_msg + + result = base_tool.get_img_from_topic("test_topic", timeout_sec=2) + + assert result == image_msg + mock_connector.receive_message.assert_called_once_with( + "test_topic", timeout_sec=2 + ) + + def test_get_image_message_success(self, base_tool, mock_connector): + """Test _get_image_message with valid image.""" + image_msg = sensor_msgs.msg.Image() + mock_connector.receive_message.return_value.payload = image_msg + + result = base_tool._get_image_message("test_topic") + + assert result == image_msg + + def test_call_gdino_node(self, base_tool, mock_connector): + """Test _call_gdino_node creates service call.""" + image_msg = sensor_msgs.msg.Image() + mock_client = MagicMock() + mock_client.wait_for_service.return_value = True + mock_connector.node.create_client.return_value = mock_client + + future = base_tool._call_gdino_node(image_msg, ["dinosaur", "dragon"]) + + assert future is not None + mock_client.call_async.assert_called_once() + + def test_parse_detection_array(self, base_tool): + """Test _parse_detection_array converts response correctly.""" + response = RAIGroundingDino.Response() + + from vision_msgs.msg import ( + BoundingBox2D, + Detection2D, + ObjectHypothesis, + ObjectHypothesisWithPose, + ) + + from rai_interfaces.msg import RAIDetectionArray + + detection1 = Detection2D() + detection1.bbox = BoundingBox2D() + detection1.bbox.center.position.x = 50.0 + detection1.bbox.center.position.y = 50.0 + detection1.bbox.size_x = 40.0 + detection1.bbox.size_y = 40.0 + detection1.results = [ObjectHypothesisWithPose()] + detection1.results[0].hypothesis = ObjectHypothesis() + detection1.results[0].hypothesis.class_id = "dinosaur" + detection1.results[0].hypothesis.score = 0.9 + + detection2 = Detection2D() + detection2.bbox = BoundingBox2D() + detection2.bbox.center.position.x = 100.0 + detection2.bbox.center.position.y = 100.0 + detection2.bbox.size_x = 30.0 + detection2.bbox.size_y = 30.0 + detection2.results = [ObjectHypothesisWithPose()] + detection2.results[0].hypothesis = ObjectHypothesis() + detection2.results[0].hypothesis.class_id = "dragon" + detection2.results[0].hypothesis.score = 0.8 + + response.detections = RAIDetectionArray() + response.detections.detections = [detection1, detection2] + + result = base_tool._parse_detection_array(response) + + assert len(result) == 2 + assert result[0].class_name == "dinosaur" + assert result[0].confidence == 0.9 + assert result[0].bbox.x_center == 50.0 + assert result[1].class_name == "dragon" + assert result[1].confidence == 0.8 + + +class TestGetDetectionTool: + """Test cases for GetDetectionTool.""" + + @pytest.fixture + def detection_tool(self, mock_connector): + """Create a GetDetectionTool instance.""" + # Use model_construct to bypass Pydantic validation + tool = GetDetectionTool.model_construct(connector=mock_connector) + return tool + + def test_get_detection_tool_run_success(self, detection_tool, mock_connector): + """Test GetDetectionTool._run with successful detection.""" + image_msg = sensor_msgs.msg.Image() + mock_connector.receive_message.return_value.payload = image_msg + + mock_client = MagicMock() + mock_client.wait_for_service.return_value = True + mock_connector.node.create_client.return_value = mock_client + + mock_future = MagicMock() + mock_client.call_async.return_value = mock_future + + response = RAIGroundingDino.Response() + from vision_msgs.msg import ( + BoundingBox2D, + Detection2D, + ObjectHypothesis, + ObjectHypothesisWithPose, + ) + + from rai_interfaces.msg import RAIDetectionArray + + detection = Detection2D() + detection.bbox = BoundingBox2D() + detection.bbox.center.position.x = 50.0 + detection.bbox.center.position.y = 50.0 + detection.results = [ObjectHypothesisWithPose()] + detection.results[0].hypothesis = ObjectHypothesis() + detection.results[0].hypothesis.class_id = "dinosaur" + detection.results[0].hypothesis.score = 0.9 + + response.detections = RAIDetectionArray() + response.detections.detections = [detection] + + with patch( + "rai_perception.tools.gdino_tools.get_future_result", + return_value=response, + ): + result = detection_tool._run("test_topic", ["dinosaur", "dragon"]) + + assert "detected" in result.lower() + assert "dinosaur" in result + + +class TestGetDistanceToObjectsTool: + """Test cases for GetDistanceToObjectsTool.""" + + @pytest.fixture + def distance_tool(self, mock_connector): + """Create a GetDistanceToObjectsTool instance.""" + # Use model_construct to bypass Pydantic validation + tool = GetDistanceToObjectsTool.model_construct(connector=mock_connector) + return tool + + def test_get_distance_from_detections(self, distance_tool): + """Test _get_distance_from_detections calculates distances.""" + # Create mock depth image data + depth_arr = np.ones((200, 200), dtype=np.uint16) * 1000 # 1 meter in mm + + detection1 = DetectionData( + class_name="dinosaur", + confidence=0.9, + bbox=BoundingBox(x_center=50.0, y_center=50.0, width=40.0, height=40.0), + ) + + with patch( + "rai_perception.tools.gdino_tools.convert_ros_img_to_ndarray", + return_value=depth_arr, + ): + measurements = distance_tool._get_distance_from_detections( + MagicMock(), [detection1], sigma_threshold=1.0, conversion_ratio=0.001 + ) + + assert len(measurements) == 1 + assert measurements[0][0] == "dinosaur" + assert isinstance(measurements[0][1], (int, float)) + + def test_get_distance_tool_run(self, distance_tool, mock_connector): + """Test GetDistanceToObjectsTool._run.""" + image_msg = sensor_msgs.msg.Image() + depth_msg = sensor_msgs.msg.Image() + mock_connector.receive_message.side_effect = [ + MagicMock(payload=image_msg), + MagicMock(payload=depth_msg), + ] + + mock_client = MagicMock() + mock_client.wait_for_service.return_value = True + mock_connector.node.create_client.return_value = mock_client + + mock_connector.node.get_parameter.side_effect = [ + MagicMock(value=1.0), # outlier_sigma_threshold + MagicMock(value=0.001), # conversion_ratio + ] + + response = RAIGroundingDino.Response() + from vision_msgs.msg import ( + BoundingBox2D, + Detection2D, + ObjectHypothesis, + ObjectHypothesisWithPose, + ) + + from rai_interfaces.msg import RAIDetectionArray + + detection = Detection2D() + detection.bbox = BoundingBox2D() + detection.bbox.center.position.x = 50.0 + detection.bbox.center.position.y = 50.0 + detection.bbox.size_x = 40.0 + detection.bbox.size_y = 40.0 + detection.results = [ObjectHypothesisWithPose()] + detection.results[0].hypothesis = ObjectHypothesis() + detection.results[0].hypothesis.class_id = "dinosaur" + detection.results[0].hypothesis.score = 0.9 + + response.detections = RAIDetectionArray() + response.detections.detections = [detection] + + depth_arr = np.ones((200, 200), dtype=np.uint16) * 1000 + + with ( + patch( + "rai_perception.tools.gdino_tools.get_future_result", + return_value=response, + ), + patch( + "rai_perception.tools.gdino_tools.convert_ros_img_to_ndarray", + return_value=depth_arr, + ), + ): + result = distance_tool._run("camera_topic", "depth_topic", ["dinosaur"]) + + assert "detected" in result.lower() + assert "dinosaur" in result + assert "away" in result diff --git a/tests/rai_perception/test_grounded_sam.py b/tests/rai_perception/test_grounded_sam.py new file mode 100644 index 000000000..e7c697445 --- /dev/null +++ b/tests/rai_perception/test_grounded_sam.py @@ -0,0 +1,210 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from unittest.mock import patch + +import numpy as np +from rai_perception.agents.grounded_sam import ( + GSAM_SERVICE_NAME, + GroundedSamAgent, +) +from sensor_msgs.msg import Image + +from rai_interfaces.srv import RAIGroundedSam +from tests.rai_perception.conftest import patch_ros2_for_agent_tests +from tests.rai_perception.test_base_vision_agent import ( + cleanup_agent, + create_valid_weights_file, + get_weights_path, +) + + +class MockGDSegmenter: + """Mock GDSegmenter for testing.""" + + def __init__(self, weights_path): + self.weights_path = weights_path + + def get_segmentation(self, image, boxes): + """Mock segmentation that returns simple masks.""" + # Return 2 masks for testing + mask1 = np.zeros((100, 100), dtype=np.float32) + mask1[10:50, 10:50] = 1.0 + mask2 = np.zeros((100, 100), dtype=np.float32) + mask2[60:90, 60:90] = 1.0 + return [mask1, mask2] + + +class TestGroundedSamAgent: + """Test cases for GroundedSamAgent. + + Note: All tests patch ROS2Connector to prevent hanging. BaseVisionAgent.__init__ + creates a real ROS2Connector which requires ROS2 to be initialized, so we patch + it to use a mock instead for unit testing. + """ + + def test_init(self, tmp_path, mock_connector): + """Test GroundedSamAgent initialization.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounded_sam.GDSegmenter", MockGDSegmenter), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundedSamAgent(weights_root_path=str(tmp_path), ros2_name="test") + + assert agent.WEIGHTS_URL is not None + assert agent.WEIGHTS_FILENAME == "sam2_hiera_large.pt" + assert agent._segmenter is not None + + cleanup_agent(agent) + + def test_init_default_path(self, mock_connector): + """Test GroundedSamAgent initialization with default path.""" + weights_path = Path.home() / ".cache/rai/vision/weights/sam2_hiera_large.pt" + weights_path.parent.mkdir(parents=True, exist_ok=True) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounded_sam.GDSegmenter", MockGDSegmenter), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundedSamAgent(ros2_name="test") + + assert agent._segmenter is not None + + cleanup_agent(agent) + weights_path.unlink() + + def test_run_creates_service(self, tmp_path, mock_connector): + """Test that run() creates the ROS2 service.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounded_sam.GDSegmenter", MockGDSegmenter), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundedSamAgent(weights_root_path=str(tmp_path), ros2_name="test") + + with patch.object( + agent.ros2_connector, "create_service" + ) as mock_create_service: + agent.run() + + mock_create_service.assert_called_once_with( + service_name=GSAM_SERVICE_NAME, + on_request=agent._segment_callback, + service_type="rai_interfaces/srv/RAIGroundedSam", + ) + + cleanup_agent(agent) + + def test_segment_callback(self, tmp_path, mock_connector): + """Test segment callback processes request correctly.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounded_sam.GDSegmenter", MockGDSegmenter), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundedSamAgent(weights_root_path=str(tmp_path), ros2_name="test") + + # Create mock request + request = RAIGroundedSam.Request() + request.source_img = Image() + + # Create mock detections + from vision_msgs.msg import BoundingBox2D, Detection2D + + from rai_interfaces.msg import RAIDetectionArray + + detection1 = Detection2D() + detection1.bbox = BoundingBox2D() + detection1.bbox.center.position.x = 30.0 + detection1.bbox.center.position.y = 30.0 + detection1.bbox.size_x = 40.0 + detection1.bbox.size_y = 40.0 + + detection2 = Detection2D() + detection2.bbox = BoundingBox2D() + detection2.bbox.center.position.x = 75.0 + detection2.bbox.center.position.y = 75.0 + detection2.bbox.size_x = 30.0 + detection2.bbox.size_y = 30.0 + + request.detections = RAIDetectionArray() + request.detections.detections = [detection1, detection2] + + response = RAIGroundedSam.Response() + + # Call callback + result = agent._segment_callback(request, response) + + # Verify response contains masks + assert len(result.masks) == 2 + assert result is response + + cleanup_agent(agent) + + def test_segment_callback_empty_detections(self, tmp_path, mock_connector): + """Test segment callback with empty detections.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + class EmptySegmenter: + def __init__(self, weights_path): + self.weights_path = weights_path + + def get_segmentation(self, image, boxes): + return [] + + with ( + patch("rai_perception.agents.grounded_sam.GDSegmenter", EmptySegmenter), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundedSamAgent(weights_root_path=str(tmp_path), ros2_name="test") + + request = RAIGroundedSam.Request() + request.source_img = Image() + + from rai_interfaces.msg import RAIDetectionArray + + request.detections = RAIDetectionArray() + request.detections.detections = [] + + response = RAIGroundedSam.Response() + result = agent._segment_callback(request, response) + + assert len(result.masks) == 0 + + cleanup_agent(agent) diff --git a/tests/rai_perception/test_grounding_dino.py b/tests/rai_perception/test_grounding_dino.py new file mode 100644 index 000000000..907cfe1ff --- /dev/null +++ b/tests/rai_perception/test_grounding_dino.py @@ -0,0 +1,242 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from rai_perception.agents.grounding_dino import ( + GDINO_SERVICE_NAME, + GroundingDinoAgent, +) +from rai_perception.vision_markup.boxer import Box +from sensor_msgs.msg import Image + +from tests.rai_perception.conftest import patch_ros2_for_agent_tests +from tests.rai_perception.test_base_vision_agent import ( + cleanup_agent, + create_valid_weights_file, + get_weights_path, +) + + +def setup_mock_clock(agent): + """Setup mock clock for agent tests. + + The code calls clock().now().to_msg() to get ts, then passes ts to + to_detection_msg which expects rclpy.time.Time and calls ts.to_msg() again. + However, ts is also assigned to response.detections.header.stamp which expects + builtin_interfaces.msg.Time. + + ROS2 Humble vs Jazzy difference: + - Humble: Strict type checking in __debug__ mode requires actual BuiltinTime + instances, not MagicMock objects. Using MagicMock causes AssertionError. + - Jazzy: More lenient with MagicMock, but BuiltinTime instances don't allow + dynamically adding methods (AttributeError when accessing to_msg). + + Solution: Create a wrapper class that inherits from BuiltinTime and adds to_msg(). + """ + from builtin_interfaces.msg import Time as BuiltinTime + + class TimeWithToMsg(BuiltinTime): + """BuiltinTime wrapper that adds to_msg() method for compatibility.""" + + def to_msg(self): + return self + + mock_clock = MagicMock() + mock_time = MagicMock() + # Create a TimeWithToMsg instance (passes isinstance checks and has to_msg()) + mock_ts = TimeWithToMsg() + mock_time.to_msg.return_value = mock_ts + mock_clock.now.return_value = mock_time + agent.ros2_connector._node.get_clock = MagicMock(return_value=mock_clock) + + +class MockGDBoxer: + """Mock GDBoxer for testing.""" + + def __init__(self, weights_path): + self.weights_path = weights_path + + def get_boxes(self, image_msg, classes, box_threshold, text_threshold): + """Mock box detection.""" + box1 = Box((50.0, 50.0), 40.0, 40.0, classes[0], 0.9) + box2 = Box((100.0, 100.0), 30.0, 30.0, classes[1], 0.8) + return [box1, box2] + + +class TestGroundingDinoAgent: + """Test cases for GroundingDinoAgent. + + Note: All tests patch ROS2Connector to prevent hanging. BaseVisionAgent.__init__ + creates a real ROS2Connector which requires ROS2 to be initialized, so we patch + it to use a mock instead for unit testing. + """ + + def test_init(self, tmp_path, mock_connector): + """Test GroundingDinoAgent initialization.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounding_dino.GDBoxer", MockGDBoxer), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundingDinoAgent( + weights_root_path=str(tmp_path), ros2_name="test" + ) + + assert agent.WEIGHTS_URL is not None + assert agent.WEIGHTS_FILENAME == "groundingdino_swint_ogc.pth" + assert agent._boxer is not None + + cleanup_agent(agent) + + def test_init_default_path(self, mock_connector): + """Test GroundingDinoAgent initialization with default path.""" + weights_path = ( + Path.home() / ".cache/rai/vision/weights/groundingdino_swint_ogc.pth" + ) + weights_path.parent.mkdir(parents=True, exist_ok=True) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounding_dino.GDBoxer", MockGDBoxer), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundingDinoAgent(ros2_name="test") + + assert agent._boxer is not None + + cleanup_agent(agent) + weights_path.unlink() + + def test_run_creates_service(self, tmp_path, mock_connector): + """Test that run() creates the ROS2 service.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounding_dino.GDBoxer", MockGDBoxer), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundingDinoAgent( + weights_root_path=str(tmp_path), ros2_name="test" + ) + + with patch.object( + agent.ros2_connector, "create_service" + ) as mock_create_service: + agent.run() + + mock_create_service.assert_called_once() + call_args = mock_create_service.call_args + assert call_args[0][0] == GDINO_SERVICE_NAME + assert call_args[0][1] == agent._classify_callback + assert ( + call_args[1]["service_type"] + == "rai_interfaces/srv/RAIGroundingDino" + ) + + cleanup_agent(agent) + + def test_classify_callback(self, tmp_path, mock_connector): + """Test classify callback processes request correctly.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + with ( + patch("rai_perception.agents.grounding_dino.GDBoxer", MockGDBoxer), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundingDinoAgent( + weights_root_path=str(tmp_path), ros2_name="test" + ) + + # Create mock request + from rai_interfaces.srv import RAIGroundingDino + + request = RAIGroundingDino.Request() + request.source_img = Image() + request.classes = "dinosaur, dragon" + request.box_threshold = 0.4 + request.text_threshold = 0.4 + + response = RAIGroundingDino.Response() + + setup_mock_clock(agent) + + # Call callback + result = agent._classify_callback(request, response) + + # Verify response + assert len(result.detections.detections) == 2 + assert result.detections.detection_classes == ["dinosaur", "dragon"] + assert result is response + + cleanup_agent(agent) + + def test_classify_callback_empty_boxes(self, tmp_path, mock_connector): + """Test classify callback with no detections.""" + weights_path = get_weights_path(tmp_path) + create_valid_weights_file(weights_path) + + class EmptyBoxer: + def __init__(self, weights_path): + self.weights_path = weights_path + + def get_boxes(self, image_msg, classes, box_threshold, text_threshold): + return [] + + with ( + patch("rai_perception.agents.grounding_dino.GDBoxer", EmptyBoxer), + patch_ros2_for_agent_tests(mock_connector), + patch( + "rai_perception.agents.base_vision_agent.BaseVisionAgent._download_weights" + ), + ): + agent = GroundingDinoAgent( + weights_root_path=str(tmp_path), ros2_name="test" + ) + + from rai_interfaces.srv import RAIGroundingDino + + request = RAIGroundingDino.Request() + request.source_img = Image() + request.classes = "dinosaur" + request.box_threshold = 0.4 + request.text_threshold = 0.4 + + response = RAIGroundingDino.Response() + + setup_mock_clock(agent) + + result = agent._classify_callback(request, response) + + assert len(result.detections.detections) == 0 + assert result.detections.detection_classes == ["dinosaur"] + + cleanup_agent(agent) diff --git a/tests/rai_perception/test_run_perception_agents.py b/tests/rai_perception/test_run_perception_agents.py new file mode 100644 index 000000000..600c4742b --- /dev/null +++ b/tests/rai_perception/test_run_perception_agents.py @@ -0,0 +1,103 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +from rai_perception.scripts.run_perception_agents import main + + +class TestRunPerceptionAgents: + """Test cases for run_perception_agents.main function.""" + + def test_main_initializes_agents(self): + """Test that main function initializes both agents.""" + with ( + patch("rai_perception.scripts.run_perception_agents.rclpy") as mock_rclpy, + patch( + "rai_perception.scripts.run_perception_agents.GroundingDinoAgent" + ) as mock_dino, + patch( + "rai_perception.scripts.run_perception_agents.GroundedSamAgent" + ) as mock_sam, + patch( + "rai_perception.scripts.run_perception_agents.wait_for_shutdown" + ) as mock_wait, + ): + mock_dino_instance = MagicMock() + mock_sam_instance = MagicMock() + mock_dino.return_value = mock_dino_instance + mock_sam.return_value = mock_sam_instance + + main() + + mock_rclpy.init.assert_called_once() + mock_dino.assert_called_once() + mock_sam.assert_called_once() + mock_dino_instance.run.assert_called_once() + mock_sam_instance.run.assert_called_once() + mock_wait.assert_called_once_with([mock_dino_instance, mock_sam_instance]) + mock_rclpy.shutdown.assert_called_once() + + def test_main_calls_agents_in_order(self): + """Test that agents are called in the correct order.""" + call_order = [] + + def track_dino_init(*args, **kwargs): + call_order.append("dino_init") + mock_instance = MagicMock() + mock_instance.run.side_effect = lambda: call_order.append("dino_run") + return mock_instance + + def track_sam_init(*args, **kwargs): + call_order.append("sam_init") + mock_instance = MagicMock() + mock_instance.run.side_effect = lambda: call_order.append("sam_run") + return mock_instance + + with ( + patch("rai_perception.scripts.run_perception_agents.rclpy"), + patch( + "rai_perception.scripts.run_perception_agents.GroundingDinoAgent", + side_effect=track_dino_init, + ), + patch( + "rai_perception.scripts.run_perception_agents.GroundedSamAgent", + side_effect=track_sam_init, + ), + patch("rai_perception.scripts.run_perception_agents.wait_for_shutdown"), + ): + main() + + # Verify order: init dino, init sam, run dino, run sam + assert call_order == ["dino_init", "sam_init", "dino_run", "sam_run"] + + def test_main_handles_shutdown(self): + """Test that main properly shuts down ROS2.""" + with ( + patch("rai_perception.scripts.run_perception_agents.rclpy") as mock_rclpy, + patch( + "rai_perception.scripts.run_perception_agents.GroundingDinoAgent" + ) as mock_dino, + patch( + "rai_perception.scripts.run_perception_agents.GroundedSamAgent" + ) as mock_sam, + patch("rai_perception.scripts.run_perception_agents.wait_for_shutdown"), + ): + mock_dino.return_value = MagicMock() + mock_sam.return_value = MagicMock() + + main() + + mock_rclpy.init.assert_called_once() + mock_rclpy.shutdown.assert_called_once() diff --git a/tests/rai_perception/test_segmentation_tools.py b/tests/rai_perception/test_segmentation_tools.py new file mode 100644 index 000000000..939f2c4ec --- /dev/null +++ b/tests/rai_perception/test_segmentation_tools.py @@ -0,0 +1,309 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import sensor_msgs.msg +from rai_perception.tools.segmentation_tools import ( + GetGrabbingPointTool, + GetSegmentationTool, + depth_to_point_cloud, +) + +from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino + + +class TestGetSegmentationTool: + """Test cases for GetSegmentationTool.""" + + @pytest.fixture + def segmentation_tool(self, mock_connector): + """Create a GetSegmentationTool instance.""" + tool = GetSegmentationTool() + tool.connector = mock_connector + # Set actual float values since GetSegmentationTool uses Field annotations + # but isn't a Pydantic model, so self.box_threshold would be a Field object + tool.box_threshold = 0.35 + tool.text_threshold = 0.45 + return tool + + def test_get_image_message_success(self, segmentation_tool, mock_connector): + """Test _get_image_message with valid image.""" + image_msg = sensor_msgs.msg.Image() + mock_connector.receive_message.return_value.payload = image_msg + + result = segmentation_tool._get_image_message("test_topic") + + assert result == image_msg + + def test_call_gdino_node(self, segmentation_tool, mock_connector): + """Test _call_gdino_node creates service call.""" + image_msg = sensor_msgs.msg.Image() + mock_client = MagicMock() + mock_client.wait_for_service.return_value = True + mock_connector.node.create_client.return_value = mock_client + + future = segmentation_tool._call_gdino_node(image_msg, "dinosaur") + + assert future is not None + mock_client.call_async.assert_called_once() + + def test_call_gsam_node(self, segmentation_tool, mock_connector): + """Test _call_gsam_node creates service call.""" + image_msg = sensor_msgs.msg.Image() + gdino_response = RAIGroundingDino.Response() + from rai_interfaces.msg import RAIDetectionArray + + gdino_response.detections = RAIDetectionArray() + + mock_client = MagicMock() + mock_client.wait_for_service.return_value = True + mock_connector.node.create_client.return_value = mock_client + + future = segmentation_tool._call_gsam_node(image_msg, gdino_response) + + assert future is not None + mock_client.call_async.assert_called_once() + + def test_run_success(self, segmentation_tool, mock_connector): + """Test _run method with successful segmentation.""" + image_msg = sensor_msgs.msg.Image() + mock_connector.receive_message.return_value.payload = image_msg + + mock_gdino_client = MagicMock() + mock_gdino_client.wait_for_service.return_value = True + mock_gsam_client = MagicMock() + mock_gsam_client.wait_for_service.return_value = True + + def create_client_side_effect(service_type, service_name): + if "GroundingDino" in str(service_type): + return mock_gdino_client + return mock_gsam_client + + mock_connector.node.create_client.side_effect = create_client_side_effect + + gdino_response = RAIGroundingDino.Response() + from rai_interfaces.msg import RAIDetectionArray + + gdino_response.detections = RAIDetectionArray() + + gsam_response = RAIGroundedSam.Response() + mask_msg1 = sensor_msgs.msg.Image() + mask_msg1.encoding = "mono8" # Set encoding to avoid cv_bridge errors + mask_msg2 = sensor_msgs.msg.Image() + mask_msg2.encoding = "mono8" # Set encoding to avoid cv_bridge errors + gsam_response.masks = [mask_msg1, mask_msg2] + + mock_connector.node.get_parameter.return_value.value = 0.001 + + with ( + patch( + "rai_perception.tools.segmentation_tools.get_future_result" + ) as mock_get_result, + patch( + "rai_perception.tools.segmentation_tools.convert_ros_img_to_base64" + ) as mock_convert, + ): + mock_get_result.side_effect = [gdino_response, gsam_response] + mock_convert.side_effect = ["base64_1", "base64_2"] + + with patch("rclpy.ok", return_value=True): + result_text, result_data = segmentation_tool._run( + "camera_topic", "dinosaur" + ) + + assert result_text == "" + assert "segmentations" in result_data + assert len(result_data["segmentations"]) == 2 + + +class TestGetGrabbingPointTool: + """Test cases for GetGrabbingPointTool.""" + + @pytest.fixture + def grabbing_tool(self, mock_connector): + """Create a GetGrabbingPointTool instance.""" + # Use model_construct to bypass Pydantic validation for connector field + tool = GetGrabbingPointTool.model_construct(connector=mock_connector) + return tool + + def test_get_camera_info_message(self, grabbing_tool, mock_connector): + """Test _get_camera_info_message.""" + camera_info = sensor_msgs.msg.CameraInfo() + mock_connector.receive_message.return_value.payload = camera_info + + result = grabbing_tool._get_camera_info_message("camera_info_topic") + + assert result == camera_info + + def test_get_intrinsic_from_camera_info(self, grabbing_tool): + """Test _get_intrinsic_from_camera_info extracts parameters.""" + camera_info = sensor_msgs.msg.CameraInfo() + camera_info.k = [500.0, 0.0, 320.0, 0.0, 500.0, 240.0, 0.0, 0.0, 1.0] + + fx, fy, cx, cy = grabbing_tool._get_intrinsic_from_camera_info(camera_info) + + assert fx == 500.0 + assert fy == 500.0 + assert cx == 320.0 + assert cy == 240.0 + + def test_process_mask(self, grabbing_tool): + """Test _process_mask calculates centroid and rotation.""" + mask_msg = sensor_msgs.msg.Image() + depth_msg = sensor_msgs.msg.Image() + + # Create mock mask (100x100 with a square region) + mask = np.zeros((100, 100), dtype=np.uint8) + mask[20:80, 20:80] = 255 + + # Create mock depth (1 meter = 1000mm) + depth = np.ones((100, 100), dtype=np.uint16) * 1000 + + intrinsic = (500.0, 500.0, 50.0, 50.0) # fx, fy, cx, cy + + with patch( + "rai_perception.tools.segmentation_tools.convert_ros_img_to_ndarray" + ) as mock_convert: + mock_convert.side_effect = [mask, depth] + + with patch("cv2.minAreaRect") as mock_min_area: + mock_min_area.return_value = ( + (50.0, 50.0), + (60.0, 60.0), + 0.0, + ) # center, dimensions, angle + + centroid, rotation = grabbing_tool._process_mask( + mask_msg, depth_msg, intrinsic, depth_to_meters_ratio=0.001 + ) + + assert len(centroid) == 3 + assert isinstance(rotation, (int, float)) + + def test_run(self, grabbing_tool, mock_connector): + """Test GetGrabbingPointTool._run.""" + image_msg = sensor_msgs.msg.Image() + depth_msg = sensor_msgs.msg.Image() + camera_info = sensor_msgs.msg.CameraInfo() + camera_info.k = [500.0, 0.0, 320.0, 0.0, 500.0, 240.0, 0.0, 0.0, 1.0] + + mock_connector.receive_message.side_effect = [ + MagicMock(payload=image_msg), + MagicMock(payload=depth_msg), + MagicMock(payload=camera_info), + ] + + mock_gdino_client = MagicMock() + mock_gdino_client.wait_for_service.return_value = True + mock_gsam_client = MagicMock() + mock_gsam_client.wait_for_service.return_value = True + + def create_client_side_effect(service_type, service_name): + if "GroundingDino" in str(service_type): + return mock_gdino_client + return mock_gsam_client + + mock_connector.node.create_client.side_effect = create_client_side_effect + + gdino_response = RAIGroundingDino.Response() + from rai_interfaces.msg import RAIDetectionArray + + gdino_response.detections = RAIDetectionArray() + + gsam_response = RAIGroundedSam.Response() + mask_msg = sensor_msgs.msg.Image() + mask_msg.encoding = "mono8" # Set encoding to avoid cv_bridge errors + gsam_response.masks = [mask_msg] + + mock_connector.node.get_parameter.return_value.value = 0.001 + + mask = np.zeros((100, 100), dtype=np.uint8) + mask[20:80, 20:80] = 255 + depth = np.ones((100, 100), dtype=np.uint16) * 1000 + + call_count = [0] # Use list to allow modification in nested function + + def convert_side_effect(msg): + """Return appropriate array based on call order: first call is mask, second is depth.""" + call_count[0] += 1 + if call_count[0] == 1: + return mask + else: + return depth + + # Patch convert_ros_img_to_base64 to avoid cv2 errors - GetGrabbingPointTool._run + # calls this function which internally uses cv2.cvtColor on empty mock images + with ( + patch( + "rai_perception.tools.segmentation_tools.get_future_result" + ) as mock_get_result, + patch( + "rai_perception.tools.segmentation_tools.convert_ros_img_to_ndarray", + side_effect=convert_side_effect, + ), + patch( + "rai_perception.tools.segmentation_tools.convert_ros_img_to_base64", + return_value="mock_base64_string", + ), + patch("cv2.minAreaRect") as mock_min_area, + patch( + "rai.communication.ros2.api.convert_ros_img_to_ndarray", + side_effect=convert_side_effect, + ), + ): + mock_get_result.side_effect = [gdino_response, gsam_response] + mock_min_area.return_value = ((50.0, 50.0), (60.0, 60.0), 0.0) + + result = grabbing_tool._run( + "camera_topic", "depth_topic", "camera_info_topic", "dinosaur" + ) + + assert isinstance(result, list) + assert len(result) == 1 + assert len(result[0]) == 2 # (centroid, rotation) + + +class TestDepthToPointCloud: + """Test cases for depth_to_point_cloud function.""" + + def test_depth_to_point_cloud(self): + """Test depth_to_point_cloud conversion.""" + # Create a simple depth image (100x100, 1 meter depth) + depth_image = np.ones((100, 100), dtype=np.float32) * 1.0 + + fx, fy = 500.0, 500.0 + cx, cy = 50.0, 50.0 + + points = depth_to_point_cloud(depth_image, fx, fy, cx, cy) + + # Should have points (excluding zero depth) + assert len(points) > 0 + assert points.shape[1] == 3 # x, y, z coordinates + + def test_depth_to_point_cloud_with_zero_depth(self): + """Test depth_to_point_cloud filters zero depth.""" + depth_image = np.zeros((100, 100), dtype=np.float32) + depth_image[20:80, 20:80] = 1.0 # Only center region has depth + + fx, fy = 500.0, 500.0 + cx, cy = 50.0, 50.0 + + points = depth_to_point_cloud(depth_image, fx, fy, cx, cy) + + # Should only have points from non-zero depth region + assert len(points) > 0 + assert all(points[:, 2] > 0) # All z values should be positive