Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/jobs/local_task_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def handle_task_exit(self, return_code: int) -> None:

if not self.task_instance.test_mode and not is_deferral:
if conf.getboolean("scheduler", "schedule_after_task_execution", fallback=True):
self.task_instance.schedule_downstream_tasks(max_tis_per_query=self.job.max_tis_per_query)

# self.task_instance.schedule_downstream_tasks(max_tis_per_query=self.job.max_tis_per_query)
pass
def on_kill(self):
self.task_runner.terminate()
self.task_runner.on_finish()
Expand Down
11 changes: 11 additions & 0 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,17 @@ def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
if t.is_teardown and not t == self:
yield t

def get_upstreams_only_setups(self) -> Iterable[Operator]:
"""
Only upstream setups.

This method is meant to be used when we are checking task dependencies where we need
to wait for all the upstream setups to complete before we can run the task.
"""
for task in self.get_upstreams_only_setups_and_teardowns():
if task.is_setup:
yield task

def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""Return mapped nodes that are direct dependencies of the current task.

Expand Down
3 changes: 2 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.ti_deps.deps.trigger_rule_dep import IndirectSetupTasksDep, TriggerRuleDep
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.context import Context
Expand Down Expand Up @@ -1117,6 +1117,7 @@ def has_dag(self):
{
NotInRetryPeriodDep(),
PrevDagrunDep(),
IndirectSetupTasksDep(),
TriggerRuleDep(),
NotPreviouslySkippedDep(),
}
Expand Down
12 changes: 10 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,14 @@ def are_dependencies_met(
def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
"""Get failed Dependencies."""
dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
deps_to_check = dep_context.deps | self.task.deps
from airflow.ti_deps.deps.trigger_rule_dep import IndirectSetupTasksDep

setup_dep = IndirectSetupTasksDep()
if setup_dep in deps_to_check:
deps_to_check.remove(setup_dep)
deps_to_check = (setup_dep, *deps_to_check)
for dep in deps_to_check:
for dep_status in dep.get_dep_statuses(self, session, dep_context):
self.log.debug(
"%s dependency '%s' PASSED: %s, %s",
Expand All @@ -1170,8 +1177,9 @@ def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session
dep_status.passed,
dep_status.reason,
)

if not dep_status.passed:
if dep.__hash__() == setup_dep.__hash__():
dep_context.flag_upstream_failed = False
yield dep_status

def __repr__(self) -> str:
Expand Down
268 changes: 208 additions & 60 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import collections
import collections.abc
import functools
import logging
from typing import TYPE_CHECKING, Iterator, NamedTuple

from sqlalchemy import and_, func, or_, select
Expand Down Expand Up @@ -81,6 +82,202 @@ def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTISta
)


@functools.lru_cache
def _get_expanded_ti_count(task, ti, session) -> int:
"""Get how many tis the current task is supposed to be expanded into.

This extra closure allows us to query the database only when needed,
and at most once.
"""
return task.get_mapped_ti_count(ti.run_id, session=session)


@functools.lru_cache
def _get_relevant_upstream_map_indexes(
*, ti, upstream_tasks, upstream_id: str, session
) -> int | range | None:
"""Get the given task's map indexes relevant to the current ti.

This extra closure allows us to query the database only when needed,
and at most once for each task (instead of once for each expanded
task instance of the same task).
"""
from airflow.models.abstractoperator import NotMapped
from airflow.models.expandinput import NotFullyPopulated

try:
expanded_ti_count = _get_expanded_ti_count()
except (NotFullyPopulated, NotMapped):
return None
return ti.get_relevant_upstream_map_indexes(
upstream_tasks[upstream_id],
expanded_ti_count,
session=session,
)


def _iter_upstream_conditions(*, ti, task, upstream_tasks, session) -> Iterator[ColumnOperators]:
"""
Get filter conditions for the upstream tasks we are concerned with.

:param task: the object task
:param upstream_tasks: the upstreams we care about
"""
from airflow.models.taskinstance import TaskInstance

