Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/user/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ components:
additionalProperties: false
description: A representation of a task that the worker recognizes
properties:
errors:
items:
type: string
title: Errors
type: array
isComplete:
default: false
title: Iscomplete
type: boolean
isError:
default: false
title: Iserror
type: boolean
isPending:
default: true
title: Ispending
Expand Down
22 changes: 20 additions & 2 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,30 @@ def submit_task(self, task: Task) -> str:
return task_id

def _submit_trackable_task(self, trackable_task: TrackableTask) -> None:
if self.state is not WorkerState.IDLE:
raise WorkerBusyError(f"Worker is in state {self.state}")

task_started = Event()

def mark_task_as_started(event: WorkerEvent, _: Optional[str]) -> None:
if (
event.task_status is not None
and event.task_status.task_id == trackable_task.task_id
):
task_started.set()

LOGGER.info(f"Submitting: {trackable_task}")
try:
sub = self.worker_events.subscribe(mark_task_as_started)
self._task_channel.put_nowait(trackable_task)
task_started.wait(timeout=5.0)
if not task_started.is_set():
raise TimeoutError("Failed to start plan within timeout")
except Full:
LOGGER.error("Cannot submit task while another is running")
raise WorkerBusyError("Cannot submit task while another is running")
finally:
self.worker_events.unsubscribe(sub)

def start(self) -> None:
if self._started.is_set():
Expand Down Expand Up @@ -222,7 +240,7 @@ def _on_state_change(
def _report_error(self, err: Exception) -> None:
LOGGER.error(err, exc_info=True)
if self._current is not None:
self._current.is_error = True
self._current.errors.append(str(err))
self._errors.append(str(err))

def _report_status(
Expand All @@ -235,7 +253,7 @@ def _report_status(
task_status = TaskStatus(
task_id=self._current.task_id,
task_complete=self._current.is_complete,
task_failed=self._current.is_error or bool(errors),
task_failed=bool(self._current.errors),
)
correlation_id = self._current.task_id
else:
Expand Down
4 changes: 3 additions & 1 deletion src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import Generic, List, Optional, TypeVar

from pydantic import Field

from blueapi.core import DataEvent, EventStream
from blueapi.utils import BlueapiBaseModel

Expand All @@ -17,8 +19,8 @@ class TrackableTask(BlueapiBaseModel, Generic[T]):
task_id: str
task: T
is_complete: bool = False
is_error: bool = False
is_pending: bool = True
errors: List[str] = Field(default_factory=list)


class Worker(ABC, Generic[T]):
Expand Down
6 changes: 5 additions & 1 deletion tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,17 @@ def test_create_task(handler: Handler, client: TestClient) -> None:


def test_put_plan_begins_task(handler: Handler, client: TestClient) -> None:
handler.worker.start()
response = client.post("/tasks", json=_TASK.dict())
task_id = response.json()["taskId"]

task_json = {"task_id": task_id}
client.put("/worker/task", json=task_json)

assert handler.worker._task_channel.get().task_id == task_id # type: ignore
active_task = handler.worker.get_active_task()
assert active_task is not None
assert active_task.task_id == task_id
handler.worker.stop()


def test_get_state_updates(handler: Handler, client: TestClient) -> None:
Expand Down
34 changes: 33 additions & 1 deletion tests/worker/test_reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from blueapi.config import EnvironmentConfig, Source, SourceKind
from blueapi.core import BlueskyContext, EventStream
from blueapi.core import BlueskyContext, EventStream, MsgGenerator
from blueapi.core.bluesky_types import DataEvent
from blueapi.worker import (
ProgressEvent,
Expand All @@ -27,6 +27,7 @@
name="set_absolute",
params={"movable": "fake_device", "value": 4.0},
)
_FAILING_TASK = RunPlan(name="failing_plan", params={})


class FakeDevice:
Expand All @@ -44,6 +45,10 @@ def set(self, pos: float) -> None:
self.event.clear()


def failing_plan() -> MsgGenerator:
raise KeyError("I failed")


@pytest.fixture
def fake_device() -> FakeDevice:
return FakeDevice()
Expand All @@ -56,6 +61,7 @@ def context(fake_device: FakeDevice) -> BlueskyContext:
ctx_config.sources.append(
Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices")
)
ctx.plan(failing_plan)
ctx.device(fake_device)
ctx.with_config(ctx_config)
return ctx
Expand Down Expand Up @@ -173,6 +179,32 @@ def test_does_not_allow_simultaneous_running_tasks(
fake_device.event.set()


def test_begin_task_blocks_until_current_task_set(worker: Worker) -> None:
task_id = worker.submit_task(_SIMPLE_TASK)
assert worker.get_active_task() is None
worker.begin_task(task_id)
active_task = worker.get_active_task()
assert active_task is not None
assert active_task.task == _SIMPLE_TASK


def test_plan_failure_recorded_in_active_task(worker: Worker) -> None:
task_id = worker.submit_task(_FAILING_TASK)
events_future: Future[List[WorkerEvent]] = take_events(
worker.worker_events,
lambda event: event.task_status is not None and event.task_status.task_failed,
)
worker.begin_task(task_id)
events = events_future.result(timeout=5.0)
assert events[-1].task_status is not None
assert events[-1].task_status.task_failed
assert events[-1].errors == ["'I failed'"]

active_task = worker.get_active_task()
assert active_task is not None
assert active_task.errors == ["'I failed'"]


@pytest.mark.parametrize("num_runs", [0, 1, 2])
def test_produces_worker_events(worker: Worker, num_runs: int) -> None:
task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)]
Expand Down