Skip to content

Commit

Permalink
Dependencies: Update requirement for aio-pika~=8.2 (#114)
Browse files Browse the repository at this point in the history
The new major versions `7.0` and `8.0` come with important fixes that
could improve the stability of robust connections, so the minimum
required version is upped to `~=8.2`.

The requirement for `pamqp` also had to be upgraded, since this is
required by `aiormq` which is a direct dependency of `aio-pika`.

The newer version required a number of changes in the code:

* `pamqp`:

  - `specification.Basic.Ack` renamed to `commands.Basic.Ack`

* `aio-pika`:

  - `types.CloseCallbackType` renamed to `abc.ConnectionCloseCallback`
  - `Connection.add_close_callback` was replaced by
    `Connection.close_callbacks.add`
  - The `ack` and `nack` methods were made asynchronous
  - `Connection.connection` moved to `Connection.transport.connection`

Co-Authored-By: Jusong Yu <jusong.yeu@gmail.com>
  • Loading branch information
sphuber and unkcpz committed Oct 13, 2022
1 parent 0e7d02d commit 3806a9e
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 58 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ repos:
"async_generator",
"pytray>=0.2.2, <0.4.0",
"deprecation",
"aio-pika~=6.6,<6.8.2",
"pamqp~=2.0",
"aio-pika~=8.2",
"pamqp~=3.2",
"pyyaml~=5.1",
"pytest~=5.4",
"pytest-notebook",
Expand Down
8 changes: 4 additions & 4 deletions kiwipy/rmq/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ async def _on_rpc(self, subscriber, message):
"""
async with message.process(ignore_processed=True):
# Tell the sender that we've dealt with it
message.ack()
await message.ack()
msg = self._decode(message.body)

try:
Expand Down Expand Up @@ -370,19 +370,19 @@ def server_properties(self) -> Dict:
if self._connection is None:
return {}

return self._connection.connection.server_properties
return self._connection.transport.connection.server_properties

@property
def loop(self):
"""Get the event loop instance driving this communicator connection."""
return self._connection.loop

def add_close_callback(self, callback: aio_pika.types.CloseCallbackType, weak: bool = False) -> None:
def add_close_callback(self, callback: aio_pika.abc.ConnectionCloseCallback, weak: bool = False) -> None:
"""Add a callable to be called each time (after) the connection is closed.
:param weak: If True, the callback will be added to a `WeakSet`
"""
self._connection.add_close_callback(callback, weak)
self._connection.close_callbacks.add(callback, weak)

async def get_default_task_queue(self) -> tasks.RmqTaskQueue:
"""Get a default task queue.
Expand Down
2 changes: 1 addition & 1 deletion kiwipy/rmq/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def connect(self):
self._channel = await self._connection.channel(
publisher_confirms=self._confirm_deliveries, on_return_raises=True
)
self._channel.add_close_callback(self._on_channel_close)
self._channel.close_callbacks.add(self._on_channel_close)

self._exchange = await self._channel.declare_exchange(name=self.get_exchange_name(), **self._exchange_params)

Expand Down
44 changes: 23 additions & 21 deletions kiwipy/rmq/tasks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
import asyncio
import collections
import contextlib
from contextlib import asynccontextmanager
import logging
import uuid
from typing import Optional
import weakref

import aio_pika
from async_generator import async_generator, yield_, asynccontextmanager
from async_generator import async_generator, yield_
import shortuuid

import kiwipy
Expand Down Expand Up @@ -109,7 +109,7 @@ async def __aiter__(self):
# Put back any tasks that are still pending (i.e. not processed or to be processed)
for task in tasks:
if task.state == TASK_PENDING:
task.requeue()
await task.requeue()

@asynccontextmanager
@async_generator
Expand All @@ -130,7 +130,7 @@ async def next_task(self, no_ack=False, fail=True, timeout=defaults.TASK_FETCH_T
await yield_(task)
finally:
if task.state == TASK_PENDING:
task.requeue()
await task.requeue()

async def _create_task_queue(self):
"""Create and bind the task queue"""
Expand All @@ -151,7 +151,7 @@ async def _on_task(self, message: aio_pika.IncomingMessage):
"""
# Decode the message tuple into a task body for easier use
rmq_task = RmqIncomingTask(self, message)
with rmq_task.processing() as outcome:
async with rmq_task.processing() as outcome:
for subscriber in self._subscribers.values():
try:
subscriber = utils.ensure_coroutine(subscriber)
Expand Down Expand Up @@ -215,6 +215,7 @@ def __init__(self, subscriber: RmqTaskSubscriber, message: aio_pika.IncomingMess
self._task_info = TaskInfo(*subscriber._decode(message.body))
self._state = TASK_PENDING
self._outcome_ref = None # type: Optional[weakref.ReferenceType]
self._loop = self._subscriber.loop()

@property
def body(self) -> str:
Expand All @@ -233,31 +234,31 @@ def process(self) -> asyncio.Future:
raise asyncio.InvalidStateError(f'The task is {self._state}')

self._state = TASK_PROCESSING
outcome = self._create_future()
outcome = self._loop.create_future()
# Rely on the done callback to signal the end of processing
outcome.add_done_callback(self._task_done)
outcome.add_done_callback(self._on_task_done)
# Or the user let's the future get destroyed
self._outcome_ref = weakref.ref(outcome, self._outcome_destroyed)

return outcome

def requeue(self):
async def requeue(self):
if self._state not in [TASK_PENDING, TASK_PROCESSING]:
raise asyncio.InvalidStateError(f'The task is {self._state}')

self._state = TASK_REQUEUED
self._message.nack(requeue=True)
await self._message.nack(requeue=True)
self._finalise()

@contextlib.contextmanager
def processing(self):
@asynccontextmanager
async def processing(self):
"""Processing context. The task should be done at the end otherwise it's assumed the
caller doesn't want to process it and it's sent back to the queue"""
if self._state != TASK_PENDING:
raise asyncio.InvalidStateError(f'The task is {self._state}')

self._state = TASK_PROCESSING
outcome = self._subscriber.loop().create_future()
outcome = self._loop.create_future()
try:
yield outcome
except KeyboardInterrupt: # pylint: disable=try-except-raise
Expand All @@ -268,11 +269,15 @@ def processing(self):
raise
finally:
if outcome.done():
self._task_done(outcome)
await self._task_done(outcome)
else:
self.requeue()
await self.requeue()

def _on_task_done(self, outcome):
"""Schedule a task to call ``_task_done`` when the outcome is done."""
self._loop.create_task(self._task_done(outcome))

def _task_done(self, outcome: asyncio.Future):
async def _task_done(self, outcome: asyncio.Future):
assert outcome.done()
self._outcome_ref = None

Expand All @@ -283,7 +288,7 @@ def _task_done(self, outcome: asyncio.Future):
# Task is done or excepted
# Permanently store the outcome
self._state = TASK_FINISHED
self._message.ack()
await self._message.ack()

# We have to get the result from the future here (even if not replying), otherwise
# python complains that it was never retrieved in case of exception
Expand All @@ -295,7 +300,7 @@ def _task_done(self, outcome: asyncio.Future):
if not self.no_reply:
# Schedule a task to send the appropriate response
# pylint: disable=protected-access
self._subscriber.loop().create_task(self._subscriber._send_response(reply_body, self._message))
await self._subscriber._send_response(reply_body, self._message)

# Clean up
self._finalise()
Expand All @@ -306,16 +311,13 @@ def _outcome_destroyed(self, outcome_ref):
assert outcome_ref is self._outcome_ref
# This task will not be processed
self._outcome_ref = None
self.requeue()
self._loop.create_task(self.requeue())

def _finalise(self):
self._outcome_ref = None
self._subscriber = None
self._message = None

def _create_future(self) -> asyncio.Future:
return self._subscriber.loop().create_future()


class RmqTaskPublisher(messages.BasePublisherWithReplyQueue):
"""
Expand Down
4 changes: 2 additions & 2 deletions kiwipy/rmq/threadcomms.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def close(self):
del self._loop
self._closed = True

def add_close_callback(self, callback: aio_pika.types.CloseCallbackType, weak: bool = False) -> None:
def add_close_callback(self, callback: aio_pika.abc.ConnectionCloseCallback, weak: bool = False) -> None:
"""Add a callable to be called each time (after) the connection is closed.
:param weak: If True, the callback will be added to a `WeakSet`
Expand Down Expand Up @@ -258,7 +258,7 @@ def broadcast_send(self, body, sender=None, subject=None, correlation_id=None):
result = self._loop_scheduler.await_(
self._communicator.broadcast_send(body=body, sender=sender, subject=subject, correlation_id=correlation_id)
)
return isinstance(result, pamqp.specification.Basic.Ack)
return isinstance(result, pamqp.commands.Basic.Ack)

def _wrap_subscriber(self, subscriber):
""""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'sphinx-autobuild',
],
'pre-commit': ['pre-commit~=2.2', 'pylint==2.5.2'],
'rmq': ['aio-pika~=6.6,<6.8.2', 'pamqp~=2.0', 'pyyaml~=5.1'],
'rmq': ['aio-pika~=8.2', 'pamqp~=3.2', 'pyyaml~=5.1'],
'tests': [
'coverage',
'ipykernel',
Expand Down
4 changes: 2 additions & 2 deletions test/rmq/bench/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def callback(_comm, _task):
queue.remove_task_subscriber(identifier)


def clear_all_tasks(queue: rmq.RmqThreadTaskQueue):
async def clear_all_tasks(queue: rmq.RmqThreadTaskQueue):
"""Just go through all tasks picking them up so the queue is cleared"""
for task in queue:
with task.processing() as outcome:
async with task.processing() as outcome:
outcome.set_result(True)


Expand Down
6 changes: 3 additions & 3 deletions test/rmq/test_coroutine_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def broadcast_subscriber(_comm, _body, _sender=None, _subject=None, _correlation


@pytest.mark.asyncio
async def test_server_properties(communicator: kiwipy.rmq.RmqCommunicator):
def test_server_properties(communicator: kiwipy.rmq.RmqCommunicator):
props = communicator.server_properties
assert isinstance(props, dict)

assert props['product'] == b'RabbitMQ'
assert props['product'] == 'RabbitMQ'
assert 'version' in props
assert props['platform'].startswith(b'Erlang')
assert props['platform'].startswith('Erlang')
32 changes: 16 additions & 16 deletions test/rmq/test_rmq_thread_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def on_task(_comm, task):
self.assertEqual(TASK, tasks[0])
self.assertEqual(RESULT, result)

def test_task_queue_next(self):
async def test_task_queue_next(self):
"""Test creating a custom task queue"""
TASK = 'The meaning?'
RESULT = 42
Expand All @@ -114,24 +114,24 @@ def test_task_queue_next(self):

# Get the task and carry it out
with task_queue.next_task() as task:
task.process().set_result(RESULT)
await task.process().set_result(RESULT)

# Now wait for the result
result = task_future.result(timeout=self.WAIT_TIMEOUT)
self.assertEqual(RESULT, result)


def test_queue_get_next(thread_task_queue: rmq.RmqThreadTaskQueue):
async def test_queue_get_next(thread_task_queue: rmq.RmqThreadTaskQueue):
"""Test getting the next task from the queue"""
result = thread_task_queue.task_send('Hello!')
with thread_task_queue.next_task(timeout=1.) as task:
with task.processing() as outcome:
async with task.processing() as outcome:
assert task.body == 'Hello!'
outcome.set_result('Goodbye')
assert result.result() == 'Goodbye'


def test_queue_iter(thread_task_queue: rmq.RmqThreadTaskQueue):
async def test_queue_iter(thread_task_queue: rmq.RmqThreadTaskQueue):
"""Test iterating through a task queue"""
results = []

Expand All @@ -140,7 +140,7 @@ def test_queue_iter(thread_task_queue: rmq.RmqThreadTaskQueue):
results.append(thread_task_queue.task_send(i))

for task in thread_task_queue:
with task.processing() as outcome:
async with task.processing() as outcome:
outcome.set_result(task.body * 10)

concurrent.futures.wait(results)
Expand All @@ -151,7 +151,7 @@ def test_queue_iter(thread_task_queue: rmq.RmqThreadTaskQueue):
assert False, "Shouldn't get here"


def test_queue_iter_not_process(thread_task_queue: rmq.RmqThreadTaskQueue):
async def test_queue_iter_not_process(thread_task_queue: rmq.RmqThreadTaskQueue):
"""Check what happens when we iterate a queue but don't process all tasks"""
outcomes = []

Expand All @@ -162,22 +162,22 @@ def test_queue_iter_not_process(thread_task_queue: rmq.RmqThreadTaskQueue):
# Now let's see what happens when we have tasks but don't process some of them
for task in thread_task_queue:
if task.body < 5:
task.process().set_result(task.body * 10)
await task.process().set_result(task.body * 10)

concurrent.futures.wait(outcomes[:5])
for i, outcome in enumerate(outcomes[:5]):
assert outcome.result() == i * 10

# Now, to through and process the rest
for task in thread_task_queue:
task.process().set_result(task.body * 10)
await task.process().set_result(task.body * 10)

concurrent.futures.wait(outcomes)
for i, outcome in enumerate(outcomes):
assert outcome.result() == i * 10


def test_queue_task_forget(thread_task_queue: rmq.RmqThreadTaskQueue):
async def test_queue_task_forget(thread_task_queue: rmq.RmqThreadTaskQueue):
"""
Check what happens when we forget to process a task we said we would
WARNING: This test mail fail when running with a debugger as it relies on the 'outcome'
Expand All @@ -190,7 +190,7 @@ def test_queue_task_forget(thread_task_queue: rmq.RmqThreadTaskQueue):
# Get the first task and say that we will process it
outcome = None
with thread_task_queue.next_task() as task:
outcome = task.process()
outcome = await task.process()

with pytest.raises(kiwipy.exceptions.QueueEmpty):
with thread_task_queue.next_task():
Expand All @@ -201,7 +201,7 @@ def test_queue_task_forget(thread_task_queue: rmq.RmqThreadTaskQueue):

# Now the task should be back in the queue
with thread_task_queue.next_task() as task:
task.process().set_result(10)
await task.process().set_result(10)

concurrent.futures.wait(outcomes)
assert outcomes[0].result() == 10
Expand All @@ -213,14 +213,14 @@ def test_empty_queue(thread_task_queue: rmq.RmqThreadTaskQueue):
pass


def test_task_processing_exception(thread_task_queue: rmq.RmqThreadTaskQueue):
async def test_task_processing_exception(thread_task_queue: rmq.RmqThreadTaskQueue):
"""Check that if there is an exception processing a task that it is removed from the queue"""
task_future = thread_task_queue.task_send('Do this')

# The error should still get propageted in the 'worker'
with pytest.raises(RuntimeError):
with thread_task_queue.next_task(timeout=WAIT_TIMEOUT) as task:
with task.processing():
async with task.processing():
raise RuntimeError('Cannea do it captain!')

# And the task sender should get a remote exception to inform them of the problem
Expand Down Expand Up @@ -269,6 +269,6 @@ def test_server_properties(thread_communicator: kiwipy.rmq.RmqThreadCommunicator
props = thread_communicator.server_properties
assert isinstance(props, dict)

assert props['product'] == b'RabbitMQ'
assert props['product'] == 'RabbitMQ'
assert 'version' in props
assert props['platform'].startswith(b'Erlang')
assert props['platform'].startswith('Erlang')

0 comments on commit 3806a9e

Please sign in to comment.