Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/67688.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Further optimize ``TaskGroup.topological_sort`` for reverse-declared DAGs via pass-number traversal; dramatically improves the O(N²) worst-case for adversarial shapes (e.g., reverse-insertion chains).
67 changes: 60 additions & 7 deletions airflow-core/src/airflow/serialization/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,24 @@ def topological_sort(self) -> list[DAGNode]:
children = self.children
if not children:
return []

nodes = list(children.values())
n = len(nodes)
id_to_idx = {nid: i for i, nid in enumerate(children)}
projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)]

projected: list[tuple[int, ...]] = [()] * n
nodes_with_back_edge = 0
for i, child in enumerate(nodes):
deps = self._project_child_deps(i, child, id_to_idx)
if deps:
projected[i] = deps
if any(d > i for d in deps):
nodes_with_back_edge += 1

# The ratio catches dense back-heavy groups; a 32-node absolute cutoff keeps
# padded reverse-declared runs on the fast path once sweep rescans overtake pass-numbering.
if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n:
return self._sort_via_pass_numbering(nodes, projected)
return self._sweep_projection(nodes, projected)

def _project_child_deps(
Expand All @@ -242,16 +257,18 @@ def _project_child_deps(
for edge_id in upstream_ids:
j = id_to_idx.get(edge_id)
if j is not None:
sib_deps.add(j)
if j != child_idx:
sib_deps.add(j)
continue
tg = self.dag.get_task(edge_id).task_group
edge = self.dag.get_task(edge_id)
tg = edge.task_group
while tg is not None:
j = id_to_idx.get(tg.node_id)
if j is not None:
sib_deps.add(j)
anc_idx = id_to_idx.get(tg.node_id)
if anc_idx is not None:
if anc_idx != child_idx:
sib_deps.add(anc_idx)
break
tg = tg.parent_group
sib_deps.discard(child_idx)
return tuple(sib_deps)

def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]:
Expand Down Expand Up @@ -291,6 +308,42 @@ def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...
pending = next_pending
return order

def _sort_via_pass_numbering(
self, nodes: list[DAGNode], projected: list[tuple[int, ...]]
) -> list[DAGNode]:
n = len(nodes)
in_degree = [len(deps) for deps in projected]
successors: list[list[int]] = [[] for _ in range(n)]
for i, deps in enumerate(projected):
for d in deps:
successors[d].append(i)

pass_of = [0] * n
queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0)
processed = 0
while queue:
i = queue.popleft()
my_pass = 1
for d in projected[i]:
d_pass = pass_of[d]
if d < i:
if d_pass > my_pass:
my_pass = d_pass
elif d_pass + 1 > my_pass:
my_pass = d_pass + 1
pass_of[i] = my_pass
processed += 1
for s in successors[i]:
in_degree[s] -= 1
if in_degree[s] == 0:
queue.append(s)

if processed != n:
raise ValueError(f"A cyclic dependency occurred in dag: {self.dag_id}")

sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i))
return [nodes[i] for i in sorted_indices]

def add(self, node: DAGNode) -> DAGNode:
# Set the TG first, as setting it might change the return value of node_id!
node.task_group = weakref.proxy(self)
Expand Down
42 changes: 42 additions & 0 deletions airflow-core/tests/unit/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,21 @@ def nested(group):
]


def _make_padded_reverse_chain(chain_length: int, independent_count: int) -> DAG:
with DAG(
f"padded_reverse_chain_{chain_length}_{independent_count}",
schedule=None,
start_date=DEFAULT_DATE,
) as dag:
tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in range(chain_length)]
by_id = {task.task_id: task for task in tasks}
for i in range(chain_length - 1):
by_id[f"r{i}"] >> by_id[f"r{i + 1}"]
for i in range(independent_count):
EmptyOperator(task_id=f"i{i}")
return dag


def test_topological_group_dep():
logical_date = pendulum.parse("20200101")
with DAG("test_dag_edges", schedule=None, start_date=logical_date) as dag:
Expand Down Expand Up @@ -1163,6 +1178,33 @@ def test_topological_sort_serialized_layered():
)


def test_topological_sort_serialized_padded_reverse_chain_uses_pass_numbering(monkeypatch):
dag = _make_padded_reverse_chain(chain_length=80, independent_count=80)
serialized = create_scheduler_dag(dag)
serialized.task_group.children = {
**{f"r{i}": serialized.task_group.children[f"r{i}"] for i in range(79, -1, -1)},
**{f"i{i}": serialized.task_group.children[f"i{i}"] for i in range(80)},
}