# Optimization: If the current task is not in a mapped task group,
# it depends on all upstream task instances.
if task.get_closest_mapped_task_group() is None:
yield TaskInstance.task_id.in_(upstream_tasks)
return
# Otherwise we need to figure out which map indexes are depended on
# for each upstream by the current task instance.
for upstream_id in upstream_tasks:
map_indexes = _get_relevant_upstream_map_indexes(
ti=ti, upstream_tasks=upstream_tasks, upstream_id=upstream_id, session=session
)
if map_indexes is None: # All tis of this upstream are dependencies.
yield (TaskInstance.task_id == upstream_id)
continue
# At this point we know we want to depend on only selected tis
# of this upstream task. Since the upstream may not have been
# expanded at this point, we also depend on the non-expanded ti
# to ensure at least one ti is included for the task.
yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0)
if isinstance(map_indexes, range) and map_indexes.step == 1:
yield and_(
TaskInstance.task_id == upstream_id,
TaskInstance.map_index >= map_indexes.start,
TaskInstance.map_index < map_indexes.stop,
)
elif isinstance(map_indexes, collections.abc.Container):
yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes))
else:
yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes)


log = logging.getLogger(__name__)


class IndirectSetupTasksDep(BaseTIDep):
"""Determines if a task's upstream tasks are in a state that allows a given task instance to run."""

NAME = "Indirect setup tasks"
IGNORABLE = True
IS_TASK_DEP = True

def _get_dep_statuses(
self,
ti: TaskInstance,
session: Session,
dep_context: DepContext,
) -> Iterator[TIDepStatus]:
if ti.task.trigger_rule == TR.ALWAYS:
yield self._passing_status(reason="The task had a always trigger rule set.")
return
if not ti.task.upstream_task_ids:
yield self._passing_status(reason="The task instance does not have any upstream tasks.")
return
if ti.task.is_teardown:
yield self._passing_status(reason="Indirect setup does not apply to teardowns.")
return
relevant_setups = {x.task_id: x for x in ti.task.get_upstreams_only_setups()}
if relevant_setups:
log.warning(f"evaluating setup for task {ti.task_id}")
yield from self._evaluate_setup(
ti=ti, dep_context=dep_context, relevant_setups=relevant_setups, session=session
)
else:
log.warning(f"NOT evaluating setup for task {ti.task_id}")

def _evaluate_setup(self, *, ti, dep_context, relevant_setups, session):
# we know that this task has indirect setups
# Optimization: Don't need to hit the database if all upstreams are
# "simple" tasks (no task or task group mapping involved).
from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance

upstream_tasks = relevant_setups

finished_upstream_tis = (
x
for x in dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
if x.task_id in upstream_tasks
)
counter: dict[str, int] = collections.Counter()
for ti_ in finished_upstream_tis:
curr_state = {ti_.state: 1}
counter.update(curr_state)
if ti_.task.is_setup:
counter.update(curr_state)
success = counter.get(TaskInstanceState.SUCCESS, 0)
skipped = counter.get(TaskInstanceState.SKIPPED, 0)
failed = counter.get(TaskInstanceState.FAILED, 0)
upstream_failed = counter.get(TaskInstanceState.UPSTREAM_FAILED, 0)
counter.get(TaskInstanceState.REMOVED, 0)
done = sum(counter.values())
if not any(needs_expansion(t) for t in upstream_tasks.values()):
upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup)
else:
task_id_counts = session.execute(
select(TaskInstance.task_id, func.count(TaskInstance.task_id))
.where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
.where(
or_(
*_iter_upstream_conditions(
ti=ti, task=ti.task, upstream_tasks=upstream_tasks, session=session
)
)
)
.group_by(TaskInstance.task_id)
).all()
upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup)
upstream_done = done >= upstream_setup

changed = False
new_state = None
if upstream_done:
log.warning(f"task={ti.task_id}: upstream done")
if success >= done:
# log.warning(f"task={ti.task_id}: all success")
pass
elif upstream_failed or failed:
# log.warning(f"task={ti.task_id}: one failed")
new_state = TaskInstanceState.UPSTREAM_FAILED
elif skipped:
# log.warning(f"task={ti.task_id}: one skipped")
new_state = TaskInstanceState.SKIPPED
else:
# log.warning(f"task={ti.task_id}: fail by default")
new_state = TaskInstanceState.UPSTREAM_FAILED
else:
log.warning(f"task={ti.task_id}: upstream not done: {done=} {upstream_setup=}")
# log.warning(f"task={ti.task_id}: upstream not done: {done.__class__=} {upstream_setup.__class__=}")

if new_state is not None:
log.warning(f"updating {ti.task_id=} to state {new_state}")
if new_state == TaskInstanceState.SKIPPED and dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
if not past_depends_met:
yield self._failing_status(
reason=("Task should be skipped but the past depends are not met")
)
return
changed = ti.set_state(new_state, session)

if changed:
dep_context.have_changed_ti_states = True

