diff --git a/airflow-core/newsfragments/67688.improvement.rst b/airflow-core/newsfragments/67688.improvement.rst new file mode 100644 index 0000000000000..d2a641ac00c80 --- /dev/null +++ b/airflow-core/newsfragments/67688.improvement.rst @@ -0,0 +1 @@ +Further optimize ``TaskGroup.topological_sort`` for reverse-declared DAGs via pass-number traversal; dramatically improves the O(N²) worst-case for adversarial shapes (e.g., reverse-insertion chains). diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index a5d8b730b05a2..65d59cb15f17e 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -227,9 +227,24 @@ def topological_sort(self) -> list[DAGNode]: children = self.children if not children: return [] + nodes = list(children.values()) + n = len(nodes) id_to_idx = {nid: i for i, nid in enumerate(children)} - projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)] + + projected: list[tuple[int, ...]] = [()] * n + nodes_with_back_edge = 0 + for i, child in enumerate(nodes): + deps = self._project_child_deps(i, child, id_to_idx) + if deps: + projected[i] = deps + if any(d > i for d in deps): + nodes_with_back_edge += 1 + + # The ratio catches dense back-heavy groups; a 32-node absolute cutoff keeps + # padded reverse-declared runs on the fast path once sweep rescans overtake pass-numbering. + if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n: + return self._sort_via_pass_numbering(nodes, projected) return self._sweep_projection(nodes, projected) def _project_child_deps( @@ -242,16 +257,18 @@ def _project_child_deps( for edge_id in upstream_ids: j = id_to_idx.get(edge_id) if j is not None: - sib_deps.add(j) + if j != child_idx: + sib_deps.add(j) continue - tg = self.dag.get_task(edge_id).task_group + edge = self.dag.get_task(edge_id) + tg = edge.task_group while tg is not None: - j = id_to_idx.get(tg.node_id) - if j is not None: - sib_deps.add(j) + anc_idx = id_to_idx.get(tg.node_id) + if anc_idx is not None: + if anc_idx != child_idx: + sib_deps.add(anc_idx) break tg = tg.parent_group - sib_deps.discard(child_idx) return tuple(sib_deps) def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: @@ -291,6 +308,42 @@ def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ... pending = next_pending return order + def _sort_via_pass_numbering( + self, nodes: list[DAGNode], projected: list[tuple[int, ...]] + ) -> list[DAGNode]: + n = len(nodes) + in_degree = [len(deps) for deps in projected] + successors: list[list[int]] = [[] for _ in range(n)] + for i, deps in enumerate(projected): + for d in deps: + successors[d].append(i) + + pass_of = [0] * n + queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0) + processed = 0 + while queue: + i = queue.popleft() + my_pass = 1 + for d in projected[i]: + d_pass = pass_of[d] + if d < i: + if d_pass > my_pass: + my_pass = d_pass + elif d_pass + 1 > my_pass: + my_pass = d_pass + 1 + pass_of[i] = my_pass + processed += 1 + for s in successors[i]: + in_degree[s] -= 1 + if in_degree[s] == 0: + queue.append(s) + + if processed != n: + raise ValueError(f"A cyclic dependency occurred in dag: {self.dag_id}") + + sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i)) + return [nodes[i] for i in sorted_indices] + def add(self, node: DAGNode) -> DAGNode: # Set the TG first, as setting it might change the return value of node_id! node.task_group = weakref.proxy(self) diff --git a/airflow-core/tests/unit/utils/test_task_group.py b/airflow-core/tests/unit/utils/test_task_group.py index e866bce62afa1..2d1458e95fbd1 100644 --- a/airflow-core/tests/unit/utils/test_task_group.py +++ b/airflow-core/tests/unit/utils/test_task_group.py @@ -1100,6 +1100,21 @@ def nested(group): ] +def _make_padded_reverse_chain(chain_length: int, independent_count: int) -> DAG: + with DAG( + f"padded_reverse_chain_{chain_length}_{independent_count}", + schedule=None, + start_date=DEFAULT_DATE, + ) as dag: + tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in range(chain_length)] + by_id = {task.task_id: task for task in tasks} + for i in range(chain_length - 1): + by_id[f"r{i}"] >> by_id[f"r{i + 1}"] + for i in range(independent_count): + EmptyOperator(task_id=f"i{i}") + return dag + + def test_topological_group_dep(): logical_date = pendulum.parse("20200101") with DAG("test_dag_edges", schedule=None, start_date=logical_date) as dag: @@ -1163,6 +1178,33 @@ def test_topological_sort_serialized_layered(): ) +def test_topological_sort_serialized_padded_reverse_chain_uses_pass_numbering(monkeypatch): + dag = _make_padded_reverse_chain(chain_length=80, independent_count=80) + serialized = create_scheduler_dag(dag) + serialized.task_group.children = { + **{f"r{i}": serialized.task_group.children[f"r{i}"] for i in range(79, -1, -1)}, + **{f"i{i}": serialized.task_group.children[f"i{i}"] for i in range(80)}, + } + + called = {"value": False} + serialized_task_group_cls = type(serialized.task_group) + original = serialized_task_group_cls._sort_via_pass_numbering + + def spy(self, nodes, projected): + called["value"] = True + return original(self, nodes, projected) + + monkeypatch.setattr(serialized_task_group_cls, "_sort_via_pass_numbering", spy) + + order = [node.node_id for node in serialized.task_group.topological_sort()] + position = {node_id: i for i, node_id in enumerate(order)} + + assert called["value"] + assert set(position) == {*(f"r{i}" for i in range(80)), *(f"i{i}" for i in range(80))} + for i in range(79): + assert position[f"r{i}"] < position[f"r{i + 1}"] + + def test_task_group_arrow_with_setup_group(): with DAG(dag_id="setup_group_teardown_group") as dag: with TaskGroup("group_1") as g1: diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index cc1fc5cbda627..14bb2fba31918 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -544,25 +544,42 @@ def topological_sort(self) -> list[DAGNode]: """ Sort children topologically — a task always comes after its upstream dependencies. - Projects each child's per-task upstream IDs onto sibling-level integer indices once, - then runs a greedy multi-pass sweep using a bytearray-backed emission flag. Equivalent - in emission order to the previous modified-Kahn implementation, but moves the per-edge - ``upstream_list`` materialization and ``parent_group`` walks out of the sweep's inner - loop so they happen once per call instead of once per outer-loop pass. + Projects per-task upstream edges onto sibling-level integer indices, then dispatches: + + - Forward-declared DAGs (few/no children declared after their dependents): greedy + multi-pass sweep over the projection, O(V + E) for the common case. + - Reverse-declared DAGs (many children declared before their dependents): pass-number + traversal, O((V + E) log V), avoids the O(N²) blowup the sweep would hit. + + Both branches produce the same emission order: level-by-legacy-pass, ties broken by + children insertion order. """ children = self.children if not children: return [] + nodes = list(children.values()) + n = len(nodes) id_to_idx = {nid: i for i, nid in enumerate(children)} - projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)] + + projected: list[tuple[int, ...]] = [()] * n + nodes_with_back_edge = 0 + for i, child in enumerate(nodes): + deps = self._project_child_deps(i, child, id_to_idx) + if deps: + projected[i] = deps + if any(d > i for d in deps): + nodes_with_back_edge += 1 + + # The ratio catches dense back-heavy groups; a 32-node absolute cutoff keeps + # padded reverse-declared runs on the fast path once sweep rescans overtake pass-numbering. + if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n: + return self._sort_via_pass_numbering(nodes, projected) return self._sweep_projection(nodes, projected) def _project_child_deps( self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int] ) -> tuple[int, ...]: - # Project one child's per-task upstream IDs onto sibling-level integer indices. - # Self-deps are filtered once at the end via ``discard`` so the inner loop stays tight. upstream_ids = child.upstream_task_ids if not upstream_ids: return () @@ -570,23 +587,25 @@ def _project_child_deps( for edge_id in upstream_ids: j = id_to_idx.get(edge_id) if j is not None: - sib_deps.add(j) + if j != child_idx: + sib_deps.add(j) continue - tg = self.dag.get_task(edge_id).task_group + edge = self.dag.get_task(edge_id) + tg = edge.task_group while tg is not None: - j = id_to_idx.get(tg.node_id) - if j is not None: - sib_deps.add(j) + anc_idx = id_to_idx.get(tg.node_id) + if anc_idx is not None: + if anc_idx != child_idx: + sib_deps.add(anc_idx) break tg = tg.parent_group - sib_deps.discard(child_idx) return tuple(sib_deps) def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: # Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been emitted. # Pass 1 iterates range(n) directly; only blocked nodes are recorded into - # ``pending`` and re-checked in subsequent passes. Avoids paying for a - # ``list(range(n))`` allocation on single-pass shapes (the common case) while + # `pending` and re-checked in subsequent passes. Avoids paying for a + # `list(range(n))` allocation on single-pass shapes (the common case) while # still skipping already-emitted nodes on multi-pass shapes (e.g. a diamond's # single trailing sink). n = len(nodes) @@ -625,6 +644,47 @@ def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ... pending = next_pending return order + def _sort_via_pass_numbering( + self, nodes: list[DAGNode], projected: list[tuple[int, ...]] + ) -> list[DAGNode]: + # Sort by (pass_number, insertion_index). pass_number(X) is the earliest pass at + # which a greedy-sweep emission of X would occur: + # pass(X) = max over deps d of (pass(d) if idx(d) < idx(X) else pass(d)+1) + # A dep declared before X can be emitted in the same pass; a dep declared after X + # forces X into the next pass. Computed via Kahn's traversal in O(V + E). + n = len(nodes) + in_degree = [len(deps) for deps in projected] + successors: list[list[int]] = [[] for _ in range(n)] + for i, deps in enumerate(projected): + for d in deps: + successors[d].append(i) + + pass_of = [0] * n + queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0) + processed = 0 + while queue: + i = queue.popleft() + my_pass = 1 + for d in projected[i]: + d_pass = pass_of[d] + if d < i: + if d_pass > my_pass: + my_pass = d_pass + elif d_pass + 1 > my_pass: + my_pass = d_pass + 1 + pass_of[i] = my_pass + processed += 1 + for s in successors[i]: + in_degree[s] -= 1 + if in_degree[s] == 0: + queue.append(s) + + if processed != n: + raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") + + sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i)) + return [nodes[i] for i in sorted_indices] + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: """ Return mapped task groups in the hierarchy. diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py index cf6d309f305eb..d11f8eb3632b3 100644 --- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py +++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py @@ -1011,6 +1011,21 @@ def _make_reverse_chain(n: int) -> DAG: return dag +def _make_padded_reverse_chain(chain_length: int, independent_count: int) -> DAG: + with DAG( + f"padded_reverse_chain_{chain_length}_{independent_count}", + schedule=None, + start_date=DEFAULT_DATE, + ) as dag: + tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in range(chain_length)] + by_id = {t.task_id: t for t in tasks} + for i in range(chain_length - 1): + by_id[f"r{i}"] >> by_id[f"r{i + 1}"] + for i in range(independent_count): + EmptyOperator(task_id=f"i{i}") + return dag + + def _make_diamond(n: int) -> DAG: with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: root = EmptyOperator(task_id="root") @@ -1124,3 +1139,33 @@ def test_topological_sort_shape_correctness(shape, builder, n): for group in _walk_groups(dag.task_group): order = [node.node_id for node in group.topological_sort()] _assert_valid_topological_order(group, order) + + +def test_topological_sort_reverse_declared_order_matches_sweep(): + dag = _make_reverse_chain(100) + group = dag.task_group + nodes = list(group.children.values()) + id_to_idx = {nid: i for i, nid in enumerate(group.children)} + projected = [group._project_child_deps(i, child, id_to_idx) for i, child in enumerate(nodes)] + + sweep_order = [node.node_id for node in group._sweep_projection(nodes, projected)] + pass_number_order = [node.node_id for node in group._sort_via_pass_numbering(nodes, projected)] + + assert pass_number_order == sweep_order + + +def test_topological_sort_padded_reverse_chain_uses_pass_numbering(monkeypatch): + dag = _make_padded_reverse_chain(chain_length=80, independent_count=80) + called = {"value": False} + original = TaskGroup._sort_via_pass_numbering + + def spy(self, nodes, projected): + called["value"] = True + return original(self, nodes, projected) + + monkeypatch.setattr(TaskGroup, "_sort_via_pass_numbering", spy) + + order = [node.node_id for node in dag.task_group.topological_sort()] + + assert called["value"] + _assert_valid_topological_order(dag.task_group, order)