Skip to content

Commit

Permalink
Fix get_leaves calculation for teardown in nested group (#36456)
Browse files Browse the repository at this point in the history
When arrowing `group` >> `task`, the "leaves" of `group` are connected to `task`. When calculating leaves in the group, teardown tasks are ignored, and we recurse upstream to find non-teardowns.

What was happening, and what this fixes, is you might recurse to a work task that already has another non-teardown downstream in the group.  In that case you should ignore the work task (because it already has a non-teardown descendent).

Resolves #36345

(cherry picked from commit 949fc57)
  • Loading branch information
dstandish authored and ephraimbuddy committed Jan 11, 2024
1 parent 4535686 commit 808ed02
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
13 changes: 12 additions & 1 deletion airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,16 @@ def get_leaves(self) -> Generator[BaseOperator, None, None]:
tasks = list(self)
ids = {x.task_id for x in tasks}

def has_non_teardown_downstream(task, exclude: str):
for down_task in task.downstream_list:
if down_task.task_id == exclude:
continue
elif down_task.task_id not in ids:
continue
elif not down_task.is_teardown:
return True
return False

def recurse_for_first_non_teardown(task):
for upstream_task in task.upstream_list:
if upstream_task.task_id not in ids:
Expand All @@ -381,7 +391,8 @@ def recurse_for_first_non_teardown(task):
elif task.is_teardown and upstream_task.is_setup:
# don't go through the teardown-to-setup path
continue
else:
# return unless upstream task already has non-teardown downstream in group
elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id):
yield upstream_task

for task in tasks:
Expand Down
47 changes: 47 additions & 0 deletions tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,53 @@ def test_dag_edges_setup_teardown():
]


def test_dag_edges_setup_teardown_nested():
from airflow.decorators import task, task_group
from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator

execution_date = pendulum.parse("20200101")

with DAG(dag_id="s_t_dag", start_date=execution_date) as dag:

@task
def test_task():
print("Hello world!")

@task_group
def inner():
inner_start = EmptyOperator(task_id="start")
inner_end = EmptyOperator(task_id="end")

test_task_r = test_task.override(task_id="work")()
inner_start >> test_task_r >> inner_end.as_teardown(setups=inner_start)

@task_group
def outer():
outer_work = EmptyOperator(task_id="work")
inner_group = inner()
inner_group >> outer_work

dag_start = EmptyOperator(task_id="dag_start")
dag_end = EmptyOperator(task_id="dag_end")
dag_start >> outer() >> dag_end

edges = dag_edges(dag)

actual = sorted((e["source_id"], e["target_id"], e.get("is_setup_teardown")) for e in edges)
assert actual == [
("dag_start", "outer.upstream_join_id", None),
("outer.downstream_join_id", "dag_end", None),
("outer.inner.downstream_join_id", "outer.work", None),
("outer.inner.start", "outer.inner.end", True),
("outer.inner.start", "outer.inner.work", None),
("outer.inner.work", "outer.inner.downstream_join_id", None),
("outer.inner.work", "outer.inner.end", None),
("outer.upstream_join_id", "outer.inner.start", None),
("outer.work", "outer.downstream_join_id", None),
]


def test_duplicate_group_id():
from airflow.exceptions import DuplicateTaskIdFound

Expand Down

0 comments on commit 808ed02

Please sign in to comment.