called = {"value": False}
serialized_task_group_cls = type(serialized.task_group)
original = serialized_task_group_cls._sort_via_pass_numbering

def spy(self, nodes, projected):
called["value"] = True
return original(self, nodes, projected)

monkeypatch.setattr(serialized_task_group_cls, "_sort_via_pass_numbering", spy)

order = [node.node_id for node in serialized.task_group.topological_sort()]
position = {node_id: i for i, node_id in enumerate(order)}

assert called["value"]
assert set(position) == {*(f"r{i}" for i in range(80)), *(f"i{i}" for i in range(80))}
for i in range(79):
assert position[f"r{i}"] < position[f"r{i + 1}"]


def test_task_group_arrow_with_setup_group():
with DAG(dag_id="setup_group_teardown_group") as dag:
with TaskGroup("group_1") as g1:
Expand Down
92 changes: 76 additions & 16 deletions task-sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,49 +544,68 @@ def topological_sort(self) -> list[DAGNode]:
"""
Sort children topologically — a task always comes after its upstream dependencies.

Projects each child's per-task upstream IDs onto sibling-level integer indices once,
then runs a greedy multi-pass sweep using a bytearray-backed emission flag. Equivalent
in emission order to the previous modified-Kahn implementation, but moves the per-edge
``upstream_list`` materialization and ``parent_group`` walks out of the sweep's inner
loop so they happen once per call instead of once per outer-loop pass.
Projects per-task upstream edges onto sibling-level integer indices, then dispatches:

- Forward-declared DAGs (few/no children declared after their dependents): greedy
multi-pass sweep over the projection, O(V + E) for the common case.
- Reverse-declared DAGs (many children declared before their dependents): pass-number
traversal, O((V + E) log V), avoids the O(N²) blowup the sweep would hit.

Both branches produce the same emission order: level-by-legacy-pass, ties broken by
children insertion order.
"""
children = self.children
if not children:
return []

nodes = list(children.values())
n = len(nodes)
id_to_idx = {nid: i for i, nid in enumerate(children)}
projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)]

projected: list[tuple[int, ...]] = [()] * n
nodes_with_back_edge = 0
for i, child in enumerate(nodes):
deps = self._project_child_deps(i, child, id_to_idx)
if deps:
projected[i] = deps
if any(d > i for d in deps):
nodes_with_back_edge += 1

# The ratio catches dense back-heavy groups; a 32-node absolute cutoff keeps
# padded reverse-declared runs on the fast path once sweep rescans overtake pass-numbering.
if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n:
return self._sort_via_pass_numbering(nodes, projected)
return self._sweep_projection(nodes, projected)

def _project_child_deps(
self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
) -> tuple[int, ...]:
# Project one child's per-task upstream IDs onto sibling-level integer indices.
# Self-deps are filtered once at the end via ``discard`` so the inner loop stays tight.
upstream_ids = child.upstream_task_ids
if not upstream_ids:
return ()
sib_deps: set[int] = set()
for edge_id in upstream_ids:
j = id_to_idx.get(edge_id)
if j is not None:
sib_deps.add(j)
if j != child_idx:
sib_deps.add(j)
continue
tg = self.dag.get_task(edge_id).task_group
edge = self.dag.get_task(edge_id)
tg = edge.task_group
while tg is not None:
j = id_to_idx.get(tg.node_id)
if j is not None:
sib_deps.add(j)
anc_idx = id_to_idx.get(tg.node_id)
if anc_idx is not None:
if anc_idx != child_idx:
sib_deps.add(anc_idx)
break
tg = tg.parent_group
sib_deps.discard(child_idx)
return tuple(sib_deps)

def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]:
# Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been emitted.
# Pass 1 iterates range(n) directly; only blocked nodes are recorded into
# ``pending`` and re-checked in subsequent passes. Avoids paying for a
# ``list(range(n))`` allocation on single-pass shapes (the common case) while
# `pending` and re-checked in subsequent passes. Avoids paying for a
# `list(range(n))` allocation on single-pass shapes (the common case) while
# still skipping already-emitted nodes on multi-pass shapes (e.g. a diamond's
# single trailing sink).
n = len(nodes)
Expand Down Expand Up @@ -625,6 +644,47 @@ def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...
pending = next_pending
return order

