Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@
TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule"
""":meta private:"""

_TRIGGER_TIMEOUT_BATCH_SIZE = 1000
"""Maximum number of task instances to lock per trigger-timeout batch."""


def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]:
"""
Expand Down Expand Up @@ -2878,25 +2881,45 @@ def check_trigger_timeouts(
self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION
) -> None:
"""Mark any "deferred" task as failed if the trigger or execution timeout has passed."""
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
result = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
while True:
task_instance_ids = []
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
now = timezone.utcnow()
candidates = (
select(TI.id)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < now,
)
.order_by(TI.id)
.limit(_TRIGGER_TIMEOUT_BATCH_SIZE)
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method=TRIGGER_FAIL_REPR,
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
scheduled_dttm=timezone.utcnow(),
trigger_id=None,
task_instance_ids = list(
session.scalars(
with_row_locks(candidates, of=TI, session=session, skip_locked=True)
).all()
)
)
num_timed_out_tasks = getattr(result, "rowcount", 0)
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
if task_instance_ids:
result = session.execute(
update(TI)
.where(TI.id.in_(task_instance_ids))
.values(
state=TaskInstanceState.SCHEDULED,
next_method=TRIGGER_FAIL_REPR,
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
scheduled_dttm=now,
trigger_id=None,
)
.execution_options(synchronize_session=False)
)
num_timed_out_tasks = getattr(result, "rowcount", 0)
if num_timed_out_tasks:
self.log.info(
"Timed out %i deferred tasks without fired triggers", num_timed_out_tasks
)
if len(task_instance_ids) < _TRIGGER_TIMEOUT_BATCH_SIZE:
break

# [START find_and_purge_task_instances_without_heartbeats]
def _find_and_purge_task_instances_without_heartbeats(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,7 @@ def _are_premature_tis(
ignore_in_retry_period=True,
ignore_in_reschedule_period=True,
finished_tis=finished_tis,
ensure_fresh_tis_before_state_change=True,
)
# there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are
# not ready yet, so we set the flags to count them in
Expand Down
40 changes: 31 additions & 9 deletions airflow-core/src/airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@

log = logging.getLogger(__name__)

_TRIGGER_ID_CLEANUP_BATCH_SIZE = 1000
"""Maximum number of task instances to lock per trigger-id cleanup batch."""


class TriggerFailureReason(str, Enum):
"""
Expand Down Expand Up @@ -226,16 +229,35 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None:
Triggers have a one-to-many relationship to task instances, so we need to clean those up first.
Afterward we can drop the triggers not referenced by anyone.
"""
# Update all task instances with trigger IDs that are not DEFERRED to remove them
for attempt in run_with_db_retries():
with attempt:
session.execute(
update(TaskInstance)
.where(
TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.is_not(None)
# Clear task-instance trigger references in primary-key order to avoid locking the same rows in
# a different order than scheduler timeout handling.
while True:
task_instance_ids = []
for attempt in run_with_db_retries():
with attempt:
candidates = (
select(TaskInstance.id)
.where(
TaskInstance.state != TaskInstanceState.DEFERRED,
TaskInstance.trigger_id.is_not(None),
)
.order_by(TaskInstance.id)
.limit(_TRIGGER_ID_CLEANUP_BATCH_SIZE)
)
task_instance_ids = list(
session.scalars(
with_row_locks(candidates, of=TaskInstance, session=session, skip_locked=True)
).all()
)
.values(trigger_id=None)
)
if task_instance_ids:
session.execute(
update(TaskInstance)
.where(TaskInstance.id.in_(task_instance_ids))
.values(trigger_id=None)
.execution_options(synchronize_session=False)
)
if len(task_instance_ids) < _TRIGGER_ID_CLEANUP_BATCH_SIZE:
break

# Get all triggers that have no task instances, assets, or callbacks depending on them and delete them
ids = (
Expand Down
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class DepContext:
trigger rule
:param ignore_ti_state: Ignore the task instance's previous failure/success
:param finished_tis: A list of all the finished task instances of this run
:param ensure_fresh_tis_before_state_change: Re-query finished task instances before writing
a terminal state based on trigger-rule dependency evaluation
"""

