diff --git a/docs/user/reference/openapi.yaml b/docs/user/reference/openapi.yaml index f35f859314..cf2b2e6ffb 100644 --- a/docs/user/reference/openapi.yaml +++ b/docs/user/reference/openapi.yaml @@ -100,14 +100,15 @@ components: additionalProperties: false description: A representation of a task that the worker recognizes properties: + errors: + items: + type: string + title: Errors + type: array isComplete: default: false title: Iscomplete type: boolean - isError: - default: false - title: Iserror - type: boolean isPending: default: true title: Ispending diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 20a2a2877e..12b32a5d4b 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -119,12 +119,30 @@ def submit_task(self, task: Task) -> str: return task_id def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: + if self.state is not WorkerState.IDLE: + raise WorkerBusyError(f"Worker is in state {self.state}") + + task_started = Event() + + def mark_task_as_started(event: WorkerEvent, _: Optional[str]) -> None: + if ( + event.task_status is not None + and event.task_status.task_id == trackable_task.task_id + ): + task_started.set() + LOGGER.info(f"Submitting: {trackable_task}") try: + sub = self.worker_events.subscribe(mark_task_as_started) self._task_channel.put_nowait(trackable_task) + task_started.wait(timeout=5.0) + if not task_started.is_set(): + raise TimeoutError("Failed to start plan within timeout") except Full: LOGGER.error("Cannot submit task while another is running") raise WorkerBusyError("Cannot submit task while another is running") + finally: + self.worker_events.unsubscribe(sub) def start(self) -> None: if self._started.is_set(): @@ -222,7 +240,7 @@ def _on_state_change( def _report_error(self, err: Exception) -> None: LOGGER.error(err, exc_info=True) if self._current is not None: - self._current.is_error = True + self._current.errors.append(str(err)) self._errors.append(str(err)) def _report_status( @@ -235,7 +253,7 @@ def _report_status( task_status = TaskStatus( task_id=self._current.task_id, task_complete=self._current.is_complete, - task_failed=self._current.is_error or bool(errors), + task_failed=bool(self._current.errors), ) correlation_id = self._current.task_id else: diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index ae144a96b6..4226e1e914 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Generic, List, Optional, TypeVar +from pydantic import Field + from blueapi.core import DataEvent, EventStream from blueapi.utils import BlueapiBaseModel @@ -17,8 +19,8 @@ class TrackableTask(BlueapiBaseModel, Generic[T]): task_id: str task: T is_complete: bool = False - is_error: bool = False is_pending: bool = True + errors: List[str] = Field(default_factory=list) class Worker(ABC, Generic[T]): diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index 1a74bf2d4f..93cc7f2fe0 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -86,13 +86,17 @@ def test_create_task(handler: Handler, client: TestClient) -> None: def test_put_plan_begins_task(handler: Handler, client: TestClient) -> None: + handler.worker.start() response = client.post("/tasks", json=_TASK.dict()) task_id = response.json()["taskId"] task_json = {"task_id": task_id} client.put("/worker/task", json=task_json) - assert handler.worker._task_channel.get().task_id == task_id # type: ignore + active_task = handler.worker.get_active_task() + assert active_task is not None + assert active_task.task_id == task_id + handler.worker.stop() def test_get_state_updates(handler: Handler, client: TestClient) -> None: diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index 5e0f60d17b..1740afae45 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -6,7 +6,7 @@ import pytest from blueapi.config import EnvironmentConfig, Source, SourceKind -from blueapi.core import BlueskyContext, EventStream +from blueapi.core import BlueskyContext, EventStream, MsgGenerator from blueapi.core.bluesky_types import DataEvent from blueapi.worker import ( ProgressEvent, @@ -27,6 +27,7 @@ name="set_absolute", params={"movable": "fake_device", "value": 4.0}, ) +_FAILING_TASK = RunPlan(name="failing_plan", params={}) class FakeDevice: @@ -44,6 +45,10 @@ def set(self, pos: float) -> None: self.event.clear() +def failing_plan() -> MsgGenerator: + raise KeyError("I failed") + + @pytest.fixture def fake_device() -> FakeDevice: return FakeDevice() @@ -56,6 +61,7 @@ def context(fake_device: FakeDevice) -> BlueskyContext: ctx_config.sources.append( Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") ) + ctx.plan(failing_plan) ctx.device(fake_device) ctx.with_config(ctx_config) return ctx @@ -173,6 +179,32 @@ def test_does_not_allow_simultaneous_running_tasks( fake_device.event.set() +def test_begin_task_blocks_until_current_task_set(worker: Worker) -> None: + task_id = worker.submit_task(_SIMPLE_TASK) + assert worker.get_active_task() is None + worker.begin_task(task_id) + active_task = worker.get_active_task() + assert active_task is not None + assert active_task.task == _SIMPLE_TASK + + +def test_plan_failure_recorded_in_active_task(worker: Worker) -> None: + task_id = worker.submit_task(_FAILING_TASK) + events_future: Future[List[WorkerEvent]] = take_events( + worker.worker_events, + lambda event: event.task_status is not None and event.task_status.task_failed, + ) + worker.begin_task(task_id) + events = events_future.result(timeout=5.0) + assert events[-1].task_status is not None + assert events[-1].task_status.task_failed + assert events[-1].errors == ["'I failed'"] + + active_task = worker.get_active_task() + assert active_task is not None + assert active_task.errors == ["'I failed'"] + + @pytest.mark.parametrize("num_runs", [0, 1, 2]) def test_produces_worker_events(worker: Worker, num_runs: int) -> None: task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)]