Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(self) -> None:
self._queued: dict[str, SchedulableTask] = {}
self._task_groups: dict[str, TaskGroupKey] = {}
self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {}
self._queued_by_group: Counter[TaskGroupKey] = Counter()
self._queued_resource_demand_by_group: dict[TaskGroupKey, Counter[SchedulerResourceKey]] = defaultdict(Counter)
self._queued_peer_demand_by_resource: Counter[SchedulerResourceKey] = Counter()
self._group_finish: dict[TaskGroupKey, float] = {}
self._heap: list[tuple[float, int, TaskGroupKey]] = []
self._active_heap_keys: set[TaskGroupKey] = set()
Expand All @@ -69,6 +72,7 @@ def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]:
queue.append(item)
self._queued[item.task_id] = item
self._task_groups[item.task_id] = item.group.key
self._increment_queue_accounting(item)
self._activate_group(item.group.key)
accepted.append(item.task_id)
if accepted:
Expand All @@ -77,10 +81,8 @@ def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]:

def discard(self, task_id: str) -> None:
"""Remove a queued task lazily if it is no longer dispatchable."""
if task_id in self._queued:
if self._remove_queued_item(task_id) is not None:
self._sequence_version += 1
self._queued.pop(task_id, None)
self._task_groups.pop(task_id, None)

def discard_where(self, predicate: Callable[[SchedulableTask], bool]) -> None:
"""Remove queued tasks matching a predicate."""
Expand Down Expand Up @@ -125,8 +127,7 @@ def commit(self, selection: QueueSelection) -> SchedulableTask | None:
return None

queue.popleft()
self._queued.pop(item.task_id, None)
self._task_groups.pop(item.task_id, None)
self._remove_queued_item(item.task_id)
self._active_heap_keys.discard(key)
self._active_heap_entries.pop(key, None)
group = self._group_specs[key]
Expand All @@ -140,35 +141,28 @@ def commit(self, selection: QueueSelection) -> SchedulableTask | None:
return item

def view(self) -> QueueView:
queued_by_group: Counter[TaskGroupKey] = Counter()
demand_by_group: dict[TaskGroupKey, dict[SchedulerResourceKey, int]] = defaultdict(lambda: defaultdict(int))
first_by_group: dict[TaskGroupKey, Mapping[SchedulerResourceKey, int]] = {}
first_tasks_by_group: dict[TaskGroupKey, SchedulableTask] = {}
first_group_specs: dict[TaskGroupKey, TaskGroupSpec] = {}
demand_by_resource: Counter[SchedulerResourceKey] = Counter()

for item in self._queued.values():
key = item.group.key
queued_by_group[key] += 1
for resource, amount in item.resource_request.amounts.items():
demand_by_group[key][resource] += amount
demand_by_resource[resource] += amount

for key, queue in self._queues.items():
for key in self._queued_by_group:
first = self._first_valid_item(key)
if first is not None:
first_by_group[key] = dict(first.resource_request.amounts)
first_tasks_by_group[key] = first
first_group_specs[key] = first.group
if first is None:
continue
first_by_group[key] = dict(first.resource_request.amounts)
first_tasks_by_group[key] = first
first_group_specs[key] = first.group

return QueueView(
queued_total=len(self._queued),
queued_by_group=dict(queued_by_group),
queued_resource_demand_by_group={key: dict(value) for key, value in demand_by_group.items()},
queued_by_group=dict(self._queued_by_group),
queued_resource_demand_by_group={
key: dict(value) for key, value in self._queued_resource_demand_by_group.items()
},
first_candidate_resources_by_group=first_by_group,
first_candidate_tasks_by_group=first_tasks_by_group,
first_candidate_group_specs_by_group=first_group_specs,
queued_peer_demand_by_resource=dict(demand_by_resource),
queued_peer_demand_by_resource=dict(self._queued_peer_demand_by_resource),
)

def _activate_group(self, key: TaskGroupKey) -> None:
Expand All @@ -183,13 +177,11 @@ def _activate_group(self, key: TaskGroupKey) -> None:
self._active_heap_entries[key] = (finish, self._sequence)

def _first_valid_item(self, key: TaskGroupKey) -> SchedulableTask | None:
self._purge_queue_head(key)
queue = self._queues.get(key)
if queue is None:
if not queue:
return None
for item in queue:
if item.task_id in self._queued and self._task_groups.get(item.task_id) == key:
return item
return None
return queue[0]

def _purge_queue_head(self, key: TaskGroupKey) -> None:
queue = self._queues.get(key)
Expand All @@ -200,3 +192,37 @@ def _purge_queue_head(self, key: TaskGroupKey) -> None:
if item.task_id in self._queued and self._task_groups.get(item.task_id) == key:
break
queue.popleft()

def _increment_queue_accounting(self, item: SchedulableTask) -> None:
key = item.group.key
self._queued_by_group[key] += 1
for resource, amount in item.resource_request.amounts.items():
self._queued_resource_demand_by_group[key][resource] += amount
self._queued_peer_demand_by_resource[resource] += amount

