Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
DomenicP committed Sep 30, 2021
1 parent 516c6a6 commit 4fb65c8
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def unregister(self):
"""Unsubscribes this subscription and cleans up resources"""
manager.unsubscribe(self.client_id, self.topic)
with self.handler_lock:
self.handler.finish()
self.handler.finish(block=False)
self.clients.clear()

def subscribe(
Expand Down
31 changes: 17 additions & 14 deletions rosbridge_library/src/rosbridge_library/internal/subscribers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class SubscriberManager:
"""

def __init__(self):
self._lock = Lock()
self._subscribers = {}

def subscribe(self, client_id, topic, callback, node_handle, msg_type=None, raw=False):
Expand All @@ -234,15 +235,16 @@ def subscribe(self, client_id, topic, callback, node_handle, msg_type=None, raw=
msg_type -- (optional) the type of the topic
"""
if topic not in self._subscribers:
self._subscribers[topic] = MultiSubscriber(
topic, client_id, callback, node_handle, msg_type=msg_type, raw=raw
)
else:
self._subscribers[topic].subscribe(client_id, callback)
with self._lock:
if topic not in self._subscribers:
self._subscribers[topic] = MultiSubscriber(
topic, client_id, callback, node_handle, msg_type=msg_type, raw=raw
)
else:
self._subscribers[topic].subscribe(client_id, callback)

if msg_type is not None and not raw:
self._subscribers[topic].verify_type(msg_type)
if msg_type is not None and not raw:
self._subscribers[topic].verify_type(msg_type)

def unsubscribe(self, client_id, topic):
"""Unsubscribe from a topic
Expand All @@ -252,14 +254,15 @@ def unsubscribe(self, client_id, topic):
topic -- the topic to unsubscribe from
"""
if topic not in self._subscribers:
return
with self._lock:
if topic not in self._subscribers:
return

self._subscribers[topic].unsubscribe(client_id)
self._subscribers[topic].unsubscribe(client_id)

if not self._subscribers[topic].has_subscribers():
self._subscribers[topic].unregister()
del self._subscribers[topic]
if not self._subscribers[topic].has_subscribers():
self._subscribers[topic].unregister()
del self._subscribers[topic]


manager = SubscriberManager()
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import sys
import traceback
from collections import deque
from threading import Condition, Thread
from time import time

Expand Down Expand Up @@ -77,7 +80,7 @@ def transition(self):
else:
return QueueMessageHandler(self)

def finish(self):
def finish(self, block=True):
pass


Expand All @@ -94,7 +97,7 @@ def transition(self):
else:
return QueueMessageHandler(self)

def finish(self):
def finish(self, block=True):
pass


Expand All @@ -103,17 +106,17 @@ def __init__(self, previous_handler):
Thread.__init__(self)
MessageHandler.__init__(self, previous_handler)
self.daemon = True
self.queue = []
self.queue = deque(maxlen=self.queue_length)
self.c = Condition()
self.alive = True
self.start()

def handle_message(self, msg):
with self.c:
if not self.alive:
return
should_notify = len(self.queue) == 0
self.queue.append(msg)
if len(self.queue) > self.queue_length:
del self.queue[0 : len(self.queue) - self.queue_length]
if should_notify:
self.c.notify()

Expand All @@ -126,37 +129,40 @@ def transition(self):
return ThrottleMessageHandler(self)
else:
with self.c:
if len(self.queue) > self.queue_length:
del self.queue[0 : len(self.queue) - self.queue_length]
old_queue = self.queue
self.queue = deque(maxlen=self.queue_length)
while len(old_queue) > 0:
self.queue.append(old_queue.popleft())
self.c.notify()
return self

def finish(self):
def finish(self, block=True):
"""If throttle was set to 0, this pushes all buffered messages"""
# Notify the thread to finish
with self.c:
self.alive = False
self.c.notify()

self.join()
if block:
self.join()

def run(self):
while self.alive:
msg = None
with self.c:
while self.alive and (self.time_remaining() > 0 or len(self.queue) == 0):
if len(self.queue) == 0:
self.c.wait()
else:
self.c.wait(self.time_remaining())
if len(self.queue) == 0:
self.c.wait()
else:
self.c.wait(self.time_remaining())
if self.alive and self.time_remaining() == 0 and len(self.queue) > 0:
try:
MessageHandler.handle_message(self, self.queue[0])
except Exception:
pass
del self.queue[0]
msg = self.queue.popleft()
if msg is not None:
try:
MessageHandler.handle_message(self, msg)
except:
traceback.print_exc(file=sys.stderr)
while self.time_remaining() == 0 and len(self.queue) > 0:
try:
MessageHandler.handle_message(self, self.queue[0])
except Exception:
pass
del self.queue[0]
except:
traceback.print_exc(file=sys.stderr)
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,42 @@ def cb(msg):
finally:
handler.finish()

def test_queue_message_handler_dropping(self):
received = {"msgs": []}

def cb(msg):
received["msgs"].append(msg)
time.sleep(1)

queue_length = 5
msgs = range(queue_length * 5)

handler = subscribe.MessageHandler(None, cb)

handler = handler.set_queue_length(queue_length)
self.assertIsInstance(handler, subscribe.QueueMessageHandler)

# send all messages at once.
# only the first and the last queue_length should get through,
# because the callbacks are blocked.
for x in msgs:
handler.handle_message(x)
# yield the thread so the first callback can append,
# otherwise the first handled value is non-deterministic.
time.sleep(0)

# wait long enough for all the callbacks, and then some.
time.sleep(queue_length + 3)

try:
self.assertEqual([msgs[0]] + msgs[-queue_length:], received["msgs"])
except:
handler.finish()
raise

handler.finish()


def test_queue_message_handler_rate(self):
handler = subscribe.MessageHandler(None, self.dummy_cb)
self.help_test_queue_rate(handler, 50, 10)
Expand Down
51 changes: 49 additions & 2 deletions rosbridge_server/src/rosbridge_server/websocket_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import threading
import traceback
import uuid
from collections import deque
from functools import partial, wraps

from rosbridge_library.rosbridge_protocol import RosbridgeProtocol
Expand Down Expand Up @@ -65,6 +66,50 @@ def wrapper(*args, **kwargs):
return wrapper


class IncomingQueue(threading.Thread):
"""Decouples incoming messages from the Tornado thread.
This mitigates cases where outgoing messages are blocked by incoming,
and vice versa.
"""
def __init__(self, protocol):
threading.Thread.__init__(self)
self.daemon = True
self.queue = deque()
self.protocol = protocol

self.cond = threading.Condition()
self._finished = False

def finish(self):
"""Clear the queue and do not accept further messages."""
with self.cond:
self._finished = True
while len(self.queue) > 0:
self.queue.popleft()
self.cond.notify()

def push(self, msg):
with self.cond:
self.queue.append(msg)
self.cond.notify()

def run(self):
while True:
with self.cond:
if len(self.queue) == 0 and not self._finished:
self.cond.wait()

if self._finished:
break

msg = self.queue.popleft()

self.protocol.incoming(msg)

self.protocol.finish()


class RosbridgeWebSocket(WebSocketHandler):
client_id_seed = 0
clients_connected = 0
Expand Down Expand Up @@ -94,6 +139,8 @@ def open(self):
self.protocol = RosbridgeProtocol(
cls.client_id_seed, cls.node_handle, parameters=parameters
)
self.incoming_queue = IncomingQueue(self.protocol)
self.incoming_queue.start()
self.protocol.outgoing = self.send_message
self.set_nodelay(True)
self._write_lock = threading.RLock()
Expand All @@ -115,18 +162,18 @@ def open(self):
def on_message(self, message):
if isinstance(message, bytes):
message = message.decode("utf-8")
self.protocol.incoming(message)
self.incoming_queue.push(message)

@log_exceptions
def on_close(self):
cls = self.__class__
cls.clients_connected -= 1
self.protocol.finish()
if cls.client_manager:
cls.client_manager.remove_client(self.client_id, self.request.remote_ip)
cls.node_handle.get_logger().info(
f"Client disconnected. {cls.clients_connected} clients total."
)
self.incoming_queue.finish()

def send_message(self, message):
if isinstance(message, bson.BSON):
Expand Down

0 comments on commit 4fb65c8

Please sign in to comment.