diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py index 2cdd99b36..4166df4b8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py @@ -46,6 +46,9 @@ def __init__(self) -> None: self._queued: dict[str, SchedulableTask] = {} self._task_groups: dict[str, TaskGroupKey] = {} self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} + self._queued_by_group: Counter[TaskGroupKey] = Counter() + self._queued_resource_demand_by_group: dict[TaskGroupKey, Counter[SchedulerResourceKey]] = defaultdict(Counter) + self._queued_peer_demand_by_resource: Counter[SchedulerResourceKey] = Counter() self._group_finish: dict[TaskGroupKey, float] = {} self._heap: list[tuple[float, int, TaskGroupKey]] = [] self._active_heap_keys: set[TaskGroupKey] = set() @@ -69,6 +72,7 @@ def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]: queue.append(item) self._queued[item.task_id] = item self._task_groups[item.task_id] = item.group.key + self._increment_queue_accounting(item) self._activate_group(item.group.key) accepted.append(item.task_id) if accepted: @@ -77,10 +81,8 @@ def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]: def discard(self, task_id: str) -> None: """Remove a queued task lazily if it is no longer dispatchable.""" - if task_id in self._queued: + if self._remove_queued_item(task_id) is not None: self._sequence_version += 1 - self._queued.pop(task_id, None) - self._task_groups.pop(task_id, None) def discard_where(self, predicate: Callable[[SchedulableTask], bool]) -> None: """Remove queued tasks matching a predicate.""" @@ -125,8 +127,7 @@ def commit(self, selection: QueueSelection) -> SchedulableTask | None: return None queue.popleft() - self._queued.pop(item.task_id, None) - self._task_groups.pop(item.task_id, None) + self._remove_queued_item(item.task_id) self._active_heap_keys.discard(key) self._active_heap_entries.pop(key, None) group = self._group_specs[key] @@ -140,35 +141,28 @@ def commit(self, selection: QueueSelection) -> SchedulableTask | None: return item def view(self) -> QueueView: - queued_by_group: Counter[TaskGroupKey] = Counter() - demand_by_group: dict[TaskGroupKey, dict[SchedulerResourceKey, int]] = defaultdict(lambda: defaultdict(int)) first_by_group: dict[TaskGroupKey, Mapping[SchedulerResourceKey, int]] = {} first_tasks_by_group: dict[TaskGroupKey, SchedulableTask] = {} first_group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} - demand_by_resource: Counter[SchedulerResourceKey] = Counter() - for item in self._queued.values(): - key = item.group.key - queued_by_group[key] += 1 - for resource, amount in item.resource_request.amounts.items(): - demand_by_group[key][resource] += amount - demand_by_resource[resource] += amount - - for key, queue in self._queues.items(): + for key in self._queued_by_group: first = self._first_valid_item(key) - if first is not None: - first_by_group[key] = dict(first.resource_request.amounts) - first_tasks_by_group[key] = first - first_group_specs[key] = first.group + if first is None: + continue + first_by_group[key] = dict(first.resource_request.amounts) + first_tasks_by_group[key] = first + first_group_specs[key] = first.group return QueueView( queued_total=len(self._queued), - queued_by_group=dict(queued_by_group), - queued_resource_demand_by_group={key: dict(value) for key, value in demand_by_group.items()}, + queued_by_group=dict(self._queued_by_group), + queued_resource_demand_by_group={ + key: dict(value) for key, value in self._queued_resource_demand_by_group.items() + }, first_candidate_resources_by_group=first_by_group, first_candidate_tasks_by_group=first_tasks_by_group, first_candidate_group_specs_by_group=first_group_specs, - queued_peer_demand_by_resource=dict(demand_by_resource), + queued_peer_demand_by_resource=dict(self._queued_peer_demand_by_resource), ) def _activate_group(self, key: TaskGroupKey) -> None: @@ -183,13 +177,11 @@ def _activate_group(self, key: TaskGroupKey) -> None: self._active_heap_entries[key] = (finish, self._sequence) def _first_valid_item(self, key: TaskGroupKey) -> SchedulableTask | None: + self._purge_queue_head(key) queue = self._queues.get(key) - if queue is None: + if not queue: return None - for item in queue: - if item.task_id in self._queued and self._task_groups.get(item.task_id) == key: - return item - return None + return queue[0] def _purge_queue_head(self, key: TaskGroupKey) -> None: queue = self._queues.get(key) @@ -200,3 +192,37 @@ def _purge_queue_head(self, key: TaskGroupKey) -> None: if item.task_id in self._queued and self._task_groups.get(item.task_id) == key: break queue.popleft() + + def _increment_queue_accounting(self, item: SchedulableTask) -> None: + key = item.group.key + self._queued_by_group[key] += 1 + for resource, amount in item.resource_request.amounts.items(): + self._queued_resource_demand_by_group[key][resource] += amount + self._queued_peer_demand_by_resource[resource] += amount + + def _remove_queued_item(self, task_id: str) -> SchedulableTask | None: + item = self._queued.pop(task_id, None) + key = self._task_groups.pop(task_id, None) + if item is None or key is None: + return item + self._decrement_queue_accounting(item, key) + return item + + def _decrement_queue_accounting(self, item: SchedulableTask, key: TaskGroupKey) -> None: + self._queued_by_group[key] -= 1 + if self._queued_by_group[key] <= 0: + del self._queued_by_group[key] + + group_demand = self._queued_resource_demand_by_group.get(key) + if group_demand is not None: + for resource, amount in item.resource_request.amounts.items(): + group_demand[resource] -= amount + if group_demand[resource] <= 0: + del group_demand[resource] + if not group_demand: + del self._queued_resource_demand_by_group[key] + + for resource, amount in item.resource_request.amounts.items(): + self._queued_peer_demand_by_resource[resource] -= amount + if self._queued_peer_demand_by_resource[resource] <= 0: + del self._queued_peer_demand_by_resource[resource] diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py index e2a9179f0..5e10fe5bd 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py @@ -4,10 +4,12 @@ from __future__ import annotations from collections import Counter +from collections.abc import ItemsView from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue, QueueView from data_designer.engine.dataset_builders.scheduling.resources import ( SchedulableTask, + SchedulerResourceKey, SchedulerResourceRequest, TaskGroupKey, TaskGroupSpec, @@ -16,6 +18,15 @@ from data_designer.engine.dataset_builders.scheduling.task_model import Task +class _FailIfScannedAmounts(dict[SchedulerResourceKey, int]): + locked: bool = False + + def items(self) -> ItemsView[SchedulerResourceKey, int]: + if self.locked: + raise AssertionError("QueueView should use incremental accounting for non-candidate tasks.") + return super().items() + + def _task(column: str, row_index: int) -> Task: return Task(column=column, row_group=0, row_index=row_index, task_type="cell") @@ -118,14 +129,25 @@ def test_select_next_uses_scheduler_eligibility_callback() -> None: def test_enqueue_is_idempotent_by_task_id() -> None: queue = FairTaskQueue() - item = _item("a", 0) + group = _group("a") + task = _task("a", 0) + item = SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 2}), + ) first = queue.enqueue([item]) second = queue.enqueue([item]) + view = queue.view() assert first == (item.task_id,) assert second == () - assert queue.view().queued_total == 1 + assert view.queued_total == 1 + assert view.queued_by_group == {group.key: 1} + assert view.queued_resource_demand_by_group == {group.key: {"submission": 1, "llm_wait": 2}} + assert view.queued_peer_demand_by_resource == {"submission": 1, "llm_wait": 2} def test_discard_where_removes_matching_tasks() -> None: @@ -157,3 +179,63 @@ def test_queue_view_exposes_group_and_resource_demand() -> None: assert view.queued_by_group[group.key] == 1 assert view.queued_resource_demand_by_group[group.key]["llm_wait"] == 1 assert view.first_candidate_resources_by_group[group.key]["submission"] == 1 + + +def test_queue_view_updates_incremental_accounting_after_removals() -> None: + queue = FairTaskQueue() + first_group = _group("a") + second_group = _group("b") + first = SchedulableTask( + task_id=stable_task_id(_task("a", 0)), + payload=_task("a", 0), + group=first_group, + resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 2}), + ) + second = SchedulableTask( + task_id=stable_task_id(_task("b", 0)), + payload=_task("b", 0), + group=second_group, + resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 3}), + ) + third = SchedulableTask( + task_id=stable_task_id(_task("b", 1)), + payload=_task("b", 1), + group=second_group, + resource_request=SchedulerResourceRequest({"submission": 1, "local": 1}), + ) + queue.enqueue([first, second, third]) + + queue.discard(first.task_id) + committed = _select_and_commit(queue) + + assert committed == second + view = queue.view() + assert view.queued_total == 1 + assert first_group.key not in view.queued_by_group + assert view.queued_by_group == {second_group.key: 1} + assert view.queued_resource_demand_by_group == {second_group.key: {"submission": 1, "local": 1}} + assert view.queued_peer_demand_by_resource == {"submission": 1, "local": 1} + + +def test_queue_view_uses_incremental_accounting_for_non_candidate_tasks() -> None: + queue = FairTaskQueue() + group = _group("a") + first = _item("a", 0, group) + amounts = _FailIfScannedAmounts({"submission": 1}) + task = _task("a", 1) + second = SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest(amounts), + ) + queue.enqueue([first, second]) + amounts.locked = True + + view = queue.view() + + assert view.queued_total == 2 + assert view.queued_by_group == {group.key: 2} + assert view.queued_resource_demand_by_group == {group.key: {"submission": 2}} + assert view.first_candidate_resources_by_group == {group.key: {"submission": 1}} + assert view.queued_peer_demand_by_resource == {"submission": 2}