def _remove_queued_item(self, task_id: str) -> SchedulableTask | None:
item = self._queued.pop(task_id, None)
key = self._task_groups.pop(task_id, None)
if item is None or key is None:
return item
self._decrement_queue_accounting(item, key)
return item

def _decrement_queue_accounting(self, item: SchedulableTask, key: TaskGroupKey) -> None:
self._queued_by_group[key] -= 1
if self._queued_by_group[key] <= 0:
del self._queued_by_group[key]

group_demand = self._queued_resource_demand_by_group.get(key)
if group_demand is not None:
for resource, amount in item.resource_request.amounts.items():
group_demand[resource] -= amount
if group_demand[resource] <= 0:
del group_demand[resource]
if not group_demand:
del self._queued_resource_demand_by_group[key]

for resource, amount in item.resource_request.amounts.items():
self._queued_peer_demand_by_resource[resource] -= amount
if self._queued_peer_demand_by_resource[resource] <= 0:
del self._queued_peer_demand_by_resource[resource]
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from __future__ import annotations

from collections import Counter
from collections.abc import ItemsView

from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue, QueueView
from data_designer.engine.dataset_builders.scheduling.resources import (
SchedulableTask,
SchedulerResourceKey,
SchedulerResourceRequest,
TaskGroupKey,
TaskGroupSpec,
Expand All @@ -16,6 +18,15 @@
from data_designer.engine.dataset_builders.scheduling.task_model import Task


class _FailIfScannedAmounts(dict[SchedulerResourceKey, int]):
locked: bool = False

def items(self) -> ItemsView[SchedulerResourceKey, int]:
if self.locked:
raise AssertionError("QueueView should use incremental accounting for non-candidate tasks.")
return super().items()


def _task(column: str, row_index: int) -> Task:
return Task(column=column, row_group=0, row_index=row_index, task_type="cell")

Expand Down Expand Up @@ -118,14 +129,25 @@ def test_select_next_uses_scheduler_eligibility_callback() -> None:

def test_enqueue_is_idempotent_by_task_id() -> None:
queue = FairTaskQueue()
item = _item("a", 0)
group = _group("a")
task = _task("a", 0)
item = SchedulableTask(
task_id=stable_task_id(task),
payload=task,
group=group,
resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 2}),
)

first = queue.enqueue([item])
second = queue.enqueue([item])
view = queue.view()

assert first == (item.task_id,)
assert second == ()
assert queue.view().queued_total == 1
assert view.queued_total == 1
assert view.queued_by_group == {group.key: 1}
assert view.queued_resource_demand_by_group == {group.key: {"submission": 1, "llm_wait": 2}}
assert view.queued_peer_demand_by_resource == {"submission": 1, "llm_wait": 2}


def test_discard_where_removes_matching_tasks() -> None:
Expand Down Expand Up @@ -157,3 +179,63 @@ def test_queue_view_exposes_group_and_resource_demand() -> None:
assert view.queued_by_group[group.key] == 1
assert view.queued_resource_demand_by_group[group.key]["llm_wait"] == 1
assert view.first_candidate_resources_by_group[group.key]["submission"] == 1


def test_queue_view_updates_incremental_accounting_after_removals() -> None:
queue = FairTaskQueue()
first_group = _group("a")
second_group = _group("b")
first = SchedulableTask(
task_id=stable_task_id(_task("a", 0)),
payload=_task("a", 0),
group=first_group,
resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 2}),
)
second = SchedulableTask(
task_id=stable_task_id(_task("b", 0)),
payload=_task("b", 0),
group=second_group,
resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 3}),
)
third = SchedulableTask(
task_id=stable_task_id(_task("b", 1)),
payload=_task("b", 1),
group=second_group,
resource_request=SchedulerResourceRequest({"submission": 1, "local": 1}),
)
queue.enqueue([first, second, third])

queue.discard(first.task_id)
committed = _select_and_commit(queue)

assert committed == second
view = queue.view()
assert view.queued_total == 1
assert first_group.key not in view.queued_by_group
assert view.queued_by_group == {second_group.key: 1}
assert view.queued_resource_demand_by_group == {second_group.key: {"submission": 1, "local": 1}}
assert view.queued_peer_demand_by_resource == {"submission": 1, "local": 1}


def test_queue_view_uses_incremental_accounting_for_non_candidate_tasks() -> None:
queue = FairTaskQueue()
group = _group("a")
first = _item("a", 0, group)
amounts = _FailIfScannedAmounts({"submission": 1})
task = _task("a", 1)
second = SchedulableTask(
task_id=stable_task_id(task),
payload=task,
group=group,
resource_request=SchedulerResourceRequest(amounts),
)
queue.enqueue([first, second])
amounts.locked = True

view = queue.view()

assert view.queued_total == 2
assert view.queued_by_group == {group.key: 2}
assert view.queued_resource_demand_by_group == {group.key: {"submission": 2}}
assert view.first_candidate_resources_by_group == {group.key: {"submission": 1}}
assert view.queued_peer_demand_by_resource == {"submission": 2}
Loading