Skip to content

Commit

Permalink
Fix concurrency timeout scoping for 2.x (#14183)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Jun 20, 2024
1 parent 62f0db6 commit 47668a4
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 53 deletions.
18 changes: 14 additions & 4 deletions src/prefect/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from prefect import get_client
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
from prefect.utilities.timeout import timeout_async

from .events import (
_emit_concurrency_acquisition_events,
Expand All @@ -26,6 +25,10 @@ class ConcurrencySlotAcquisitionError(Exception):
"""Raised when an unhandlable occurs while acquiring concurrency slots."""


class AcquireConcurrencySlotTimeoutError(TimeoutError):
"""Raised when acquiring a concurrency slot times out."""


@asynccontextmanager
async def concurrency(
names: Union[str, List[str]],
Expand Down Expand Up @@ -58,8 +61,9 @@ async def main():
```
"""
names = names if isinstance(names, list) else [names]
with timeout_async(seconds=timeout_seconds):
limits = await _acquire_concurrency_slots(names, occupy)
limits = await _acquire_concurrency_slots(
names, occupy, timeout_seconds=timeout_seconds
)
acquisition_time = pendulum.now("UTC")
emitted_events = _emit_concurrency_acquisition_events(limits, occupy)

Expand Down Expand Up @@ -91,12 +95,18 @@ async def _acquire_concurrency_slots(
names: List[str],
slots: int,
mode: Union[Literal["concurrency"], Literal["rate_limit"]] = "concurrency",
timeout_seconds: Optional[float] = None,
) -> List[MinimalConcurrencyLimitResponse]:
service = ConcurrencySlotAcquisitionService.instance(frozenset(names))
future = service.send((slots, mode))
future = service.send((slots, mode, timeout_seconds))
response_or_exception = await asyncio.wrap_future(future)

if isinstance(response_or_exception, Exception):
if isinstance(response_or_exception, TimeoutError):
raise AcquireConcurrencySlotTimeoutError(
f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)"
) from response_or_exception

raise ConcurrencySlotAcquisitionError(
f"Unable to acquire concurrency slots on {names!r}"
) from response_or_exception
Expand Down
51 changes: 29 additions & 22 deletions src/prefect/concurrency/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import asynccontextmanager
from typing import (
FrozenSet,
Optional,
Tuple,
)

Expand All @@ -13,6 +14,7 @@
from prefect._internal.concurrency import logger
from prefect._internal.concurrency.services import QueueService
from prefect.client.orchestration import PrefectClient
from prefect.utilities.timeout import timeout_async


class ConcurrencySlotAcquisitionService(QueueService):
Expand All @@ -27,10 +29,12 @@ async def _lifespan(self):
self._client = client
yield

async def _handle(self, item: Tuple[int, str, concurrent.futures.Future]):
occupy, mode, future = item
async def _handle(
self, item: Tuple[int, str, Optional[float], concurrent.futures.Future]
):
occupy, mode, timeout_seconds, future = item
try:
response = await self.acquire_slots(occupy, mode)
response = await self.acquire_slots(occupy, mode, timeout_seconds)
except Exception as exc:
# If the request to the increment endpoint fails in a non-standard
# way, we need to set the future's result so that the caller can
Expand All @@ -40,33 +44,36 @@ async def _handle(self, item: Tuple[int, str, concurrent.futures.Future]):
else:
future.set_result(response)

async def acquire_slots(self, slots: int, mode: str) -> httpx.Response:
while True:
try:
response = await self._client.increment_concurrency_slots(
names=self.concurrency_limit_names, slots=slots, mode=mode
)
except Exception as exc:
if (
isinstance(exc, httpx.HTTPStatusError)
and exc.response.status_code == status.HTTP_423_LOCKED
):
retry_after = float(exc.response.headers["Retry-After"])
await asyncio.sleep(retry_after)
async def acquire_slots(
self, slots: int, mode: str, timeout_seconds: Optional[float] = None
):
with timeout_async(timeout_seconds):
while True:
try:
response = await self._client.increment_concurrency_slots(
names=self.concurrency_limit_names, slots=slots, mode=mode
)
except Exception as exc:
if (
isinstance(exc, httpx.HTTPStatusError)
and exc.response.status_code == status.HTTP_423_LOCKED
):
retry_after = float(exc.response.headers["Retry-After"])
await asyncio.sleep(retry_after)
else:
raise exc
else:
raise exc
else:
return response
return response

def send(self, item: Tuple[int, str]):
def send(self, item: Tuple[int, str, Optional[float]]) -> concurrent.futures.Future:
with self._lock:
if self._stopped:
raise RuntimeError("Cannot put items in a stopped service instance.")

logger.debug("Service %r enqueuing item %r", self, item)
future: concurrent.futures.Future = concurrent.futures.Future()

occupy, mode = item
self._queue.put_nowait((occupy, mode, future))
occupy, mode, timeout_seconds = item
self._queue.put_nowait((occupy, mode, timeout_seconds, future))

return future
8 changes: 3 additions & 5 deletions src/prefect/concurrency/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from prefect._internal.concurrency.api import create_call, from_sync
from prefect._internal.concurrency.event_loop import get_running_loop
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
from prefect.utilities.timeout import timeout

from .asyncio import (
_acquire_concurrency_slots,
Expand Down Expand Up @@ -57,10 +56,9 @@ def main():
"""
names = names if isinstance(names, list) else [names]

with timeout(seconds=timeout_seconds):
limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync(
_acquire_concurrency_slots, names, occupy
)
limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync(
_acquire_concurrency_slots, names, occupy, timeout_seconds=timeout_seconds
)
acquisition_time = pendulum.now("UTC")
emitted_events = _emit_concurrency_acquisition_events(limits, occupy)

Expand Down
26 changes: 17 additions & 9 deletions tests/concurrency/test_concurrency_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from unittest import mock

import pytest
from httpx import HTTPStatusError, Request, Response
from prefect._vendor.starlette import status

from prefect import flow, task
from prefect.concurrency.asyncio import (
Expand Down Expand Up @@ -35,7 +36,7 @@ async def resource_heavy():
) as release_spy:
await resource_heavy()

acquire_spy.assert_called_once_with(["test"], 1)
acquire_spy.assert_called_once_with(["test"], 1, timeout_seconds=None)

# On release we calculate how many seconds the slots were occupied
# for, so here we really just want to make sure that the value
Expand Down Expand Up @@ -173,18 +174,25 @@ async def resource_heavy():


@pytest.fixture
def mock_acquire_concurrency_slots(monkeypatch):
async def blocks_forever(*args, **kwargs):
while True:
await asyncio.sleep(1)
def mock_increment_concurrency_slots(monkeypatch):
async def mocked_increment_concurrency_slots(*args, **kwargs):
response = Response(
status_code=status.HTTP_423_LOCKED,
headers={"Retry-After": "0.01"},
)
raise HTTPStatusError(
message="Locked",
request=Request("GET", "http://test.com"),
response=response,
)

monkeypatch.setattr(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
blocks_forever,
"prefect.client.orchestration.PrefectClient.increment_concurrency_slots",
mocked_increment_concurrency_slots,
)


@pytest.mark.usefixtures("concurrency_limit", "mock_acquire_concurrency_slots")
@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots")
async def test_concurrency_respects_timeout():
with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*"):
async with concurrency("test", occupy=1, timeout_seconds=0.01):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_returns_successful_response(mocked_client):
expected_mode = "concurrency"

service = ConcurrencySlotAcquisitionService.instance(frozenset(expected_names))
future = service.send((expected_slots, expected_mode))
future = service.send((expected_slots, expected_mode, None))
await service.drain()
returned_response = await asyncio.wrap_future(future)
assert returned_response == response
Expand All @@ -67,7 +67,7 @@ async def test_retries_failed_call_respects_retry_after_header(mocked_client):
service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names))

with mock.patch("prefect.concurrency.asyncio.asyncio.sleep") as sleep:
future = service.send((1, "concurrency"))
future = service.send((1, "concurrency", None))
await service.drain()
returned_response = await asyncio.wrap_future(future)

Expand All @@ -91,7 +91,7 @@ async def test_failed_call_status_code_not_retryable_returns_exception(mocked_cl
limit_names = sorted(["api", "database"])
service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names))

future = service.send((1, "concurrency"))
future = service.send((1, "concurrency", None))
await service.drain()
exception = await asyncio.wrap_future(future)

Expand All @@ -106,7 +106,7 @@ async def test_basic_exception_returns_exception(mocked_client):
limit_names = sorted(["api", "database"])
service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names))

future = service.send((1, "concurrency"))
future = service.send((1, "concurrency", None))
await service.drain()
exception = await asyncio.wrap_future(future)

Expand Down
26 changes: 17 additions & 9 deletions tests/concurrency/test_concurrency_sync.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from unittest import mock

import pytest
from httpx import HTTPStatusError, Request, Response
from prefect._vendor.starlette import status

from prefect import flow, task
from prefect.concurrency.asyncio import (
Expand Down Expand Up @@ -34,7 +35,7 @@ def resource_heavy():
) as release_spy:
resource_heavy()

acquire_spy.assert_called_once_with(["test"], 1)
acquire_spy.assert_called_once_with(["test"], 1, timeout_seconds=None)

# On release we calculate how many seconds the slots were occupied
# for, so here we really just want to make sure that the value
Expand Down Expand Up @@ -168,18 +169,25 @@ def resource_heavy():


@pytest.fixture
def mock_acquire_concurrency_slots(monkeypatch):
async def blocks_forever(*args, **kwargs):
while True:
await asyncio.sleep(1)
def mock_increment_concurrency_slots(monkeypatch):
async def mocked_increment_concurrency_slots(*args, **kwargs):
response = Response(
status_code=status.HTTP_423_LOCKED,
headers={"Retry-After": "0.01"},
)
raise HTTPStatusError(
message="Locked",
request=Request("GET", "http://test.com"),
response=response,
)

monkeypatch.setattr(
"prefect.concurrency.sync._acquire_concurrency_slots",
blocks_forever,
"prefect.client.orchestration.PrefectClient.increment_concurrency_slots",
mocked_increment_concurrency_slots,
)


@pytest.mark.usefixtures("concurrency_limit", "mock_acquire_concurrency_slots")
@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots")
def test_concurrency_respects_timeout():
with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*."):
with concurrency("test", occupy=1, timeout_seconds=0.01):
Expand Down

0 comments on commit 47668a4

Please sign in to comment.