Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions airflow-core/src/airflow/serialization/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,16 @@ def topological_sort(self) -> list[DAGNode]:
# We are already going to visit that TG
break
else:
del graph_unsorted[node.node_id]
graph_sorted.append(node)
# When list-based deps (e.g. `[b0, b1] >> a`) are used between TaskGroups,
# only upstream_group_ids is populated (not upstream_task_ids), so upstream_list
# is empty and the task-level check above won't block the node. Check group-level
# upstreams explicitly to handle this case.
for group_id in getattr(node, "upstream_group_ids", ()):
if group_id in graph_unsorted:
break
else:
del graph_unsorted[node.node_id]
graph_sorted.append(node)
return graph_sorted

def add(self, node: DAGNode) -> DAGNode:
Expand Down
22 changes: 22 additions & 0 deletions airflow-core/tests/unit/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,28 @@ def nested_topo(group):
]


def test_topological_group_dep_list_syntax():
"""List-based deps (`[b0, b1] >> a`) must produce the same topological order as individual deps."""
logical_date = pendulum.parse("20200101")
with DAG("test_dag_list_dep", schedule=None, start_date=logical_date) as dag:
with TaskGroup("a") as tg_a:
EmptyOperator(task_id="task")

groups = []
for x in range(3):
with TaskGroup(f"b_{x}") as tg_b:
EmptyOperator(task_id="task")
groups.append(tg_b)

groups >> tg_a # list-based dep — previously produced wrong order

top_level = dag.task_group.topological_sort()
ids = [node.node_id for node in top_level]
a_idx = ids.index("a")
b_idxs = [ids.index(f"b_{x}") for x in range(3)]
assert all(b < a_idx for b in b_idxs), f"Expected all b_x before a in topological order, got: {ids}"


def test_task_group_arrow_with_setup_group():
with DAG(dag_id="setup_group_teardown_group") as dag:
with TaskGroup("group_1") as g1:
Expand Down
14 changes: 11 additions & 3 deletions task-sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,17 @@ def topological_sort(self):
# We are already going to visit that TG
break
else:
acyclic = True
del graph_unsorted[node.node_id]
graph_sorted.append(node)
# When list-based deps (e.g. `[b0, b1] >> a`) are used between TaskGroups,
# only upstream_group_ids is populated (not upstream_task_ids), so upstream_list
# is empty and the task-level check above won't block the node. Check group-level
# upstreams explicitly to handle this case.
for group_id in getattr(node, "upstream_group_ids", ()):
if group_id in graph_unsorted:
break
else:
acyclic = True
del graph_unsorted[node.node_id]
graph_sorted.append(node)

if not acyclic:
raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
Expand Down
Loading