Skip to content

Commit

Permalink
Port #464, #478, #496, and #502 from ROS1 branch
Browse files Browse the repository at this point in the history
  • Loading branch information
DomenicP committed Oct 1, 2021
1 parent 516c6a6 commit 74450f6
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def unregister(self):
"""Unsubscribes this subscription and cleans up resources"""
manager.unsubscribe(self.client_id, self.topic)
with self.handler_lock:
self.handler.finish()
self.handler.finish(block=False)
self.clients.clear()

def subscribe(
Expand Down
36 changes: 18 additions & 18 deletions rosbridge_library/src/rosbridge_library/internal/subscribers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def verify_type(self, msg_type):
"""
if not ros_loader.get_message_class(msg_type) is self.msg_class:
raise TypeConflictException(self.topic, msg_class_type_repr(self.msg_class), msg_type)
return

def subscribe(self, client_id, callback):
"""Subscribe the specified client to this subscriber.
Expand Down Expand Up @@ -168,8 +167,7 @@ def unsubscribe(self, client_id):
def has_subscribers(self):
"""Return true if there are subscribers"""
with self.lock:
ret = len(self.subscriptions) != 0
return ret
return len(self.subscriptions) != 0

def callback(self, msg, callbacks=None):
"""Callback for incoming messages on the rclpy subscription.
Expand All @@ -195,7 +193,6 @@ def callback(self, msg, callbacks=None):
except Exception as exc:
# Do nothing if one particular callback fails except log it
self.node_handle.get_logger().error(f"Exception calling subscribe callback: {exc}")
pass

def _new_sub_callback(self, msg):
"""
Expand All @@ -222,6 +219,7 @@ class SubscriberManager:
"""

def __init__(self):
self._lock = Lock()
self._subscribers = {}

def subscribe(self, client_id, topic, callback, node_handle, msg_type=None, raw=False):
Expand All @@ -234,15 +232,16 @@ def subscribe(self, client_id, topic, callback, node_handle, msg_type=None, raw=
msg_type -- (optional) the type of the topic
"""
if topic not in self._subscribers:
self._subscribers[topic] = MultiSubscriber(
topic, client_id, callback, node_handle, msg_type=msg_type, raw=raw
)
else:
self._subscribers[topic].subscribe(client_id, callback)
with self._lock:
if topic not in self._subscribers:
self._subscribers[topic] = MultiSubscriber(
topic, client_id, callback, node_handle, msg_type=msg_type, raw=raw
)
else:
self._subscribers[topic].subscribe(client_id, callback)

if msg_type is not None and not raw:
self._subscribers[topic].verify_type(msg_type)
if msg_type is not None and not raw:
self._subscribers[topic].verify_type(msg_type)

def unsubscribe(self, client_id, topic):
"""Unsubscribe from a topic
Expand All @@ -252,14 +251,15 @@ def unsubscribe(self, client_id, topic):
topic -- the topic to unsubscribe from
"""
if topic not in self._subscribers:
return
with self._lock:
if topic not in self._subscribers:
return

self._subscribers[topic].unsubscribe(client_id)
self._subscribers[topic].unsubscribe(client_id)

if not self._subscribers[topic].has_subscribers():
self._subscribers[topic].unregister()
del self._subscribers[topic]
if not self._subscribers[topic].has_subscribers():
self._subscribers[topic].unregister()
del self._subscribers[topic]


manager = SubscriberManager()
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import sys
import traceback
from collections import deque
from threading import Condition, Thread
from time import time

Expand Down Expand Up @@ -77,7 +80,7 @@ def transition(self):
else:
return QueueMessageHandler(self)

def finish(self):
def finish(self, block=True):
pass


Expand All @@ -94,7 +97,7 @@ def transition(self):
else:
return QueueMessageHandler(self)

def finish(self):
def finish(self, block=True):
pass


Expand All @@ -103,17 +106,17 @@ def __init__(self, previous_handler):
Thread.__init__(self)
MessageHandler.__init__(self, previous_handler)
self.daemon = True
self.queue = []
self.queue = deque(maxlen=self.queue_length)
self.c = Condition()
self.alive = True
self.start()

def handle_message(self, msg):
with self.c:
if not self.alive:
return
should_notify = len(self.queue) == 0
self.queue.append(msg)
if len(self.queue) > self.queue_length:
del self.queue[0 : len(self.queue) - self.queue_length]
if should_notify:
self.c.notify()

Expand All @@ -126,37 +129,40 @@ def transition(self):
return ThrottleMessageHandler(self)
else:
with self.c:
if len(self.queue) > self.queue_length:
del self.queue[0 : len(self.queue) - self.queue_length]
old_queue = self.queue
self.queue = deque(maxlen=self.queue_length)
while len(old_queue) > 0:
self.queue.append(old_queue.popleft())
self.c.notify()
return self

def finish(self):
def finish(self, block=True):
"""If throttle was set to 0, this pushes all buffered messages"""
# Notify the thread to finish
with self.c:
self.alive = False
self.c.notify()

self.join()
if block:
self.join()

def run(self):
while self.alive:
msg = None
with self.c:
while self.alive and (self.time_remaining() > 0 or len(self.queue) == 0):
if len(self.queue) == 0:
self.c.wait()
else:
self.c.wait(self.time_remaining())
if len(self.queue) == 0:
self.c.wait()
else:
self.c.wait(self.time_remaining())
if self.alive and self.time_remaining() == 0 and len(self.queue) > 0:
try:
MessageHandler.handle_message(self, self.queue[0])
except Exception:
pass
del self.queue[0]
msg = self.queue.popleft()
if msg is not None:
try:
MessageHandler.handle_message(self, msg)
except Exception:
traceback.print_exc(file=sys.stderr)
while self.time_remaining() == 0 and len(self.queue) > 0:
try:
MessageHandler.handle_message(self, self.queue[0])
except Exception:
pass
del self.queue[0]
traceback.print_exc(file=sys.stderr)
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,41 @@ def cb(msg):
finally:
handler.finish()

def test_queue_message_handler_dropping(self):
received = {"msgs": []}

def cb(msg):
received["msgs"].append(msg)
time.sleep(1)

queue_length = 5
msgs = range(queue_length * 5)

handler = subscribe.MessageHandler(None, cb)

handler = handler.set_queue_length(queue_length)
self.assertIsInstance(handler, subscribe.QueueMessageHandler)

# send all messages at once.
# only the first and the last queue_length should get through,
# because the callbacks are blocked.
for x in msgs:
handler.handle_message(x)
# yield the thread so the first callback can append,
# otherwise the first handled value is non-deterministic.
time.sleep(0)

# wait long enough for all the callbacks, and then some.
time.sleep(queue_length + 3)

try:
self.assertEqual([msgs[0]] + msgs[-queue_length:], received["msgs"])
except: # noqa: E722 # Will finish and raise
handler.finish()
raise

handler.finish()

def test_queue_message_handler_rate(self):
handler = subscribe.MessageHandler(None, self.dummy_cb)
self.help_test_queue_rate(handler, 50, 10)
Expand Down
57 changes: 54 additions & 3 deletions rosbridge_server/src/rosbridge_server/websocket_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import threading
import traceback
import uuid
from collections import deque
from functools import partial, wraps

from rosbridge_library.rosbridge_protocol import RosbridgeProtocol
Expand All @@ -45,6 +46,9 @@
from tornado.websocket import WebSocketClosedError, WebSocketHandler


_io_loop = IOLoop.instance()


def _log_exception():
"""Log the most recent exception to ROS."""
exc = traceback.format_exception(*sys.exc_info())
Expand All @@ -65,6 +69,51 @@ def wrapper(*args, **kwargs):
return wrapper


class IncomingQueue(threading.Thread):
"""Decouples incoming messages from the Tornado thread.
This mitigates cases where outgoing messages are blocked by incoming,
and vice versa.
"""

def __init__(self, protocol):
threading.Thread.__init__(self)
self.daemon = True
self.queue = deque()
self.protocol = protocol

self.cond = threading.Condition()
self._finished = False

def finish(self):
"""Clear the queue and do not accept further messages."""
with self.cond:
self._finished = True
while len(self.queue) > 0:
self.queue.popleft()
self.cond.notify()

def push(self, msg):
with self.cond:
self.queue.append(msg)
self.cond.notify()

def run(self):
while True:
with self.cond:
if len(self.queue) == 0 and not self._finished:
self.cond.wait()

if self._finished:
break

msg = self.queue.popleft()

self.protocol.incoming(msg)

self.protocol.finish()


class RosbridgeWebSocket(WebSocketHandler):
client_id_seed = 0
clients_connected = 0
Expand Down Expand Up @@ -94,6 +143,8 @@ def open(self):
self.protocol = RosbridgeProtocol(
cls.client_id_seed, cls.node_handle, parameters=parameters
)
self.incoming_queue = IncomingQueue(self.protocol)
self.incoming_queue.start()
self.protocol.outgoing = self.send_message
self.set_nodelay(True)
self._write_lock = threading.RLock()
Expand All @@ -115,18 +166,18 @@ def open(self):
def on_message(self, message):
if isinstance(message, bytes):
message = message.decode("utf-8")
self.protocol.incoming(message)
self.incoming_queue.push(message)

@log_exceptions
def on_close(self):
cls = self.__class__
cls.clients_connected -= 1
self.protocol.finish()
if cls.client_manager:
cls.client_manager.remove_client(self.client_id, self.request.remote_ip)
cls.node_handle.get_logger().info(
f"Client disconnected. {cls.clients_connected} clients total."
)
self.incoming_queue.finish()

def send_message(self, message):
if isinstance(message, bson.BSON):
Expand All @@ -138,7 +189,7 @@ def send_message(self, message):
binary = False

with self._write_lock:
IOLoop.instance().add_callback(partial(self.prewrite_message, message, binary))
_io_loop.add_callback(partial(self.prewrite_message, message, binary))

@coroutine
def prewrite_message(self, message, binary):
Expand Down

0 comments on commit 74450f6

Please sign in to comment.