deps: set = attr.ib(factory=set)
Expand All @@ -80,6 +82,7 @@ class DepContext:
ignore_ti_state: bool = False
ignore_unmapped_tasks: bool = False
finished_tis: list[TaskInstance] | None = None
ensure_fresh_tis_before_state_change: bool = False
description: str | None = None

have_changed_ti_states: bool = False
Expand All @@ -103,3 +106,7 @@ def ensure_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskIns
else:
finished_tis = self.finished_tis
return finished_tis

def refresh_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskInstance]:
self.finished_tis = None
return self.ensure_finished_tis(dag_run=dag_run, session=session)
23 changes: 21 additions & 2 deletions airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,23 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnElement]:
else:
yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes)

def _has_stale_finished_upstreams(relevant_ids: set[str] | KeysView[str]) -> bool:
if not dep_context.ensure_fresh_tis_before_state_change or dep_context.finished_tis is None:
return False

def _relevant_states(
finished_tis: list[TaskInstance],
) -> dict[tuple[str, int], str]:
return {
(upstream.task_id, upstream.map_index): upstream.state
for upstream in finished_tis
if upstream.state is not None and _is_relevant_upstream(upstream, relevant_ids)
}

cached_states = _relevant_states(dep_context.finished_tis)
fresh_states = _relevant_states(dep_context.refresh_finished_tis(ti.get_dagrun(session), session))
return cached_states != fresh_states

def _evaluate_setup_constraint(
*, relevant_setups: Mapping[str, Operator]
) -> Iterator[tuple[TIDepStatus, bool]]:
Expand Down Expand Up @@ -324,7 +341,8 @@ def _evaluate_setup_constraint(
changed,
)
return
changed = ti.set_state(new_state, session)
if not _has_stale_finished_upstreams(relevant_setups.keys()):
changed = ti.set_state(new_state, session)

if changed:
dep_context.have_changed_ti_states = True
Expand Down Expand Up @@ -457,7 +475,8 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
reason="Task should be skipped but the past depends are not met"
)
return
changed = ti.set_state(new_state, session)
if not _has_stale_finished_upstreams(task.upstream_task_ids):
changed = ti.set_state(new_state, session)

if changed:
dep_context.have_changed_ti_states = True
Expand Down
39 changes: 39 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6831,6 +6831,45 @@ def test_timeout_triggers(self, dag_maker):
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED

def test_timeout_triggers_processes_more_than_one_batch(self, dag_maker, monkeypatch):
"""Timed-out deferred task instances are all updated when they span multiple batches."""
import airflow.jobs.scheduler_job_runner as scheduler_job_runner_module

monkeypatch.setattr(scheduler_job_runner_module, "_TRIGGER_TIMEOUT_BATCH_SIZE", 2)

session = settings.Session()
with dag_maker(
dag_id="test_timeout_triggers_processes_more_than_one_batch",
start_date=DEFAULT_DATE,
schedule="@once",
max_active_runs=5,
session=session,
):
EmptyOperator(task_id="dummy1")

past = timezone.utcnow() - datetime.timedelta(seconds=60)
task_instances = []
for index in range(5):
dag_run = dag_maker.create_dagrun(
run_id=f"test_batch_{index}",
logical_date=DEFAULT_DATE + datetime.timedelta(seconds=index),
)
task_instance = dag_run.get_task_instance("dummy1", session)
task_instance.state = State.DEFERRED
task_instance.trigger_timeout = past
task_instances.append(task_instance)
session.flush()

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

self.job_runner.check_trigger_timeouts(session=session)

for task_instance in task_instances:
session.refresh(task_instance)
assert task_instance.state == State.SCHEDULED
assert task_instance.next_method == "__fail__"

