From 47668a47ca840b15290933268a2626c4daf9c587 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Thu, 20 Jun 2024 13:58:01 -0500 Subject: [PATCH] Fix `concurrency` timeout scoping for 2.x (#14183) --- src/prefect/concurrency/asyncio.py | 18 +++++-- src/prefect/concurrency/services.py | 51 +++++++++++-------- src/prefect/concurrency/sync.py | 8 ++- tests/concurrency/test_concurrency_asyncio.py | 26 ++++++---- ...st_concurrency_slot_acquisition_service.py | 8 +-- tests/concurrency/test_concurrency_sync.py | 26 ++++++---- 6 files changed, 84 insertions(+), 53 deletions(-) diff --git a/src/prefect/concurrency/asyncio.py b/src/prefect/concurrency/asyncio.py index b9e346c27e8a..95713d0f70be 100644 --- a/src/prefect/concurrency/asyncio.py +++ b/src/prefect/concurrency/asyncio.py @@ -13,7 +13,6 @@ from prefect import get_client from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.utilities.timeout import timeout_async from .events import ( _emit_concurrency_acquisition_events, @@ -26,6 +25,10 @@ class ConcurrencySlotAcquisitionError(Exception): """Raised when an unhandlable occurs while acquiring concurrency slots.""" +class AcquireConcurrencySlotTimeoutError(TimeoutError): + """Raised when acquiring a concurrency slot times out.""" + + @asynccontextmanager async def concurrency( names: Union[str, List[str]], @@ -58,8 +61,9 @@ async def main(): ``` """ names = names if isinstance(names, list) else [names] - with timeout_async(seconds=timeout_seconds): - limits = await _acquire_concurrency_slots(names, occupy) + limits = await _acquire_concurrency_slots( + names, occupy, timeout_seconds=timeout_seconds + ) acquisition_time = pendulum.now("UTC") emitted_events = _emit_concurrency_acquisition_events(limits, occupy) @@ -91,12 +95,18 @@ async def _acquire_concurrency_slots( names: List[str], slots: int, mode: Union[Literal["concurrency"], Literal["rate_limit"]] = "concurrency", + timeout_seconds: Optional[float] = None, ) -> List[MinimalConcurrencyLimitResponse]: service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) - future = service.send((slots, mode)) + future = service.send((slots, mode, timeout_seconds)) response_or_exception = await asyncio.wrap_future(future) if isinstance(response_or_exception, Exception): + if isinstance(response_or_exception, TimeoutError): + raise AcquireConcurrencySlotTimeoutError( + f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)" + ) from response_or_exception + raise ConcurrencySlotAcquisitionError( f"Unable to acquire concurrency slots on {names!r}" ) from response_or_exception diff --git a/src/prefect/concurrency/services.py b/src/prefect/concurrency/services.py index 8aede9ae6eaf..2ab19362e00f 100644 --- a/src/prefect/concurrency/services.py +++ b/src/prefect/concurrency/services.py @@ -3,6 +3,7 @@ from contextlib import asynccontextmanager from typing import ( FrozenSet, + Optional, Tuple, ) @@ -13,6 +14,7 @@ from prefect._internal.concurrency import logger from prefect._internal.concurrency.services import QueueService from prefect.client.orchestration import PrefectClient +from prefect.utilities.timeout import timeout_async class ConcurrencySlotAcquisitionService(QueueService): @@ -27,10 +29,12 @@ async def _lifespan(self): self._client = client yield - async def _handle(self, item: Tuple[int, str, concurrent.futures.Future]): - occupy, mode, future = item + async def _handle( + self, item: Tuple[int, str, Optional[float], concurrent.futures.Future] + ): + occupy, mode, timeout_seconds, future = item try: - response = await self.acquire_slots(occupy, mode) + response = await self.acquire_slots(occupy, mode, timeout_seconds) except Exception as exc: # If the request to the increment endpoint fails in a non-standard # way, we need to set the future's result so that the caller can @@ -40,25 +44,28 @@ async def _handle(self, item: Tuple[int, str, concurrent.futures.Future]): else: future.set_result(response) - async def acquire_slots(self, slots: int, mode: str) -> httpx.Response: - while True: - try: - response = await self._client.increment_concurrency_slots( - names=self.concurrency_limit_names, slots=slots, mode=mode - ) - except Exception as exc: - if ( - isinstance(exc, httpx.HTTPStatusError) - and exc.response.status_code == status.HTTP_423_LOCKED - ): - retry_after = float(exc.response.headers["Retry-After"]) - await asyncio.sleep(retry_after) + async def acquire_slots( + self, slots: int, mode: str, timeout_seconds: Optional[float] = None + ): + with timeout_async(timeout_seconds): + while True: + try: + response = await self._client.increment_concurrency_slots( + names=self.concurrency_limit_names, slots=slots, mode=mode + ) + except Exception as exc: + if ( + isinstance(exc, httpx.HTTPStatusError) + and exc.response.status_code == status.HTTP_423_LOCKED + ): + retry_after = float(exc.response.headers["Retry-After"]) + await asyncio.sleep(retry_after) + else: + raise exc else: - raise exc - else: - return response + return response - def send(self, item: Tuple[int, str]): + def send(self, item: Tuple[int, str, Optional[float]]) -> concurrent.futures.Future: with self._lock: if self._stopped: raise RuntimeError("Cannot put items in a stopped service instance.") @@ -66,7 +73,7 @@ def send(self, item: Tuple[int, str]): logger.debug("Service %r enqueuing item %r", self, item) future: concurrent.futures.Future = concurrent.futures.Future() - occupy, mode = item - self._queue.put_nowait((occupy, mode, future)) + occupy, mode, timeout_seconds = item + self._queue.put_nowait((occupy, mode, timeout_seconds, future)) return future diff --git a/src/prefect/concurrency/sync.py b/src/prefect/concurrency/sync.py index 3551c28b2853..d572ca641551 100644 --- a/src/prefect/concurrency/sync.py +++ b/src/prefect/concurrency/sync.py @@ -12,7 +12,6 @@ from prefect._internal.concurrency.api import create_call, from_sync from prefect._internal.concurrency.event_loop import get_running_loop from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.utilities.timeout import timeout from .asyncio import ( _acquire_concurrency_slots, @@ -57,10 +56,9 @@ def main(): """ names = names if isinstance(names, list) else [names] - with timeout(seconds=timeout_seconds): - limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync( - _acquire_concurrency_slots, names, occupy - ) + limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync( + _acquire_concurrency_slots, names, occupy, timeout_seconds=timeout_seconds + ) acquisition_time = pendulum.now("UTC") emitted_events = _emit_concurrency_acquisition_events(limits, occupy) diff --git a/tests/concurrency/test_concurrency_asyncio.py b/tests/concurrency/test_concurrency_asyncio.py index 2f747ee8a3e6..ce9633cc5cd3 100644 --- a/tests/concurrency/test_concurrency_asyncio.py +++ b/tests/concurrency/test_concurrency_asyncio.py @@ -1,7 +1,8 @@ -import asyncio from unittest import mock import pytest +from httpx import HTTPStatusError, Request, Response +from prefect._vendor.starlette import status from prefect import flow, task from prefect.concurrency.asyncio import ( @@ -35,7 +36,7 @@ async def resource_heavy(): ) as release_spy: await resource_heavy() - acquire_spy.assert_called_once_with(["test"], 1) + acquire_spy.assert_called_once_with(["test"], 1, timeout_seconds=None) # On release we calculate how many seconds the slots were occupied # for, so here we really just want to make sure that the value @@ -173,18 +174,25 @@ async def resource_heavy(): @pytest.fixture -def mock_acquire_concurrency_slots(monkeypatch): - async def blocks_forever(*args, **kwargs): - while True: - await asyncio.sleep(1) +def mock_increment_concurrency_slots(monkeypatch): + async def mocked_increment_concurrency_slots(*args, **kwargs): + response = Response( + status_code=status.HTTP_423_LOCKED, + headers={"Retry-After": "0.01"}, + ) + raise HTTPStatusError( + message="Locked", + request=Request("GET", "http://test.com"), + response=response, + ) monkeypatch.setattr( - "prefect.concurrency.asyncio._acquire_concurrency_slots", - blocks_forever, + "prefect.client.orchestration.PrefectClient.increment_concurrency_slots", + mocked_increment_concurrency_slots, ) -@pytest.mark.usefixtures("concurrency_limit", "mock_acquire_concurrency_slots") +@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots") async def test_concurrency_respects_timeout(): with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*"): async with concurrency("test", occupy=1, timeout_seconds=0.01): diff --git a/tests/concurrency/test_concurrency_slot_acquisition_service.py b/tests/concurrency/test_concurrency_slot_acquisition_service.py index 86eddecc5894..16cdd0fc688f 100644 --- a/tests/concurrency/test_concurrency_slot_acquisition_service.py +++ b/tests/concurrency/test_concurrency_slot_acquisition_service.py @@ -41,7 +41,7 @@ async def test_returns_successful_response(mocked_client): expected_mode = "concurrency" service = ConcurrencySlotAcquisitionService.instance(frozenset(expected_names)) - future = service.send((expected_slots, expected_mode)) + future = service.send((expected_slots, expected_mode, None)) await service.drain() returned_response = await asyncio.wrap_future(future) assert returned_response == response @@ -67,7 +67,7 @@ async def test_retries_failed_call_respects_retry_after_header(mocked_client): service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) with mock.patch("prefect.concurrency.asyncio.asyncio.sleep") as sleep: - future = service.send((1, "concurrency")) + future = service.send((1, "concurrency", None)) await service.drain() returned_response = await asyncio.wrap_future(future) @@ -91,7 +91,7 @@ async def test_failed_call_status_code_not_retryable_returns_exception(mocked_cl limit_names = sorted(["api", "database"]) service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) - future = service.send((1, "concurrency")) + future = service.send((1, "concurrency", None)) await service.drain() exception = await asyncio.wrap_future(future) @@ -106,7 +106,7 @@ async def test_basic_exception_returns_exception(mocked_client): limit_names = sorted(["api", "database"]) service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) - future = service.send((1, "concurrency")) + future = service.send((1, "concurrency", None)) await service.drain() exception = await asyncio.wrap_future(future) diff --git a/tests/concurrency/test_concurrency_sync.py b/tests/concurrency/test_concurrency_sync.py index c76007a6f8c4..2d6c5e36a992 100644 --- a/tests/concurrency/test_concurrency_sync.py +++ b/tests/concurrency/test_concurrency_sync.py @@ -1,7 +1,8 @@ -import asyncio from unittest import mock import pytest +from httpx import HTTPStatusError, Request, Response +from prefect._vendor.starlette import status from prefect import flow, task from prefect.concurrency.asyncio import ( @@ -34,7 +35,7 @@ def resource_heavy(): ) as release_spy: resource_heavy() - acquire_spy.assert_called_once_with(["test"], 1) + acquire_spy.assert_called_once_with(["test"], 1, timeout_seconds=None) # On release we calculate how many seconds the slots were occupied # for, so here we really just want to make sure that the value @@ -168,18 +169,25 @@ def resource_heavy(): @pytest.fixture -def mock_acquire_concurrency_slots(monkeypatch): - async def blocks_forever(*args, **kwargs): - while True: - await asyncio.sleep(1) +def mock_increment_concurrency_slots(monkeypatch): + async def mocked_increment_concurrency_slots(*args, **kwargs): + response = Response( + status_code=status.HTTP_423_LOCKED, + headers={"Retry-After": "0.01"}, + ) + raise HTTPStatusError( + message="Locked", + request=Request("GET", "http://test.com"), + response=response, + ) monkeypatch.setattr( - "prefect.concurrency.sync._acquire_concurrency_slots", - blocks_forever, + "prefect.client.orchestration.PrefectClient.increment_concurrency_slots", + mocked_increment_concurrency_slots, ) -@pytest.mark.usefixtures("concurrency_limit", "mock_acquire_concurrency_slots") +@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots") def test_concurrency_respects_timeout(): with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*."): with concurrency("test", occupy=1, timeout_seconds=0.01):