Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#133: Fix task cancellation breaking broadcast loop #134

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
181 changes: 111 additions & 70 deletions amqtt/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import websockets
import asyncio
import re
from asyncio import CancelledError
from asyncio import CancelledError, futures
from collections import deque
from enum import Enum

Expand Down Expand Up @@ -185,6 +185,7 @@ def __init__(self, config=None, loop=None, plugin_namespace=None):
self._broadcast_queue = asyncio.Queue()

self._broadcast_task = None
self._broadcast_shutdown_waiter = futures.Future()

# Init plugins manager
context = BrokerContext(self)
Expand Down Expand Up @@ -365,13 +366,7 @@ async def shutdown(self):
# Fire broker_shutdown event to plugins
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)

# Stop broadcast loop
if self._broadcast_task:
self._broadcast_task.cancel()
if self._broadcast_queue.qsize() > 0:
self.logger.warning(
"%d messages not broadcasted" % self._broadcast_queue.qsize()
)
await self._shutdown_broadcast_loop()

for listener_name in self._servers:
server = self._servers[listener_name]
Expand Down Expand Up @@ -884,73 +879,119 @@ async def _broadcast_loop(self):
task = running_tasks.popleft()
try:
task.result() # make asyncio happy and collect results
except CancelledError:
self.logger.info("Task has been cancelled: %s", task)
except Exception:
pass
broadcast = await self._broadcast_queue.get()
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("broadcasting %r" % broadcast)
for k_filter in self._subscriptions:
if broadcast["topic"].startswith("$") and (
k_filter.startswith("+") or k_filter.startswith("#")
):
self.logger.debug(
"[MQTT-4.7.2-1] - ignoring brodcasting $ topic to subscriptions starting with + or #"
self.logger.exception(
"Task failed and will be skipped: %s", task
)
elif self.matches(broadcast["topic"], k_filter):
subscriptions = self._subscriptions[k_filter]
for (target_session, qos) in subscriptions:
if "qos" in broadcast:
qos = broadcast["qos"]
if target_session.transitions.state == "connected":
self.logger.debug(
"broadcasting application message from %s on topic '%s' to %s"
% (
format_client_message(
session=broadcast["session"]
),
broadcast["topic"],
format_client_message(session=target_session),
)
)
handler = self._get_handler(target_session)
task = asyncio.ensure_future(
handler.mqtt_publish(
broadcast["topic"],
broadcast["data"],
qos,
retain=False,
),
)
running_tasks.append(task)
else:
self.logger.debug(
"retaining application message from %s on topic '%s' to client '%s'"
% (
format_client_message(
session=broadcast["session"]
),
broadcast["topic"],
format_client_message(session=target_session),
)
)
retained_message = RetainedApplicationMessage(
broadcast["session"],
broadcast["topic"],
broadcast["data"],
qos,
)
await target_session.retained_messages.put(
retained_message
)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
f"target_session.retained_messages={target_session.retained_messages.qsize()}"
)
except CancelledError:

run_broadcast_task = asyncio.Task(self._run_broadcast(running_tasks))

completed, _ = await asyncio.wait(
[run_broadcast_task, self._broadcast_shutdown_waiter],
return_when=asyncio.FIRST_COMPLETED,
)

# Shutdown has been triggered by the broker
# So stop the loop execution
if self._broadcast_shutdown_waiter in completed:
run_broadcast_task.cancel()
break

except BaseException:
self.logger.exception("Broadcast loop stopped by exception")
raise
finally:
# Wait until current broadcasting tasks end
if running_tasks:
await asyncio.wait(running_tasks)
raise # reraise per CancelledError semantics

async def _run_broadcast(self, running_tasks: deque):
broadcast = await self._broadcast_queue.get()

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("broadcasting %r", broadcast)

for k_filter in self._subscriptions:
if broadcast["topic"].startswith("$") and (
k_filter.startswith("+") or k_filter.startswith("#")
):
self.logger.debug(
"[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #"
)
continue

# Skip all subscriptions which do not match the topic
if not self.matches(broadcast["topic"], k_filter):
continue

subscriptions = self._subscriptions[k_filter]
for (target_session, qos) in subscriptions:
qos = broadcast.get("qos", qos)

# Retain all messages which cannot be broadcasted
# due to the session not being connected
if target_session.transitions.state != "connected":
await self._retain_broadcast_message(broadcast, qos, target_session)
continue

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"broadcasting application message from %s on topic '%s' to %s"
% (
format_client_message(session=broadcast["session"]),
broadcast["topic"],
format_client_message(session=target_session),
)
)

handler = self._get_handler(target_session)
task = asyncio.ensure_future(
handler.mqtt_publish(
broadcast["topic"],
broadcast["data"],
qos,
retain=False,
),
)
running_tasks.append(task)

async def _retain_broadcast_message(self, broadcast, qos, target_session):
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"retaining application message from %s on topic '%s' to client '%s'",
format_client_message(session=broadcast["session"]),
broadcast["topic"],
format_client_message(session=target_session),
)

retained_message = RetainedApplicationMessage(
broadcast["session"],
broadcast["topic"],
broadcast["data"],
qos,
)
await target_session.retained_messages.put(retained_message)

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"target_session.retained_messages=%s",
target_session.retained_messages.qsize(),
)

async def _shutdown_broadcast_loop(self):
if self._broadcast_task:
self._broadcast_shutdown_waiter.set_result(True)
try:
await asyncio.wait_for(self._broadcast_task, timeout=30)
except BaseException as e:
self.logger.warning("Failed to cleanly shutdown broadcast loop: %r", e)

if self._broadcast_queue.qsize() > 0:
self.logger.warning(
"%d messages not broadcasted", self._broadcast_queue.qsize()
)

async def _broadcast_message(self, session, topic, data, force_qos=None):
broadcast = {"session": session, "topic": topic, "data": data}
Expand Down
35 changes: 34 additions & 1 deletion tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# See the file license.txt for copying permission.
import asyncio
import logging
from unittest.mock import call, MagicMock
import sys
from unittest.mock import call, MagicMock, patch

import pytest

Expand Down Expand Up @@ -31,6 +32,7 @@
)
from amqtt.mqtt.connect import ConnectVariableHeader, ConnectPayload
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler


formatter = (
Expand Down Expand Up @@ -625,3 +627,34 @@ def test_matches_single_level_wildcard(broker):
"sport/tennis/player2",
]:
assert broker.matches(good_topic, test_filter)


@pytest.mark.asyncio
async def test_broker_broadcast_cancellation(broker):
topic = "test"
data = b"data"
qos = QOS_0

sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1")
await sub_client.subscribe([(topic, qos)])

with patch.object(
BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError
) as mocked_mqtt_publish:
await _client_publish(topic, data, qos)

# Second publish triggers the awaiting of first `mqtt_publish` task
await _client_publish(topic, data, qos)
await asyncio.sleep(0.01)

# `assert_awaited` does not exist in Python before `3.8`
if sys.version_info >= (3, 8):
mocked_mqtt_publish.assert_awaited()
else:
mocked_mqtt_publish.assert_called()

# Ensure broadcast loop is still functional and can deliver the message
await _client_publish(topic, data, qos)
message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1)
assert message