diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 4f2523a07..9c8152a88 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.2.1" +version = "2.3.0" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" diff --git a/src/rai_core/rai/communication/ros2/api/service.py b/src/rai_core/rai/communication/ros2/api/service.py index 464625aed..06dac1bc1 100644 --- a/src/rai_core/rai/communication/ros2/api/service.py +++ b/src/rai_core/rai/communication/ros2/api/service.py @@ -14,6 +14,7 @@ import os import uuid +from threading import Lock from typing import ( Any, Callable, @@ -30,6 +31,7 @@ import rclpy.qos import rclpy.subscription import rclpy.task +from rclpy.client import Client from rclpy.service import Service from rai.communication.ros2.api.base import ( @@ -39,12 +41,18 @@ class ROS2ServiceAPI(BaseROS2API): - """Handles ROS2 service operations including calling services.""" + """Handles ROS 2 service operations including calling services.""" def __init__(self, node: rclpy.node.Node) -> None: self.node = node self._logger = node.get_logger() self._services: Dict[str, Service] = {} + self._persistent_clients: Dict[str, Client] = {} + self._persistent_clients_lock = Lock() + + def release_client(self, service_name: str) -> bool: + with self._persistent_clients_lock: + return self._persistent_clients.pop(service_name, None) is not None def call_service( self, @@ -52,30 +60,57 @@ def call_service( service_type: str, request: Any, timeout_sec: float = 5.0, + *, + reuse_client: bool = True, ) -> Any: """ - Call a ROS2 service. + Call a ROS 2 service. Args: - service_name: Name of the service to call - service_type: ROS2 service type as string - request: Request message content + service_name: Fully-qualified service name. + service_type: ROS 2 service type string (e.g., 'std_srvs/srv/SetBool'). + request: Request payload dict. + timeout_sec: Seconds to wait for availability/response. + reuse_client: Reuse a cached client. Client creation is synchronized; set + False to create a new client per call. Returns: - The response message + Response message instance. + + Raises: + ValueError: Service not available within the timeout. + AttributeError: Service type or request cannot be constructed. + + Note: + With reuse_client=True, access to the cached client (including the + service call) is serialized by a lock, preventing concurrent calls + through the same client. Use reuse_client=False for per-call clients + when concurrent service calls are required. """ srv_msg, srv_cls = self.build_ros2_service_request(service_type, request) - service_client = self.node.create_client(srv_cls, service_name) # type: ignore - client_ready = service_client.wait_for_service(timeout_sec=timeout_sec) - if not client_ready: - raise ValueError( - f"Service {service_name} not ready within {timeout_sec} seconds. " - "Try increasing the timeout or check if the service is running." - ) - if os.getenv("ROS_DISTRO") == "humble": - return service_client.call(srv_msg) + + def _call_service(client: Client, timeout_sec: float) -> Any: + is_service_available = client.wait_for_service(timeout_sec=timeout_sec) + if not is_service_available: + raise ValueError( + f"Service {service_name} not ready within {timeout_sec} seconds. " + "Try increasing the timeout or check if the service is running." + ) + if os.getenv("ROS_DISTRO") == "humble": + return client.call(srv_msg) + else: + return client.call(srv_msg, timeout_sec=timeout_sec) + + if reuse_client: + with self._persistent_clients_lock: + client = self._persistent_clients.get(service_name, None) + if client is None: + client = self.node.create_client(srv_cls, service_name) # type: ignore + self._persistent_clients[service_name] = client + return _call_service(client, timeout_sec) else: - return service_client.call(srv_msg, timeout_sec=timeout_sec) + client = self.node.create_client(srv_cls, service_name) # type: ignore + return _call_service(client, timeout_sec) def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]: return self.node.get_service_names_and_types() diff --git a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py index 985de6c16..7c1597a56 100644 --- a/src/rai_core/rai/communication/ros2/connectors/service_mixin.py +++ b/src/rai_core/rai/communication/ros2/connectors/service_mixin.py @@ -30,6 +30,9 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: f"{self.__class__.__name__} instance must have an attribute '_service_api' of type ROS2ServiceAPI" ) + def release_client(self, service_name: str) -> bool: + return self._service_api.release_client(service_name) + def service_call( self, message: ROS2Message, @@ -37,6 +40,7 @@ def service_call( timeout_sec: float = 5.0, *, msg_type: str, + reuse_client: bool = True, **kwargs: Any, ) -> ROS2Message: msg = self._service_api.call_service( @@ -44,6 +48,7 @@ def service_call( service_type=msg_type, request=message.payload, timeout_sec=timeout_sec, + reuse_client=reuse_client, ) return ROS2Message( payload=msg, metadata={"msg_type": str(type(msg)), "service": target} diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index bc102a58a..d51f30378 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -138,11 +138,14 @@ def test_ros2_single_message_publish_wrong_qos_setup( shutdown_executors_and_threads(executors, threads) -def service_call_helper(service_name: str, service_api: ROS2ServiceAPI): +def invoke_set_bool_service( + service_name: str, service_api: ROS2ServiceAPI, reuse_client: bool = True +): response = service_api.call_service( service_name, service_type="std_srvs/srv/SetBool", request={"data": True}, + reuse_client=reuse_client, ) assert response.success assert response.message == "Test service called" @@ -164,7 +167,7 @@ def test_ros2_service_single_call( try: service_api = ROS2ServiceAPI(node) - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api) finally: shutdown_executors_and_threads(executors, threads) @@ -186,7 +189,30 @@ def test_ros2_service_multiple_calls( try: service_api = ROS2ServiceAPI(node) for _ in range(3): - service_call_helper(service_name, service_api) + invoke_set_bool_service(service_name, service_api, reuse_client=False) + finally: + shutdown_executors_and_threads(executors, threads) + + +@pytest.mark.parametrize( + "callback_group", + [MutuallyExclusiveCallbackGroup(), ReentrantCallbackGroup()], + ids=["MutuallyExclusiveCallbackGroup", "ReentrantCallbackGroup"], +) +def test_ros2_service_multiple_calls_with_reused_client( + ros_setup: None, request: pytest.FixtureRequest, callback_group: CallbackGroup +) -> None: + service_name = f"{request.node.originalname}_service" # type: ignore + node_name = f"{request.node.originalname}_node" # type: ignore + service_server = ServiceServer(service_name, callback_group) + node = Node(node_name) + executors, threads = multi_threaded_spinner([service_server, node]) + + try: + service_api = ROS2ServiceAPI(node) + for _ in range(3): + invoke_set_bool_service(service_name, service_api, reuse_client=True) + assert service_api.release_client(service_name), "Client not released" finally: shutdown_executors_and_threads(executors, threads) @@ -210,7 +236,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_threading( service_threads: List[threading.Thread] = [] for _ in range(10): thread = threading.Thread( - target=service_call_helper, args=(service_name, service_api) + target=invoke_set_bool_service, args=(service_name, service_api) ) service_threads.append(thread) thread.start() @@ -241,7 +267,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_multiprocessing( service_api = ROS2ServiceAPI(node) with Pool(10) as pool: pool.map( - lambda _: service_call_helper(service_name, service_api), range(10) + lambda _: invoke_set_bool_service(service_name, service_api), range(10) ) finally: shutdown_executors_and_threads(executors, threads)