Skip to content

Commit

Permalink
Revert "Fix pre-mature evaluation of tasks in mapped task group (#34337
Browse files Browse the repository at this point in the history
…)" (#35651)

This reverts commit 69938fd.
  • Loading branch information
ephraimbuddy committed Nov 15, 2023
1 parent ffbba9e commit 95bf5dd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 61 deletions.
18 changes: 0 additions & 18 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule as TR

if TYPE_CHECKING:
Expand Down Expand Up @@ -133,20 +132,6 @@ def _get_expanded_ti_count() -> int:
"""
return ti.task.get_mapped_ti_count(ti.run_id, session=session)

def _iter_expansion_dependencies() -> Iterator[str]:
from airflow.models.mappedoperator import MappedOperator

if isinstance(ti.task, MappedOperator):
for op in ti.task.iter_mapped_dependencies():
yield op.task_id
task_group = ti.task.task_group
if task_group and task_group.iter_mapped_task_groups():
yield from (
op.task_id
for tg in task_group.iter_mapped_task_groups()
for op in tg.iter_mapped_dependencies()
)

@functools.lru_cache
def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
"""Get the given task's map indexes relevant to the current ti.
Expand All @@ -157,9 +142,6 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
"""
if TYPE_CHECKING:
assert isinstance(ti.task.dag, DAG)
if isinstance(ti.task.task_group, MappedTaskGroup):
if upstream_id not in set(_iter_expansion_dependencies()):
return None
try:
expanded_ti_count = _get_expanded_ti_count()
except (NotFullyPopulated, NotMapped):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,8 +1305,8 @@ def file_transforms(filename):
states = self.get_states(dr)
expected = {
"file_transforms.my_setup": {0: "success", 1: "failed", 2: "skipped"},
"file_transforms.my_work": {2: "upstream_failed", 1: "upstream_failed", 0: "upstream_failed"},
"file_transforms.my_teardown": {2: "success", 1: "success", 0: "success"},
"file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: "skipped"},
"file_transforms.my_teardown": {0: "success", 1: "upstream_failed", 2: "skipped"},
}

assert states == expected
Expand Down
47 changes: 6 additions & 41 deletions tests/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,23 +1165,19 @@ def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)]

# After running the first t1, the remaining t1 must be run before t2 is available.
# After running the first t1, the first t2 becomes immediately available.
tis["tg.t1", 0].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2)]
assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)]

# After running all t1, t2 is available.
tis["tg.t1", 1].run()
# Similarly for the subsequent t2 instances.
tis["tg.t1", 2].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)]

# Similarly for t2 instances. They both have to complete before t3 is available
tis["tg.t2", 0].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t2", 1), ("tg.t2", 2)]
assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)]

# But running t2 partially does not make t3 available.
tis["tg.t1", 1].run()
tis["tg.t2", 0].run()
tis["tg.t2", 2].run()
tis = _one_scheduling_decision_iteration()
assert sorted(tis) == [("tg.t2", 1)]
Expand Down Expand Up @@ -1411,34 +1407,3 @@ def w2():
(status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session)
assert status.reason.startswith("All setup tasks must complete successfully")
assert self.get_ti(dr, "w2").state == expected


def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session):
"""Test that one failed trigger rule works well in mapped task group"""
with dag_maker() as dag:

@dag.task
def t1():
return [1, 2, 3]

@task_group("tg1")
def tg1(a):
@dag.task()
def t2(a):
return a

@dag.task(trigger_rule=TriggerRule.ONE_FAILED)
def t3(a):
return a

t2(a) >> t3(a)

t = t1()
tg1.expand(a=t)

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id="t1")
ti.run()
dr.task_instance_scheduling_decisions()
ti3 = dr.get_task_instance(task_id="tg1.t3")
assert not ti3.state

0 comments on commit 95bf5dd

Please sign in to comment.