def _sort_via_pass_numbering(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new branch has no test pinning its emission order against the sweep. The reverse-chain cases in test_topological_sort_shape_correctness route through here, but _assert_valid_topological_order only checks the result is a valid topological sort, not that it matches what _sweep_projection emits. The order-sensitive tests (test_topological_sort1/2, test_topological_nested_groups) use forward-declared DAGs and route through the sweep, so nothing pins the "both branches produce identical order" invariant this PR rests on.

Worth a test that builds a reverse-declared DAG and asserts _sort_via_pass_numbering and _sweep_projection return identical orders. I checked equivalence empirically across ~150k random DAGs and it holds, so this is a coverage gap rather than a bug, but it's the property most likely to silently break in a future refactor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a reverse-declared regression that directly asserts _sort_via_pass_numbering() emits the same order as _sweep_projection(), so that invariant is now pinned independently of the validity-only shape tests.

self, nodes: list[DAGNode], projected: list[tuple[int, ...]]
) -> list[DAGNode]:
# Sort by (pass_number, insertion_index). pass_number(X) is the earliest pass at
# which a greedy-sweep emission of X would occur:
# pass(X) = max over deps d of (pass(d) if idx(d) < idx(X) else pass(d)+1)
# A dep declared before X can be emitted in the same pass; a dep declared after X
# forces X into the next pass. Computed via Kahn's traversal in O(V + E).
n = len(nodes)
in_degree = [len(deps) for deps in projected]
successors: list[list[int]] = [[] for _ in range(n)]
for i, deps in enumerate(projected):
for d in deps:
successors[d].append(i)

pass_of = [0] * n
queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0)
processed = 0
while queue:
i = queue.popleft()
my_pass = 1
for d in projected[i]:
d_pass = pass_of[d]
if d < i:
if d_pass > my_pass:
my_pass = d_pass
elif d_pass + 1 > my_pass:
my_pass = d_pass + 1
pass_of[i] = my_pass
processed += 1
for s in successors[i]:
in_degree[s] -= 1
if in_degree[s] == 0:
queue.append(s)

if processed != n:
raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")

sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i))
return [nodes[i] for i in sorted_indices]

def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""
Return mapped task groups in the hierarchy.
Expand Down
45 changes: 45 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,21 @@ def _make_reverse_chain(n: int) -> DAG:
return dag


def _make_padded_reverse_chain(chain_length: int, independent_count: int) -> DAG:
with DAG(
f"padded_reverse_chain_{chain_length}_{independent_count}",
schedule=None,
start_date=DEFAULT_DATE,
) as dag:
tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in range(chain_length)]
by_id = {t.task_id: t for t in tasks}
for i in range(chain_length - 1):
by_id[f"r{i}"] >> by_id[f"r{i + 1}"]
for i in range(independent_count):
EmptyOperator(task_id=f"i{i}")
return dag


def _make_diamond(n: int) -> DAG:
with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
root = EmptyOperator(task_id="root")
Expand Down Expand Up @@ -1124,3 +1139,33 @@ def test_topological_sort_shape_correctness(shape, builder, n):
for group in _walk_groups(dag.task_group):
order = [node.node_id for node in group.topological_sort()]
_assert_valid_topological_order(group, order)


def test_topological_sort_reverse_declared_order_matches_sweep():
dag = _make_reverse_chain(100)
group = dag.task_group
nodes = list(group.children.values())
id_to_idx = {nid: i for i, nid in enumerate(group.children)}
projected = [group._project_child_deps(i, child, id_to_idx) for i, child in enumerate(nodes)]

sweep_order = [node.node_id for node in group._sweep_projection(nodes, projected)]
pass_number_order = [node.node_id for node in group._sort_via_pass_numbering(nodes, projected)]

assert pass_number_order == sweep_order


def test_topological_sort_padded_reverse_chain_uses_pass_numbering(monkeypatch):
dag = _make_padded_reverse_chain(chain_length=80, independent_count=80)
called = {"value": False}
original = TaskGroup._sort_via_pass_numbering

def spy(self, nodes, projected):
called["value"] = True
return original(self, nodes, projected)

monkeypatch.setattr(TaskGroup, "_sort_via_pass_numbering", spy)

order = [node.node_id for node in dag.task_group.topological_sort()]

assert called["value"]
_assert_valid_topological_order(dag.task_group, order)
Loading