Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move timeout context into slot acquisition service #14121

Merged
merged 4 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/prefect/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from prefect.client.orchestration import get_client
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
from prefect.utilities.timeout import timeout_async

from .events import (
_emit_concurrency_acquisition_events,
Expand All @@ -26,6 +25,10 @@ class ConcurrencySlotAcquisitionError(Exception):
"""Raised when an unhandlable occurs while acquiring concurrency slots."""


class AcquireConcurrencySlotTimeoutError(TimeoutError):
"""Raised when acquiring a concurrency slot times out."""


@asynccontextmanager
async def concurrency(
names: Union[str, List[str]],
Expand Down Expand Up @@ -58,8 +61,9 @@ async def main():
```
"""
names = names if isinstance(names, list) else [names]
with timeout_async(seconds=timeout_seconds):
limits = await _acquire_concurrency_slots(names, occupy)
limits = await _acquire_concurrency_slots(
names, occupy, timeout_seconds=timeout_seconds
)
acquisition_time = pendulum.now("UTC")
emitted_events = _emit_concurrency_acquisition_events(limits, occupy)

Expand Down Expand Up @@ -91,12 +95,18 @@ async def _acquire_concurrency_slots(
names: List[str],
slots: int,
mode: Union[Literal["concurrency"], Literal["rate_limit"]] = "concurrency",
timeout_seconds: Optional[float] = None,
) -> List[MinimalConcurrencyLimitResponse]:
service = ConcurrencySlotAcquisitionService.instance(frozenset(names))
future = service.send((slots, mode))
future = service.send((slots, mode, timeout_seconds))
response_or_exception = await asyncio.wrap_future(future)

if isinstance(response_or_exception, Exception):
if isinstance(response_or_exception, TimeoutError):
raise AcquireConcurrencySlotTimeoutError(
f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)"
) from response_or_exception

Comment on lines +105 to +109
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raising a custom error here because the generic Scope timed out after ... message from timeout_async isnt so illuminating in this context

raise ConcurrencySlotAcquisitionError(
f"Unable to acquire concurrency slots on {names!r}"
) from response_or_exception
Expand Down
51 changes: 29 additions & 22 deletions src/prefect/concurrency/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
TYPE_CHECKING,
FrozenSet,
Optional,
Tuple,
)

Expand All @@ -13,6 +14,7 @@
from prefect._internal.concurrency import logger
from prefect._internal.concurrency.services import QueueService
from prefect.client.orchestration import get_client
from prefect.utilities.timeout import timeout_async

if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
Expand All @@ -30,10 +32,12 @@ async def _lifespan(self):
self._client = client
yield

async def _handle(self, item: Tuple[int, str, concurrent.futures.Future]):
occupy, mode, future = item
async def _handle(
self, item: Tuple[int, str, Optional[float], concurrent.futures.Future]
):
occupy, mode, timeout_seconds, future = item
try:
response = await self.acquire_slots(occupy, mode)
response = await self.acquire_slots(occupy, mode, timeout_seconds)
except Exception as exc:
# If the request to the increment endpoint fails in a non-standard
# way, we need to set the future's result so that the caller can
Expand All @@ -43,33 +47,36 @@ async def _handle(self, item: Tuple[int, str, concurrent.futures.Future]):
else:
future.set_result(response)

async def acquire_slots(self, slots: int, mode: str) -> httpx.Response:
while True:
try:
response = await self._client.increment_concurrency_slots(
names=self.concurrency_limit_names, slots=slots, mode=mode
)
except Exception as exc:
if (
isinstance(exc, httpx.HTTPStatusError)
and exc.response.status_code == status.HTTP_423_LOCKED
):
retry_after = float(exc.response.headers["Retry-After"])
await asyncio.sleep(retry_after)
async def acquire_slots(
self, slots: int, mode: str, timeout_seconds: Optional[float] = None
) -> httpx.Response:
with timeout_async(seconds=timeout_seconds):
while True:
try:
response = await self._client.increment_concurrency_slots(
names=self.concurrency_limit_names, slots=slots, mode=mode
)
except Exception as exc:
if (
isinstance(exc, httpx.HTTPStatusError)
and exc.response.status_code == status.HTTP_423_LOCKED
):
retry_after = float(exc.response.headers["Retry-After"])
await asyncio.sleep(retry_after)
else:
raise exc
else:
raise exc
else:
return response
return response

def send(self, item: Tuple[int, str]):
def send(self, item: Tuple[int, str, Optional[float]]) -> concurrent.futures.Future:
with self._lock:
if self._stopped:
raise RuntimeError("Cannot put items in a stopped service instance.")

logger.debug("Service %r enqueuing item %r", self, item)
future: concurrent.futures.Future = concurrent.futures.Future()

occupy, mode = item
self._queue.put_nowait((occupy, mode, future))
occupy, mode, timeout_seconds = item
self._queue.put_nowait((occupy, mode, timeout_seconds, future))

return future
8 changes: 3 additions & 5 deletions src/prefect/concurrency/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from prefect._internal.concurrency.api import create_call, from_sync
from prefect._internal.concurrency.event_loop import get_running_loop
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
from prefect.utilities.timeout import timeout

from .asyncio import (
_acquire_concurrency_slots,
Expand Down Expand Up @@ -57,10 +56,9 @@ def main():
"""
names = names if isinstance(names, list) else [names]

with timeout(seconds=timeout_seconds):
limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync(
_acquire_concurrency_slots, names, occupy
)
limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync(
_acquire_concurrency_slots, names, occupy, timeout_seconds=timeout_seconds
)
acquisition_time = pendulum.now("UTC")
emitted_events = _emit_concurrency_acquisition_events(limits, occupy)

Expand Down
26 changes: 17 additions & 9 deletions tests/concurrency/test_concurrency_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from unittest import mock

import pytest
from httpx import HTTPStatusError, Request, Response
from starlette import status

from prefect import flow, task
from prefect.concurrency.asyncio import (
Expand Down Expand Up @@ -35,7 +36,7 @@ async def resource_heavy():
) as release_spy:
await resource_heavy()

acquire_spy.assert_called_once_with(["test"], 1)
acquire_spy.assert_called_once_with(["test"], 1, timeout_seconds=None)

# On release we calculate how many seconds the slots were occupied
# for, so here we really just want to make sure that the value
Expand Down Expand Up @@ -176,18 +177,25 @@ async def resource_heavy():


@pytest.fixture
def mock_acquire_concurrency_slots(monkeypatch):
async def blocks_forever(*args, **kwargs):
while True:
await asyncio.sleep(1)
def mock_increment_concurrency_slots(monkeypatch):
async def mocked_increment_concurrency_slots(*args, **kwargs):
response = Response(
status_code=status.HTTP_423_LOCKED,
headers={"Retry-After": "0.01"},
)
raise HTTPStatusError(
message="Locked",
request=Request("GET", "http://test.com"),
response=response,
)

monkeypatch.setattr(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
blocks_forever,
"prefect.client.orchestration.PrefectClient.increment_concurrency_slots",
mocked_increment_concurrency_slots,
)


@pytest.mark.usefixtures("concurrency_limit", "mock_acquire_concurrency_slots")
@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots")
async def test_concurrency_respects_timeout():
with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*"):
async with concurrency("test", occupy=1, timeout_seconds=0.01):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_returns_successful_response(mocked_client):
expected_mode = "concurrency"

service = ConcurrencySlotAcquisitionService.instance(frozenset(expected_names))
future = service.send((expected_slots, expected_mode))
future = service.send((expected_slots, expected_mode, None))
await service.drain()
returned_response = await asyncio.wrap_future(future)
assert returned_response == response
Expand All @@ -67,7 +67,7 @@ async def test_retries_failed_call_respects_retry_after_header(mocked_client):
service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names))

with mock.patch("prefect.concurrency.asyncio.asyncio.sleep") as sleep:
future = service.send((1, "concurrency"))
future = service.send((1, "concurrency", None))
await service.drain()
returned_response = await asyncio.wrap_future(future)

Expand All @@ -91,7 +91,7 @@ async def test_failed_call_status_code_not_retryable_returns_exception(mocked_cl
limit_names = sorted(["api", "database"])
service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names))

future = service.send((1, "concurrency"))
future = service.send((1, "concurrency", None))
await service.drain()
exception = await asyncio.wrap_future(future)

Expand All @@ -106,7 +106,7 @@ async def test_basic_exception_returns_exception(mocked_client):
limit_names = sorted(["api", "database"])
service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names))

future = service.send((1, "concurrency"))
future = service.send((1, "concurrency", None))
await service.drain()
exception = await asyncio.wrap_future(future)

Expand Down
26 changes: 17 additions & 9 deletions tests/concurrency/test_concurrency_sync.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from unittest import mock

import pytest
from httpx import HTTPStatusError, Request, Response
from starlette import status

from prefect import flow, task
from prefect.concurrency.asyncio import (
Expand Down Expand Up @@ -34,7 +35,7 @@ def resource_heavy():
) as release_spy:
resource_heavy()

acquire_spy.assert_called_once_with(["test"], 1)
acquire_spy.assert_called_once_with(["test"], 1, timeout_seconds=None)

# On release we calculate how many seconds the slots were occupied
# for, so here we really just want to make sure that the value
Expand Down Expand Up @@ -168,18 +169,25 @@ def resource_heavy():


@pytest.fixture
def mock_acquire_concurrency_slots(monkeypatch):
async def blocks_forever(*args, **kwargs):
while True:
await asyncio.sleep(1)
def mock_increment_concurrency_slots(monkeypatch):
async def mocked_increment_concurrency_slots(*args, **kwargs):
response = Response(
status_code=status.HTTP_423_LOCKED,
headers={"Retry-After": "0.01"},
)
raise HTTPStatusError(
message="Locked",
request=Request("GET", "http://test.com"),
response=response,
)

monkeypatch.setattr(
"prefect.concurrency.sync._acquire_concurrency_slots",
blocks_forever,
"prefect.client.orchestration.PrefectClient.increment_concurrency_slots",
mocked_increment_concurrency_slots,
)


@pytest.mark.usefixtures("concurrency_limit", "mock_acquire_concurrency_slots")
@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots")
def test_concurrency_respects_timeout():
with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*."):
with concurrency("test", occupy=1, timeout_seconds=0.01):
Expand Down