Skip to content

Commit

Permalink
Fix task cancellation breaking broadcast loop
Browse files Browse the repository at this point in the history
- Differentiate between broadcast shutdown and task cancellation
- Minor refactor to make the code more readable
- Improved and simplified logging
  • Loading branch information
FabianElsmer committed Feb 6, 2023
1 parent 19960e6 commit 7dcd1df
Showing 1 changed file with 106 additions and 70 deletions.
176 changes: 106 additions & 70 deletions amqtt/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import websockets
import asyncio
import re
from asyncio import CancelledError
from asyncio import CancelledError, futures
from collections import deque
from enum import Enum

Expand Down Expand Up @@ -185,6 +185,7 @@ def __init__(self, config=None, loop=None, plugin_namespace=None):
self._broadcast_queue = asyncio.Queue()

self._broadcast_task = None
self._broadcast_shutdown_waiter = futures.Future()

# Init plugins manager
context = BrokerContext(self)
Expand Down Expand Up @@ -365,13 +366,7 @@ async def shutdown(self):
# Fire broker_shutdown event to plugins
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)

# Stop broadcast loop
if self._broadcast_task:
self._broadcast_task.cancel()
if self._broadcast_queue.qsize() > 0:
self.logger.warning(
"%d messages not broadcasted" % self._broadcast_queue.qsize()
)
await self._shutdown_broadcast_loop()

for listener_name in self._servers:
server = self._servers[listener_name]
Expand Down Expand Up @@ -885,72 +880,113 @@ async def _broadcast_loop(self):
try:
task.result() # make asyncio happy and collect results
except Exception:
pass
broadcast = await self._broadcast_queue.get()
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("broadcasting %r" % broadcast)
for k_filter in self._subscriptions:
if broadcast["topic"].startswith("$") and (
k_filter.startswith("+") or k_filter.startswith("#")
):
self.logger.debug(
"[MQTT-4.7.2-1] - ignoring brodcasting $ topic to subscriptions starting with + or #"
self.logger.exception(
"Task failed and will be skipped: %s", task
)
elif self.matches(broadcast["topic"], k_filter):
subscriptions = self._subscriptions[k_filter]
for (target_session, qos) in subscriptions:
if "qos" in broadcast:
qos = broadcast["qos"]
if target_session.transitions.state == "connected":
self.logger.debug(
"broadcasting application message from %s on topic '%s' to %s"
% (
format_client_message(
session=broadcast["session"]
),
broadcast["topic"],
format_client_message(session=target_session),
)
)
handler = self._get_handler(target_session)
task = asyncio.ensure_future(
handler.mqtt_publish(
broadcast["topic"],
broadcast["data"],
qos,
retain=False,
),
)
running_tasks.append(task)
else:
self.logger.debug(
"retaining application message from %s on topic '%s' to client '%s'"
% (
format_client_message(
session=broadcast["session"]
),
broadcast["topic"],
format_client_message(session=target_session),
)
)
retained_message = RetainedApplicationMessage(
broadcast["session"],
broadcast["topic"],
broadcast["data"],
qos,
)
await target_session.retained_messages.put(
retained_message
)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
f"target_session.retained_messages={target_session.retained_messages.qsize()}"
)
except CancelledError:
except CancelledError:
self.logger.warn("Task has been cancelled: %s", task)

run_broadcast_task = self._run_broadcast(running_tasks)

completed, _ = await asyncio.wait(
[run_broadcast_task, self._broadcast_shutdown_waiter],
return_when=asyncio.FIRST_COMPLETED,
)

# Shutdown has been triggered by the broker
# So stop the loop execution
if self._broadcast_shutdown_waiter in completed:
break

except BaseException:
self.logger.exception("Broadcast loop stopped by exception")
raise
finally:
# Wait until current broadcasting tasks end
if running_tasks:
await asyncio.wait(running_tasks)
raise # reraise per CancelledError semantics

async def _run_broadcast(self, running_tasks: deque):
broadcast = await self._broadcast_queue.get()

self.logger.debug("broadcasting %r", broadcast)

for k_filter in self._subscriptions:
if broadcast["topic"].startswith("$") and (
k_filter.startswith("+") or k_filter.startswith("#")
):
self.logger.debug(
"[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #"
)
continue

# Skip all subscriptions which do not match the topic
if not self.matches(broadcast["topic"], k_filter):
continue

subscriptions = self._subscriptions[k_filter]
for (target_session, qos) in subscriptions:
qos = broadcast.get("qos", qos)

# Retain all messages which cannot be broadcasted
# due to the session not being connected
if target_session.transitions.state != "connected":
task = asyncio.ensure_future(
self._retain_broadcast_message(broadcast, qos, target_session)
)
running_tasks.append(task)
continue

self.logger.debug(
"broadcasting application message from %s on topic '%s' to %s"
% (
format_client_message(session=broadcast["session"]),
broadcast["topic"],
format_client_message(session=target_session),
)
)
handler = self._get_handler(target_session)
task = asyncio.ensure_future(
handler.mqtt_publish(
broadcast["topic"],
broadcast["data"],
qos,
retain=False,
),
)
running_tasks.append(task)

async def _retain_broadcast_message(self, broadcast, qos, target_session):
self.logger.debug(
"retaining application message from %s on topic '%s' to client '%s'",
format_client_message(session=broadcast["session"]),
broadcast["topic"],
format_client_message(session=target_session),
)
retained_message = RetainedApplicationMessage(
broadcast["session"],
broadcast["topic"],
broadcast["data"],
qos,
)
await target_session.retained_messages.put(retained_message)
self.logger.debug(
"target_session.retained_messages=%s",
target_session.retained_messages.qsize(),
)

async def _shutdown_broadcast_loop(self):
if self._broadcast_task:
self._broadcast_shutdown_waiter.set_result(True)
try:
await asyncio.wait_for(self._broadcast_task, timeout=30)
except BaseException as e:
self.logger.warning("Failed to cleanly shutdown broadcast loop: %r", e)

if self._broadcast_queue.qsize() > 0:
self.logger.warning(
"%d messages not broadcasted", self._broadcast_queue.qsize()
)

async def _broadcast_message(self, session, topic, data, force_qos=None):
broadcast = {"session": session, "topic": topic, "data": data}
Expand Down

0 comments on commit 7dcd1df

Please sign in to comment.