Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stub): add fail_fast param to StubBroker.join #203

Merged
merged 4 commits into from Apr 5, 2019
Merged
Diff settings

Always

Just for now

Next

feat(stub): add fail_fast param to StubBroker.join

Closes #195.  This makes it possible to receive the original exception
inside the joining thread.
  • Loading branch information...
Bogdanp committed Apr 4, 2019
commit 3777f7a71db95dcfaabd1c0a6d1262f153028496
@@ -319,6 +319,14 @@ class MessageProxy:
def __init__(self, message):
self.failed = False
self._message = message
self._exception = None

def stuff_exception(self, exception):
"""Stuff an exception into this message. Currently, this is
used by the stub broker to known why a particular message has
failed.
"""
self._exception = exception

def fail(self):
"""Mark this message for rejection.
@@ -15,6 +15,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
from collections import defaultdict
from itertools import chain
from queue import Empty, Queue

@@ -26,16 +27,18 @@

class StubBroker(Broker):
"""A broker that can be used within unit tests.
Attributes:
dead_letters(list[Message]): Contains the dead-lettered messages
for all defined queues.
"""

def __init__(self, middleware=None):
super().__init__(middleware)

self.dead_letters = []
self.dead_letters_by_queue = defaultdict(list)

@property
def dead_letters(self):
"""The dead-lettered messages for all defined queues.
"""
return [message for messages in self.dead_letters_by_queue.values() for message in messages]

def consume(self, queue_name, prefetch=1, timeout=100):
"""Create a new consumer for a queue.
@@ -52,7 +55,11 @@ def consume(self, queue_name, prefetch=1, timeout=100):
Consumer: A consumer that retrieves messages from Redis.
"""
try:
return _StubConsumer(self.queues[queue_name], self.dead_letters, timeout)
return _StubConsumer(
self.queues[queue_name],
self.dead_letters_by_queue[queue_name],
timeout,
)
except KeyError:
raise QueueNotFound(queue_name)

@@ -119,7 +126,8 @@ def flush_all(self):
for queue_name in chain(self.queues, self.delay_queues):
self.flush(queue_name)

def join(self, queue_name, *, timeout=None):
# TODO: Make fail_fast default to True.
def join(self, queue_name, *, fail_fast=False, timeout=None):
"""Wait for all the messages on the given queue to be
processed. This method is only meant to be used in tests
to wait for all the messages in a queue to be processed.
@@ -130,27 +138,40 @@ def join(self, queue_name, *, timeout=None):
Parameters:
queue_name(str): The queue to wait on.
fail_fast(bool): When this is True and any message gets
dead-lettered during the join, then an exception will be
raised. This will be True by default starting with
version 2.0.
timeout(Optional[int]): The max amount of time, in
milliseconds, to wait on this queue.
"""
try:
deadline = timeout and time.monotonic() + timeout / 1000
while True:
for name in [queue_name, dq_name(queue_name)]:
timeout = deadline and deadline - time.monotonic()
join_queue(self.queues[name], timeout=timeout)

# We cycle through $queue then $queue.DQ then $queue
# again in case the messages that were on the DQ got
# moved back on $queue.
for name in [queue_name, dq_name(queue_name)]:
if self.queues[name].unfinished_tasks:
break
else:
return
queues = [
self.queues[queue_name],
self.queues[dq_name(queue_name)],
]
except KeyError:
raise QueueNotFound(queue_name)

deadline = timeout and time.monotonic() + timeout / 1000
while True:
for queue in queues:
timeout = deadline and deadline - time.monotonic()
join_queue(queue, timeout=timeout)

# We cycle through $queue then $queue.DQ then $queue
# again in case the messages that were on the DQ got
# moved back on $queue.
for queue in queues:
if queue.unfinished_tasks:
break
else:
if fail_fast:
for message in self.dead_letters_by_queue[queue_name]:
raise message._exception from None

return


class _StubConsumer(Consumer):
def __init__(self, queue, dead_letters, timeout):
@@ -439,6 +439,11 @@ def process_message(self, message):
self.broker.emit_after("skip_message", message)

except BaseException as e:
# Stuff the exception into the message [proxy] so that it
# may be used by the stub broker to provide a nicer
# testing experience.
message.stuff_exception(e)

if isinstance(e, RateLimitExceeded):
self.logger.warning("Rate limit exceeded in message %s: %s.", message, e)
else:
@@ -64,3 +64,21 @@ def do_work():
# Then I expect a QueueJoinTimeout to be raised
with pytest.raises(QueueJoinTimeout):
stub_broker.join(do_work.queue_name, timeout=500)


def test_stub_broker_join_reraises_actor_exceptions_in_the_joining_current_thread(stub_broker, stub_worker):
# Given that I have an actor that always fails with a custom exception
class CustomError(Exception):
pass

@dramatiq.actor(max_retries=0)
def do_work():
raise CustomError("well, shit")

# When I send that actor a message
do_work.send()

# And join on its queue
# Then that exception should be raised in my thread
with pytest.raises(CustomError):
stub_broker.join(do_work.queue_name, fail_fast=True)
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.