Skip to content

Commit

Permalink
Replace busy wait in AdvertiseService with async handler (#666)
Browse files Browse the repository at this point in the history
Fixes #650 (closes #652). Adds a test to show that multiple service calls can be handled concurrently.
  • Loading branch information
jtbandes authored Nov 29, 2021
1 parent c54ba61 commit 34d7bd5
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fnmatch
import time
from threading import Lock

import rclpy
from rclpy.callback_groups import ReentrantCallbackGroup
from rosbridge_library.capability import Capability
from rosbridge_library.internal import message_conversion
from rosbridge_library.internal.ros_loader import get_service_class
Expand All @@ -10,31 +10,32 @@
class AdvertisedServiceHandler:

id_counter = 1
responses = {}

def __init__(self, service_name, service_type, protocol):
self.active_requests = 0
self.shutdown_requested = False
self.lock = Lock()
self.request_futures = {}
self.service_name = service_name
self.service_type = service_type
self.protocol = protocol
# setup the service
self.service_handle = protocol.node_handle.create_service(
get_service_class(service_type), service_name, self.handle_request
get_service_class(service_type),
service_name,
self.handle_request,
callback_group=ReentrantCallbackGroup(), # https://github.com/ros2/rclpy/issues/834#issuecomment-961331870
)

def next_id(self):
id = self.id_counter
self.id_counter += 1
return id

def handle_request(self, req):
with self.lock:
self.active_requests += 1
async def handle_request(self, req, res):
# generate a unique ID
request_id = "service_request:" + self.service_name + ":" + str(self.next_id())

future = rclpy.task.Future()
self.request_futures[request_id] = future

# build a request to send to the external client
request_message = {
"op": "call_service",
Expand All @@ -44,41 +45,36 @@ def handle_request(self, req):
}
self.protocol.send(request_message)

# wait for a response
while request_id not in self.responses.keys():
with self.lock:
if self.shutdown_requested:
break
time.sleep(0)

with self.lock:
self.active_requests -= 1
res = await future
del self.request_futures[request_id]
return res

if self.shutdown_requested:
self.protocol.log(
"warning",
"Service %s was unadvertised with a service call in progress, "
"aborting service call with request ID %s" % (self.service_name, request_id),
)
return None

resp = self.responses[request_id]
del self.responses[request_id]
return resp
def handle_response(self, request_id, res):
"""
Called by the ServiceResponse capability to handle a service response from the external client.
"""
if request_id in self.request_futures:
self.request_futures[request_id].set_result(res)
else:
self.protocol.log(
"warning", f"Received service response for unrecognized id: {request_id}"
)

def graceful_shutdown(self, timeout):
def graceful_shutdown(self):
"""
Signal the AdvertisedServiceHandler to shutdown
Using this, rather than just node_handle.destroy_service, allows us
time to stop any active service requests, ending their busy wait
loops.
"""
with self.lock:
self.shutdown_requested = True
start_time = time.clock()
while time.clock() - start_time < timeout:
time.sleep(0)
if self.request_futures:
incomplete_ids = ", ".join(self.request_futures.keys())
self.protocol.log(
"warning",
f"Service {self.service_name} was unadvertised with a service call in progress, "
f"aborting service calls with request IDs {incomplete_ids}",
)
self.protocol.node_handle.destroy_service(self.service_handle)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def service_response(self, message):
resp = ros_loader.get_service_response_instance(service_handler.service_type)
message_conversion.populate_instance(values, resp)
# pass along the response
service_handler.responses[request_id] = resp
service_handler.handle_response(request_id, resp)
else:
self.protocol.log(
"error",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def unadvertise_service(self, message):

# unregister service in ROS
if service_name in self.protocol.external_service_list.keys():
self.protocol.external_service_list[service_name].graceful_shutdown(timeout=1.0)
self.protocol.external_service_list[service_name].graceful_shutdown()
self.protocol.external_service_list[service_name].service_handle.shutdown(
"Unadvertise request."
)
Expand Down
1 change: 1 addition & 0 deletions rosbridge_server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ install(FILES
if(BUILD_TESTING)
find_package(launch_testing_ament_cmake REQUIRED)
add_launch_test(test/websocket/smoke.test.py)
add_launch_test(test/websocket/advertise_service.test.py)
endif()
1 change: 1 addition & 0 deletions rosbridge_server/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
<test_depend>launch_testing</test_depend>
<test_depend>launch_testing_ros</test_depend>
<test_depend>launch_testing_ament_cmake</test_depend>
<test_depend>std_srvs</test_depend>

<export>
<build_type>ament_cmake</build_type>
Expand Down
77 changes: 77 additions & 0 deletions rosbridge_server/test/websocket/advertise_service.test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python
import os
import sys
import unittest

from rclpy.node import Node
from std_srvs.srv import SetBool
from twisted.python import log

sys.path.append(os.path.dirname(__file__)) # enable importing from common.py in this directory

import common # noqa: E402
from common import expect_messages, websocket_test # noqa: E402

log.startLogging(sys.stderr)

generate_test_description = common.generate_test_description


class TestAdvertiseService(unittest.TestCase):
@websocket_test
async def test_two_concurrent_calls(self, node: Node, ws_client):
ws_client.sendJson(
{
"op": "advertise_service",
"type": "std_srvs/SetBool",
"service": "/test_service",
}
)
client = node.create_client(SetBool, "/test_service")
client.wait_for_service()

requests_future, ws_client.message_handler = expect_messages(
2, "WebSocket", node.get_logger()
)
requests_future.add_done_callback(lambda _: node.executor.wake())

response1_future = client.call_async(SetBool.Request(data=True))
response2_future = client.call_async(SetBool.Request(data=False))

requests = await requests_future
self.assertEqual(len(requests), 2)

self.assertEqual(requests[0]["op"], "call_service")
self.assertEqual(requests[0]["service"], "/test_service")
self.assertEqual(requests[0]["args"], {"data": True})
ws_client.sendJson(
{
"op": "service_response",
"service": "/test_service",
"values": {"success": True, "message": "Hello world 1"},
"id": requests[0]["id"],
"result": True,
}
)

self.assertEqual(requests[1]["op"], "call_service")
self.assertEqual(requests[1]["service"], "/test_service")
self.assertEqual(requests[1]["args"], {"data": False})
ws_client.sendJson(
{
"op": "service_response",
"service": "/test_service",
"values": {"success": True, "message": "Hello world 2"},
"id": requests[1]["id"],
"result": True,
}
)

self.assertEqual(
await response1_future, SetBool.Response(success=True, message="Hello world 1")
)
self.assertEqual(
await response2_future, SetBool.Response(success=True, message="Hello world 2")
)

node.destroy_client(client)
22 changes: 22 additions & 0 deletions rosbridge_server/test/websocket/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,25 @@ def run_test(self):
run_websocket_test(test_fn.__name__, lambda *args: test_fn(self, *args))

return run_test


def expect_messages(count: int, description: str, logger):
"""
Convenience function to create a Future and a message handler function which gathers results
into a list and waits for the list to have the expected number of items.
"""
future = rclpy.Future()
results = []

def handler(msg):
logger.info(f"Received message on {description}: {msg}")
results.append(msg)
if len(results) == count:
logger.info(f"Received all messages on {description}")
future.set_result(results)
elif len(results) > count:
raise AssertionError(
f"Received {len(results)} messages on {description} but expected {count}"
)

return future, handler
47 changes: 13 additions & 34 deletions rosbridge_server/test/websocket/smoke.test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import sys
import unittest

import rclpy
import rclpy.task
from rclpy.node import Node
from std_msgs.msg import String
from twisted.python import log

sys.path.append(os.path.dirname(__file__)) # enable importing from common.py in this directory

import common # noqa: E402
from common import sleep, websocket_test # noqa: E402
from common import expect_messages, sleep, websocket_test # noqa: E402

log.startLogging(sys.stderr)

Expand All @@ -20,7 +19,7 @@

class TestWebsocketSmoke(unittest.TestCase):
@websocket_test
async def test_smoke(self, node, ws_client):
async def test_smoke(self, node: Node, ws_client):
# For consistency, the number of messages must not exceed the the protocol
# Subscriber queue_size.
NUM_MSGS = 10
Expand All @@ -31,33 +30,14 @@ async def test_smoke(self, node, ws_client):
B_STRING = "B" * MSG_SIZE
WARMUP_DELAY = 1.0 # seconds

ros_received = []
ws_received = []
sub_completed_future = rclpy.task.Future()
ws_completed_future = rclpy.task.Future()

def sub_handler(msg):
node.get_logger().info(f"Received message via ROS subscriber: {msg}")
ros_received.append(msg)
if len(ros_received) == NUM_MSGS:
node.get_logger().info("Received all messages on ROS subscriber")
sub_completed_future.set_result(None)
elif len(ros_received) > NUM_MSGS:
raise AssertionError(
f"Received {len(ros_received)} messages on ROS subscriber but expected {NUM_MSGS}"
)

def ws_handler(msg):
ws_received.append(msg)
if len(ws_received) == NUM_MSGS:
node.get_logger().info("Received all WebSocket messages")
ws_completed_future.set_result(None)
elif len(ws_received) > NUM_MSGS:
raise AssertionError(
f"Received {len(ws_received)} WebSocket messages but expected {NUM_MSGS}"
)

ws_client.message_handler = ws_handler
sub_completed_future, sub_handler = expect_messages(
NUM_MSGS, "ROS subscriber", node.get_logger()
)
ws_completed_future, ws_client.message_handler = expect_messages(
NUM_MSGS, "WebSocket", node.get_logger()
)
ws_completed_future.add_done_callback(lambda _: node.executor.wake())

sub_a = node.create_subscription(String, A_TOPIC, sub_handler, NUM_MSGS)
pub_b = node.create_publisher(String, B_TOPIC, NUM_MSGS)

Expand Down Expand Up @@ -93,9 +73,8 @@ def ws_handler(msg):
for _ in range(NUM_MSGS):
pub_b.publish(String(data=B_STRING))

ws_completed_future.add_done_callback(lambda _: node.executor.wake())
await sub_completed_future
await ws_completed_future
ros_received = await sub_completed_future
ws_received = await ws_completed_future

for msg in ws_received:
self.assertEqual("publish", msg["op"])
Expand Down

0 comments on commit 34d7bd5

Please sign in to comment.