def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker, testing_dag_bundle, session):
"""
Tests that it will retry on DB error like deadlock when updating timeout triggers.
Expand Down
61 changes: 59 additions & 2 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
select,
update,
)
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session as SASession, joinedload

from airflow import settings
from airflow._shared.observability.metrics.stats import Stats
Expand Down Expand Up @@ -1965,6 +1965,64 @@ def consumer(*args):
assert dr.state == DagRunState.FAILED


def test_stale_finished_tis_do_not_cause_stuck_upstream_failed(dag_maker, session):
def _assign_serialized_tasks(tis, dag_run):
serialized_dag = dag_run.get_dag()
for ti in tis:
ti.task = serialized_dag.get_task(ti.task_id)

with dag_maker("test_upstream_failed_race", session=session):
fail_task = EmptyOperator(task_id="fail_task")
t0 = EmptyOperator(task_id="t0")
t1 = EmptyOperator(task_id="t1")
t2 = EmptyOperator(task_id="t2")
fail_task >> t0 >> t1 >> t2

dr = dag_maker.create_dagrun(state=DagRunState.RUNNING)
tis = {ti.task_id: ti for ti in dr.task_instances}

tis["fail_task"].state = TaskInstanceState.FAILED
tis["t0"].state = TaskInstanceState.UPSTREAM_FAILED
session.flush()
session.commit()

scheduler_session = SASession(bind=session.get_bind())
try:
sched_dr = scheduler_session.get(DagRun, dr.id)
sched_dr.dag = dr.dag

stale_finished_tis = sched_dr.get_task_instances(state=State.finished, session=scheduler_session)
_assign_serialized_tasks(stale_finished_tis, sched_dr)

session.expire_all()
api_tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
api_tis["fail_task"].state = TaskInstanceState.SUCCESS
api_tis["t0"].state = None
session.flush()
session.commit()

unfinished_tis = sched_dr.get_task_instances(state=State.unfinished, session=scheduler_session)
_assign_serialized_tasks(unfinished_tis, sched_dr)
sched_dr._are_premature_tis(
unfinished_tis=unfinished_tis,
finished_tis=stale_finished_tis,
session=scheduler_session,
)
scheduler_session.flush()
scheduler_session.commit()

session.expire_all()
final_states = {ti.task_id: ti.state for ti in dr.get_task_instances(session=session)}
assert final_states == {
"fail_task": TaskInstanceState.SUCCESS,
"t0": None,
"t1": None,
"t2": None,
}
finally:
scheduler_session.close()


def test_mapped_task_all_finish_before_downstream(dag_maker, session):
with dag_maker(session=session) as dag:

Expand Down Expand Up @@ -2260,7 +2318,6 @@ def test_schedule_tis_only_one_scheduler_update_succeeds_when_competing(dag_make
assert refreshed_ti.try_number == 1


@pytest.mark.xfail(reason="We can't keep this behaviour with remote workers where scheduler can't reach xcom")
@pytest.mark.need_serialized_dag
def test_schedule_tis_start_trigger(dag_maker, session):
"""
Expand Down
32 changes: 32 additions & 0 deletions airflow-core/tests/unit/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,38 @@ def test_clean_unused(session, dag_maker):
assert {result.id for result in results} == {trigger1.id, trigger4.id, trigger5.id, trigger6.id}


def test_clean_unused_clears_trigger_ids_in_batches(session, dag_maker, monkeypatch):
"""Non-deferred task instances have trigger references cleared when they span multiple batches."""
import airflow.models.trigger as trigger_module

monkeypatch.setattr(trigger_module, "_TRIGGER_ID_CLEANUP_BATCH_SIZE", 2)

triggers = [
Trigger(classpath=f"airflow.triggers.testing.SuccessTrigger{index}", kwargs={}) for index in range(5)
]
session.add_all(triggers)
session.flush()

with dag_maker(session=session, dag_id="test_clean_unused_clears_trigger_ids_in_batches"):
for index in range(5):
EmptyOperator(task_id=f"fake{index}")

dag_run = dag_maker.create_dagrun(logical_date=timezone.utcnow())
task_instances = {task_instance.task_id: task_instance for task_instance in dag_run.task_instances}
for index, trigger in enumerate(triggers):
task_instance = task_instances[f"fake{index}"]
task_instance.state = State.SUCCESS
task_instance.trigger_id = trigger.id
session.flush()

Trigger.clean_unused(session=session)

for task_instance in task_instances.values():
session.refresh(task_instance)
assert task_instance.trigger_id is None
assert session.scalar(select(func.count()).select_from(Trigger)) == 0


@patch.object(TriggererCallback, "handle_event")
def test_submit_event(mock_callback_handle_event, session, create_task_instance):
"""
Expand Down
Loading