Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/rai_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "rai_core"
version = "2.2.1"
version = "2.3.0"
description = "Core functionality for RAI framework"
authors = ["Maciej Majek <maciej.majek@robotec.ai>", "Bartłomiej Boczek <bartlomiej.boczek@robotec.ai>", "Kajetan Rachwał <kajetan.rachwal@robotec.ai>"]
readme = "README.md"
Expand Down
67 changes: 51 additions & 16 deletions src/rai_core/rai/communication/ros2/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import uuid
from threading import Lock
from typing import (
Any,
Callable,
Expand All @@ -30,6 +31,7 @@
import rclpy.qos
import rclpy.subscription
import rclpy.task
from rclpy.client import Client
from rclpy.service import Service

from rai.communication.ros2.api.base import (
Expand All @@ -39,43 +41,76 @@


class ROS2ServiceAPI(BaseROS2API):
"""Handles ROS2 service operations including calling services."""
"""Handles ROS 2 service operations including calling services."""

def __init__(self, node: rclpy.node.Node) -> None:
self.node = node
self._logger = node.get_logger()
self._services: Dict[str, Service] = {}
self._persistent_clients: Dict[str, Client] = {}
self._persistent_clients_lock = Lock()

def release_client(self, service_name: str) -> bool:
with self._persistent_clients_lock:
return self._persistent_clients.pop(service_name, None) is not None

def call_service(
self,
service_name: str,
service_type: str,
request: Any,
timeout_sec: float = 5.0,
*,
reuse_client: bool = True,
) -> Any:
"""
Call a ROS2 service.
Call a ROS 2 service.

Args:
service_name: Name of the service to call
service_type: ROS2 service type as string
request: Request message content
service_name: Fully-qualified service name.
service_type: ROS 2 service type string (e.g., 'std_srvs/srv/SetBool').
request: Request payload dict.
timeout_sec: Seconds to wait for availability/response.
reuse_client: Reuse a cached client. Client creation is synchronized; set
False to create a new client per call.

Returns:
The response message
Response message instance.

Raises:
ValueError: Service not available within the timeout.
AttributeError: Service type or request cannot be constructed.

Note:
With reuse_client=True, access to the cached client (including the
service call) is serialized by a lock, preventing concurrent calls
through the same client. Use reuse_client=False for per-call clients
when concurrent service calls are required.
"""
srv_msg, srv_cls = self.build_ros2_service_request(service_type, request)
service_client = self.node.create_client(srv_cls, service_name) # type: ignore
client_ready = service_client.wait_for_service(timeout_sec=timeout_sec)
if not client_ready:
raise ValueError(
f"Service {service_name} not ready within {timeout_sec} seconds. "
"Try increasing the timeout or check if the service is running."
)
if os.getenv("ROS_DISTRO") == "humble":
return service_client.call(srv_msg)

def _call_service(client: Client, timeout_sec: float) -> Any:
is_service_available = client.wait_for_service(timeout_sec=timeout_sec)
if not is_service_available:
raise ValueError(
f"Service {service_name} not ready within {timeout_sec} seconds. "
"Try increasing the timeout or check if the service is running."
)
if os.getenv("ROS_DISTRO") == "humble":
return client.call(srv_msg)
else:
return client.call(srv_msg, timeout_sec=timeout_sec)

if reuse_client:
with self._persistent_clients_lock:
client = self._persistent_clients.get(service_name, None)
if client is None:
client = self.node.create_client(srv_cls, service_name) # type: ignore
self._persistent_clients[service_name] = client
return _call_service(client, timeout_sec)
else:
return service_client.call(srv_msg, timeout_sec=timeout_sec)
client = self.node.create_client(srv_cls, service_name) # type: ignore
return _call_service(client, timeout_sec)

def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]:
return self.node.get_service_names_and_types()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,25 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
f"{self.__class__.__name__} instance must have an attribute '_service_api' of type ROS2ServiceAPI"
)

def release_client(self, service_name: str) -> bool:
return self._service_api.release_client(service_name)

def service_call(
self,
message: ROS2Message,
target: str,
timeout_sec: float = 5.0,
*,
msg_type: str,
reuse_client: bool = True,
**kwargs: Any,
) -> ROS2Message:
msg = self._service_api.call_service(
service_name=target,
service_type=msg_type,
request=message.payload,
timeout_sec=timeout_sec,
reuse_client=reuse_client,
)
return ROS2Message(
payload=msg, metadata={"msg_type": str(type(msg)), "service": target}
Expand Down
36 changes: 31 additions & 5 deletions tests/communication/ros2/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,14 @@ def test_ros2_single_message_publish_wrong_qos_setup(
shutdown_executors_and_threads(executors, threads)


def service_call_helper(service_name: str, service_api: ROS2ServiceAPI):
def invoke_set_bool_service(
service_name: str, service_api: ROS2ServiceAPI, reuse_client: bool = True
):
response = service_api.call_service(
service_name,
service_type="std_srvs/srv/SetBool",
request={"data": True},
reuse_client=reuse_client,
)
assert response.success
assert response.message == "Test service called"
Expand All @@ -164,7 +167,7 @@ def test_ros2_service_single_call(

try:
service_api = ROS2ServiceAPI(node)
service_call_helper(service_name, service_api)
invoke_set_bool_service(service_name, service_api)
finally:
shutdown_executors_and_threads(executors, threads)

Expand All @@ -186,7 +189,30 @@ def test_ros2_service_multiple_calls(
try:
service_api = ROS2ServiceAPI(node)
for _ in range(3):
service_call_helper(service_name, service_api)
invoke_set_bool_service(service_name, service_api, reuse_client=False)
finally:
shutdown_executors_and_threads(executors, threads)


@pytest.mark.parametrize(
"callback_group",
[MutuallyExclusiveCallbackGroup(), ReentrantCallbackGroup()],
ids=["MutuallyExclusiveCallbackGroup", "ReentrantCallbackGroup"],
)
def test_ros2_service_multiple_calls_with_reused_client(
ros_setup: None, request: pytest.FixtureRequest, callback_group: CallbackGroup
) -> None:
service_name = f"{request.node.originalname}_service" # type: ignore
node_name = f"{request.node.originalname}_node" # type: ignore
service_server = ServiceServer(service_name, callback_group)
node = Node(node_name)
executors, threads = multi_threaded_spinner([service_server, node])

try:
service_api = ROS2ServiceAPI(node)
for _ in range(3):
invoke_set_bool_service(service_name, service_api, reuse_client=True)
assert service_api.release_client(service_name), "Client not released"
finally:
shutdown_executors_and_threads(executors, threads)

Expand All @@ -210,7 +236,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_threading(
service_threads: List[threading.Thread] = []
for _ in range(10):
thread = threading.Thread(
target=service_call_helper, args=(service_name, service_api)
target=invoke_set_bool_service, args=(service_name, service_api)
)
service_threads.append(thread)
thread.start()
Expand Down Expand Up @@ -241,7 +267,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_multiprocessing(
service_api = ROS2ServiceAPI(node)
with Pool(10) as pool:
pool.map(
lambda _: service_call_helper(service_name, service_api), range(10)
lambda _: invoke_set_bool_service(service_name, service_api), range(10)
)
finally:
shutdown_executors_and_threads(executors, threads)
Expand Down