Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ async def _queue_loop(
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": queue_name,
"state": state.copy(),
"extensions": {"message.ack.out_of_order": {}},
}
await self._app(
scope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __eq__(self, other: Any) -> bool:

@pytest.fixture(scope="module")
async def localstack_container() -> AsyncGenerator[LocalStackContainer, None]:
with LocalStackContainer(
image="localstack/localstack:4.9.2"
with LocalStackContainer(image="localstack/localstack:4.9.2").with_services(
"sqs"
) as localstack_container:
yield localstack_container

Expand Down Expand Up @@ -89,6 +89,7 @@ async def test_message(
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": queue_name,
"state": {},
"extensions": {"message.ack.out_of_order": {}},
}
message_receive = await receive()

Expand Down Expand Up @@ -131,6 +132,7 @@ async def test_message_nack(
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": queue_name,
"state": {},
"extensions": {"message.ack.out_of_order": {}},
}
message_receive = await receive()

Expand Down Expand Up @@ -290,6 +292,7 @@ async def test_lifespan(
"amgi": {"spec_version": "1.0", "version": "1.0"},
"type": "message",
"state": {"item": state_item},
"extensions": {"message.ack.out_of_order": {}},
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ async def _call_source_batch(
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": event_source_arn_match["queue"],
"state": self._state.copy(),
"extensions": {"message.ack.out_of_order": {}},
}

records_send = _Send(self._queue_url_cache, self._send_batcher, message_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def test_sqs_handler_records(app: MockApp, sqs_handler: SqsHandler) -> Non
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": "my-queue",
"state": {},
"extensions": {"message.ack.out_of_order": {}},
}

assert await receive() == {
Expand Down Expand Up @@ -178,6 +179,7 @@ async def test_sqs_handler_record_nack(app: MockApp, sqs_handler: SqsHandler) ->
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": "my-queue",
"state": {},
"extensions": {"message.ack.out_of_order": {}},
}

assert await receive() == {
Expand Down Expand Up @@ -243,6 +245,7 @@ async def test_sqs_handler_record_unacked(
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": "my-queue",
"state": {},
"extensions": {"message.ack.out_of_order": {}},
}

assert await receive() == {
Expand Down Expand Up @@ -300,6 +303,7 @@ async def test_sqs_handler_record_message_attribute_binary_value(
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": "my-queue",
"state": {},
"extensions": {"message.ack.out_of_order": {}},
}

assert await receive() == {
Expand Down Expand Up @@ -415,6 +419,7 @@ async def test_lifespan() -> None:
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": "my-queue",
"state": {"item": state_item},
"extensions": {"message.ack.out_of_order": {}},
}

await call_task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

@pytest.fixture(scope="module")
async def localstack_container() -> AsyncGenerator[LocalStackContainer, None]:
with LocalStackContainer(
image="localstack/localstack:4.9.2"
with LocalStackContainer(image="localstack/localstack:4.9.2").with_services(
"sqs"
) as localstack_container:
yield localstack_container

Expand Down
1 change: 1 addition & 0 deletions packages/amgi-types/src/amgi_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class MessageScope(TypedDict):
amgi: AMGIVersions
address: str
state: NotRequired[dict[str, Any]]
extensions: NotRequired[dict[str, dict[str, Any]]]


class LifespanScope(TypedDict):
Expand Down
83 changes: 53 additions & 30 deletions packages/asyncfast/src/asyncfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,17 @@ async def _handle_generator(
exception = None


async def _receive_messages(
receive: AMGIReceiveCallable,
) -> AsyncGenerator[MessageReceiveEvent, None]:
more_messages = True
while more_messages:
message = await receive()
assert message["type"] == "message.receive"
yield message
more_messages = message.get("more_messages", False)


class Channel:

def __init__(
Expand Down Expand Up @@ -566,36 +577,48 @@ async def __call__(
send: AMGISendCallable,
parameters: dict[str, str],
) -> None:
more_messages = True
while more_messages:
message = await receive()
if message["type"] != "message.receive":
continue
more_messages = message.get("more_messages", False)
try:
arguments = dict(self._generate_arguments(message, parameters, send))

if inspect.isasyncgenfunction(self._handler):
await _handle_async_generator(self._handler, arguments, send)
elif inspect.isgeneratorfunction(self._handler):
await _handle_generator(self._handler, arguments, send)
elif inspect.iscoroutinefunction(self._handler):
await self._handler(**arguments)
else:
await asyncio.to_thread(self._handler, **arguments)

message_ack_event: MessageAckEvent = {
"type": "message.ack",
"id": message["id"],
}
await send(message_ack_event)
except Exception as e:
message_nack_event: MessageNackEvent = {
"type": "message.nack",
"id": message["id"],
"message": str(e),
}
await send(message_nack_event)
ack_out_of_order = "message.ack.out_of_order" in scope.get("extensions", {})
if ack_out_of_order:
await asyncio.gather(
*[
self._handle_message(message, parameters, send)
async for message in _receive_messages(receive)
]
)
else:
async for message in _receive_messages(receive):
await self._handle_message(message, parameters, send)

async def _handle_message(
self,
message: MessageReceiveEvent,
parameters: dict[str, str],
send: AMGISendCallable,
) -> None:
try:
arguments = dict(self._generate_arguments(message, parameters, send))

if inspect.isasyncgenfunction(self._handler):
await _handle_async_generator(self._handler, arguments, send)
elif inspect.isgeneratorfunction(self._handler):
await _handle_generator(self._handler, arguments, send)
elif inspect.iscoroutinefunction(self._handler):
await self._handler(**arguments)
else:
await asyncio.to_thread(self._handler, **arguments)

message_ack_event: MessageAckEvent = {
"type": "message.ack",
"id": message["id"],
}
await send(message_ack_event)
except Exception as e:
message_nack_event: MessageNackEvent = {
"type": "message.nack",
"id": message["id"],
"message": str(e),
}
await send(message_nack_event)

def _generate_arguments(
self,
Expand Down
40 changes: 40 additions & 0 deletions packages/asyncfast/tests_asyncfast/test_message.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from asyncio import Event
from collections.abc import AsyncGenerator
from collections.abc import Generator
from collections.abc import Iterable
Expand Down Expand Up @@ -884,3 +885,42 @@ async def topic_handler(message_sender: MessageSender[SendMessage]) -> None:
)
]
)


async def test_message_ack_out_of_order() -> None:
app = AsyncFast()

received = set()
block_event = Event()

@app.channel("topic")
async def topic_handler(i: int) -> None:
received.add(i)
if received == {1, 2}:
block_event.set()
await block_event.wait()

message_scope: MessageScope = {
"type": "message",
"amgi": {"version": "1.0", "spec_version": "1.0"},
"address": "topic",
"extensions": {"message.ack.out_of_order": {}},
}
message_receive_event1: MessageReceiveEvent = {
"type": "message.receive",
"id": "id-1",
"headers": [],
"payload": b"1",
"more_messages": True,
}
message_receive_event2: MessageReceiveEvent = {
"type": "message.receive",
"id": "id-2",
"payload": b"2",
"headers": [],
}
await app(
message_scope,
AsyncMock(side_effect=[message_receive_event1, message_receive_event2]),
AsyncMock(),
)