if not upstream_done:
reason = (
f"All setups must be completed, but found {len(upstream_tasks) - done} task(s) "
"that were not done. "
)
yield self._failing_status(reason=reason)


class TriggerRuleDep(BaseTIDep):
"""Determines if a task's upstream tasks are in a state that allows a given task instance to run."""

Expand Down Expand Up @@ -116,42 +313,13 @@ def _evaluate_trigger_rule(
:param dep_context: The current dependency context.
:param session: Database session.
"""
from airflow.models.abstractoperator import NotMapped
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance

task = ti.task
upstream_tasks = {t.task_id: t for t in task.upstream_list}
trigger_rule = task.trigger_rule

@functools.lru_cache
def _get_expanded_ti_count() -> int:
"""Get how many tis the current task is supposed to be expanded into.

This extra closure allows us to query the database only when needed,
and at most once.
"""
return task.get_mapped_ti_count(ti.run_id, session=session)

@functools.lru_cache
def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
"""Get the given task's map indexes relevant to the current ti.

This extra closure allows us to query the database only when needed,
and at most once for each task (instead of once for each expanded
task instance of the same task).
"""
try:
expanded_ti_count = _get_expanded_ti_count()
except (NotFullyPopulated, NotMapped):
return None
return ti.get_relevant_upstream_map_indexes(
upstream_tasks[upstream_id],
expanded_ti_count,
session=session,
)

def _is_relevant_upstream(upstream: TaskInstance) -> bool:
"""Whether a task instance is a "relevant upstream" of the current task."""
# Not actually an upstream task.
Expand All @@ -167,7 +335,9 @@ def _is_relevant_upstream(upstream: TaskInstance) -> bool:
return True
# Now we need to perform fine-grained check on whether this specific
# upstream ti's map index is relevant.
relevant = _get_relevant_upstream_map_indexes(upstream.task_id)
relevant = _get_relevant_upstream_map_indexes(
ti=ti, upstream_tasks=upstream_tasks, upstream_id=upstream.task_id, session=session
)
if relevant is None:
return True
if relevant == upstream.map_index:
Expand All @@ -192,35 +362,6 @@ def _is_relevant_upstream(upstream: TaskInstance) -> bool:
success_setup = upstream_states.success_setup
skipped_setup = upstream_states.skipped_setup

def _iter_upstream_conditions() -> Iterator[ColumnOperators]:
# Optimization: If the current task is not in a mapped task group,
# it depends on all upstream task instances.
if task.get_closest_mapped_task_group() is None:
yield TaskInstance.task_id.in_(upstream_tasks)
return
# Otherwise we need to figure out which map indexes are depended on
# for each upstream by the current task instance.
for upstream_id in upstream_tasks:
map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
if map_indexes is None: # All tis of this upstream are dependencies.
yield (TaskInstance.task_id == upstream_id)
continue
# At this point we know we want to depend on only selected tis
# of this upstream task. Since the upstream may not have been
# expanded at this point, we also depend on the non-expanded ti
# to ensure at least one ti is included for the task.
yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0)
if isinstance(map_indexes, range) and map_indexes.step == 1:
yield and_(
TaskInstance.task_id == upstream_id,
TaskInstance.map_index >= map_indexes.start,
TaskInstance.map_index < map_indexes.stop,
)
elif isinstance(map_indexes, collections.abc.Container):
yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes))
else:
yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes)

# Optimization: Don't need to hit the database if all upstreams are
# "simple" tasks (no task or task group mapping involved).
if not any(needs_expansion(t) for t in upstream_tasks.values()):
Expand All @@ -230,7 +371,13 @@ def _iter_upstream_conditions() -> Iterator[ColumnOperators]:
task_id_counts = session.execute(
select(TaskInstance.task_id, func.count(TaskInstance.task_id))
.where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
.where(or_(*_iter_upstream_conditions()))
.where(
or_(
*_iter_upstream_conditions(
ti=ti, task=task, upstream_tasks=upstream_tasks, session=session
)
)
)
.group_by(TaskInstance.task_id)
).all()
upstream = sum(count for _, count in task_id_counts)
Expand Down Expand Up @@ -288,6 +435,7 @@ def _iter_upstream_conditions() -> Iterator[ColumnOperators]:
# if at least one setup ran, we'll let it run
new_state = TaskInstanceState.UPSTREAM_FAILED
if new_state is not None:
log.warning(f"trigger rule dep changing {ti.task_id} to {new_state}")
if new_state == TaskInstanceState.SKIPPED and dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
Expand Down