diff --git a/src/prefect/client/subscriptions.py b/src/prefect/client/subscriptions.py index 91a40e1b4e12..68abdaa33d4b 100644 --- a/src/prefect/client/subscriptions.py +++ b/src/prefect/client/subscriptions.py @@ -9,7 +9,7 @@ from prefect._internal.schemas.bases import IDBaseModel from prefect.logging import get_logger -from prefect.settings import PREFECT_API_KEY, PREFECT_API_URL +from prefect.settings import PREFECT_API_KEY logger = get_logger(__name__) @@ -23,10 +23,11 @@ def __init__( path: str, keys: List[str], client_id: Optional[str] = None, + base_url: Optional[str] = None, ): self.model = model self.client_id = client_id - base_url = PREFECT_API_URL.value().replace("http", "ws", 1) + base_url = base_url.replace("http", "ws", 1) self.subscription_url = f"{base_url}{path}" self.keys = keys diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index b5f8080d9201..2e994199d236 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -139,6 +139,12 @@ async def stop(self): raise StopTaskWorker async def _subscribe_to_task_scheduling(self): + base_url = PREFECT_API_URL.value() + if base_url is None: + raise ValueError( + "`PREFECT_API_URL` must be set to use the task worker. " + "Task workers are not compatible with the ephemeral API." + ) logger.info( f"Subscribing to tasks: {' | '.join(t.task_key.split('.')[-1] for t in self.tasks)}" ) @@ -147,6 +153,7 @@ async def _subscribe_to_task_scheduling(self): path="/task_runs/subscriptions/scheduled", keys=[task.task_key for task in self.tasks], client_id=self._client_id, + base_url=base_url, ): if self._limiter: await self._limiter.acquire_on_behalf_of(task_run.id) diff --git a/tests/test_task_worker.py b/tests/test_task_worker.py index c166bc473aba..853a8f26f3b0 100644 --- a/tests/test_task_worker.py +++ b/tests/test_task_worker.py @@ -11,10 +11,13 @@ from prefect.exceptions import MissingResult from prefect.filesystems import LocalFileSystem from prefect.futures import PrefectDistributedFuture +from prefect.settings import PREFECT_API_URL, temporary_settings from prefect.states import Running from prefect.task_worker import TaskWorker, serve from prefect.tasks import task_input_hash +pytestmark = pytest.mark.usefixtures("use_hosted_api_server") + @pytest.fixture(autouse=True) async def clear_cached_filesystems(): @@ -87,6 +90,12 @@ def mock_subscription(monkeypatch): return mock_subscription +async def test_task_worker_does_not_run_against_ephemeral_api(): + with pytest.raises(ValueError): + with temporary_settings({PREFECT_API_URL: None}): + await TaskWorker(...)._subscribe_to_task_scheduling() + + async def test_task_worker_basic_context_management(): async with TaskWorker(...) as task_worker: assert task_worker.started is True