Skip to content
Permalink
Browse files

feat(stub): add fail_fast param to StubBroker.join

Closes #195.  This makes it possible to receive the original exception
inside the joining thread.
  • Loading branch information...
Bogdanp committed Apr 4, 2019
1 parent 106ba99 commit f968118922cd025d236e1c99ab802767283f3f53
Showing with 73 additions and 21 deletions.
  1. +8 −0 dramatiq/broker.py
  2. +42 −21 dramatiq/brokers/stub.py
  3. +5 −0 dramatiq/worker.py
  4. +18 −0 tests/test_stub_broker.py
@@ -319,6 +319,14 @@ class MessageProxy:
def __init__(self, message):
self.failed = False
self._message = message
self._exception = None

def stuff_exception(self, exception):
"""Stuff an exception into this message. Currently, this is
used by the stub broker to known why a particular message has
failed.
"""
self._exception = exception

def fail(self):
"""Mark this message for rejection.
@@ -15,6 +15,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
from collections import defaultdict
from itertools import chain
from queue import Empty, Queue

@@ -26,16 +27,18 @@

class StubBroker(Broker):
"""A broker that can be used within unit tests.
Attributes:
dead_letters(list[Message]): Contains the dead-lettered messages
for all defined queues.
"""

def __init__(self, middleware=None):
super().__init__(middleware)

self.dead_letters = []
self.dead_letters_by_queue = defaultdict(list)

@property
def dead_letters(self):
"""The dead-lettered messages for all defined queues.
"""
return [message for messages in self.dead_letters_by_queue.values() for message in messages]

def consume(self, queue_name, prefetch=1, timeout=100):
"""Create a new consumer for a queue.
@@ -52,7 +55,11 @@ def consume(self, queue_name, prefetch=1, timeout=100):
Consumer: A consumer that retrieves messages from Redis.
"""
try:
return _StubConsumer(self.queues[queue_name], self.dead_letters, timeout)
return _StubConsumer(
self.queues[queue_name],
self.dead_letters_by_queue[queue_name],
timeout,
)
except KeyError:
raise QueueNotFound(queue_name)

@@ -119,7 +126,8 @@ def flush_all(self):
for queue_name in chain(self.queues, self.delay_queues):
self.flush(queue_name)

def join(self, queue_name, *, timeout=None):
# TODO: Make fail_fast default to True.
def join(self, queue_name, *, fail_fast=False, timeout=None):
"""Wait for all the messages on the given queue to be
processed. This method is only meant to be used in tests
to wait for all the messages in a queue to be processed.
@@ -130,27 +138,40 @@ def join(self, queue_name, *, timeout=None):
Parameters:
queue_name(str): The queue to wait on.
fail_fast(bool): When this is True and any message gets
dead-lettered during the join, then an exception will be
raised. This will be True by default starting with
version 2.0.
timeout(Optional[int]): The max amount of time, in
milliseconds, to wait on this queue.
"""
try:
deadline = timeout and time.monotonic() + timeout / 1000
while True:
for name in [queue_name, dq_name(queue_name)]:
timeout = deadline and deadline - time.monotonic()
join_queue(self.queues[name], timeout=timeout)

# We cycle through $queue then $queue.DQ then $queue
# again in case the messages that were on the DQ got
# moved back on $queue.
for name in [queue_name, dq_name(queue_name)]:
if self.queues[name].unfinished_tasks:
break
else:
return
queues = [
self.queues[queue_name],
self.queues[dq_name(queue_name)],
]
except KeyError:
raise QueueNotFound(queue_name)

deadline = timeout and time.monotonic() + timeout / 1000
while True:
for queue in queues:
timeout = deadline and deadline - time.monotonic()
join_queue(queue, timeout=timeout)

# We cycle through $queue then $queue.DQ then $queue
# again in case the messages that were on the DQ got
# moved back on $queue.
for queue in queues:
if queue.unfinished_tasks:
break
else:
if fail_fast:
for message in self.dead_letters_by_queue[queue_name]:
raise message._exception from None

return


class _StubConsumer(Consumer):
def __init__(self, queue, dead_letters, timeout):
@@ -439,6 +439,11 @@ def process_message(self, message):
self.broker.emit_after("skip_message", message)

except BaseException as e:
# Stuff the exception into the message [proxy] so that it
# may be used by the stub broker to provide a nicer
# testing experience.
message.stuff_exception(e)

if isinstance(e, RateLimitExceeded):
self.logger.warning("Rate limit exceeded in message %s: %s.", message, e)
else:
@@ -64,3 +64,21 @@ def do_work():
# Then I expect a QueueJoinTimeout to be raised
with pytest.raises(QueueJoinTimeout):
stub_broker.join(do_work.queue_name, timeout=500)


def test_stub_broker_join_reraises_actor_exceptions_in_the_joining_current_thread(stub_broker, stub_worker):
# Given that I have an actor that always fails with a custom exception
class CustomError(Exception):
pass

@dramatiq.actor(max_retries=0)
def do_work():
raise CustomError("well, shit")

# When I send that actor a message
do_work.send()

# And join on its queu
# Then that exception should be raised in my thread
with pytest.raises(CustomError):
stub_broker.join(do_work.queue_name, fail_fast=True)

0 comments on commit f968118

Please sign in to comment.
You can’t perform that action at this time.