Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ There are three main types of cluster policy:
task running in a DagRun. The ``task_policy`` defined is applied to all the task instances that will be
executed in the future.
* ``task_instance_mutation_hook``: Takes a :class:`~airflow.models.taskinstance.TaskInstance` parameter called
``task_instance``. The ``task_instance_mutation_hook`` applies not to a task but to the instance of a task that
relates to a particular DagRun. It is executed in a "worker", not in the Dag file processor, just before the
task instance is executed. The policy is only applied to the currently executed run (i.e. instance) of that
task.
``task_instance``. Runs scheduler-side, just before the task instance is dispatched to an executor. It may
fire more than once for a given task instance — at creation and on each retry transition, where the next
attempt gets a fresh ``task_instance.id``. Implementations should be idempotent.

The Dag and Task cluster policies can raise the :class:`~airflow.exceptions.AirflowClusterPolicyViolation`
exception to indicate that the Dag/task they were passed is not compliant and should not be loaded.
Expand Down
27 changes: 21 additions & 6 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,8 @@ def schedule_tis(
reschedule_ti_ids: set[UUID] = set()
debug_try_number_check = self.log.isEnabledFor(logging.DEBUG)
expected_try_number_by_ti_id: dict[UUID, tuple[int, int, str | None]] = {}
count = 0
had_retry_mutation = False
for ti in schedulable_tis:
if not ti.is_schedulable:
empty_ti_ids.append(ti.id)
Expand All @@ -2000,19 +2002,32 @@ def schedule_tis(
# If not, we'll add this "ti" into "schedulable_ti_ids" and later
# execute it to run in the worker.
elif not ti.defer_task(session=session):
# Retries flow through the ORM so refresh_from_task can re-apply task
# defaults and run task_instance_mutation_hook against the about-to-run
# try_number. First attempts and reschedules stay on the bulk UPDATE path.
if ti.task is not None and ti.state == TaskInstanceState.UP_FOR_RETRY:
ti.try_number += 1
ti.state = TaskInstanceState.SCHEDULED
ti.scheduled_dttm = timezone.utcnow()
ti.refresh_from_task(ti.task)
count += 1
had_retry_mutation = True
continue

schedulable_ti_ids.append(ti.id)
if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
is_reschedule = ti.state == TaskInstanceState.UP_FOR_RESCHEDULE
if is_reschedule:
reschedule_ti_ids.add(ti.id)
if debug_try_number_check:
expected_try_number_by_ti_id[ti.id] = (
ti.try_number
if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE
else ti.try_number + 1,
ti.try_number if is_reschedule else ti.try_number + 1,
ti.try_number,
ti.state,
)

count = 0
if had_retry_mutation:
# Airflow disables SQLA autoflush, so retry-branch mutations need an
# explicit flush to be visible to the bulk UPDATE/SELECTs that follow.
session.flush()
# Don't only check if the TI.id is in id_chunk
# but also check if the TI.state is in the schedulable states.
# Plus, a scheduled empty operator should not be scheduled again.
Expand Down
147 changes: 147 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,153 @@ def test_schedule_tis_up_for_reschedule_does_not_increment_try_number(dag_maker,
assert refreshed_ti.try_number == 3


def test_schedule_tis_reapplies_mutation_hook_on_retry(dag_maker, session, monkeypatch):
def reroute_retries(task_instance):
if (task_instance.try_number or 0) >= 2:
task_instance.queue = "retry_queue"

monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", reroute_retries)

with dag_maker(session=session) as dag:
BashOperator(task_id="task", bash_command="echo 1", queue="default")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
ti.refresh_from_task(dag.get_task("task"))
ti.state = TaskInstanceState.UP_FOR_RETRY
ti.try_number = 1
session.commit()

assert dr.schedule_tis((ti,), session=session) == 1
session.commit()

session.expire_all()
refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.state == TaskInstanceState.SCHEDULED
assert refreshed_ti.try_number == 2
assert refreshed_ti.queue == "retry_queue"


def test_schedule_tis_does_not_apply_retry_mutation_hook_on_first_attempt(dag_maker, session, monkeypatch):
def reroute_retries(task_instance):
if (task_instance.try_number or 0) >= 1:
task_instance.queue = "retry_queue"

monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", reroute_retries)

with dag_maker(session=session) as dag:
BashOperator(task_id="task", bash_command="echo 1", queue="default")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
ti.refresh_from_task(dag.get_task("task"))
ti.state = None
ti.try_number = 0
ti.queue = "default"
session.merge(ti)
session.commit()

assert dr.schedule_tis((ti,), session=session) == 1
session.commit()

session.expire_all()
refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.state == TaskInstanceState.SCHEDULED
assert refreshed_ti.try_number == 1
assert refreshed_ti.queue == "default"


def test_schedule_tis_recomputes_priority_weight_for_dynamic_strategy(dag_maker, session):
from airflow import plugins_manager

from tests_common.test_utils.mock_plugins import mock_plugin_manager
from unit.plugins.priority_weight_strategy import (
DecreasingPriorityStrategy,
TestPriorityWeightStrategyPlugin,
)

try:
with mock_plugin_manager(plugins=[TestPriorityWeightStrategyPlugin]):
with dag_maker(session=session) as dag:
BashOperator(
task_id="task",
bash_command="echo 1",
weight_rule=DecreasingPriorityStrategy(),
)

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
ti.refresh_from_task(dag.get_task("task"))
ti.state = TaskInstanceState.UP_FOR_RETRY
ti.try_number = 1
session.merge(ti)
session.commit()

assert dr.schedule_tis((ti,), session=session) == 1
session.commit()

session.expire_all()
refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
# DecreasingPriorityStrategy returns max(3 - try + 1, 1). After schedule_tis
# bumps try_number 1 -> 2, the strategy is re-evaluated against try=2 giving 2.
assert refreshed_ti.try_number == 2
assert refreshed_ti.priority_weight == 2
finally:
# mock_plugin_manager clears strategy caches on entry but not on exit.
plugins_manager.get_priority_weight_strategy_plugins.cache_clear()


def test_schedule_tis_does_not_touch_priority_weight_for_static_strategy(dag_maker, session):
with dag_maker(session=session) as dag:
BashOperator(task_id="task", bash_command="echo 1", priority_weight=42)

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
ti.refresh_from_task(dag.get_task("task"))
ti.state = TaskInstanceState.UP_FOR_RETRY
ti.try_number = 1
original_priority = ti.priority_weight
original_queue = ti.queue
session.commit()

assert dr.schedule_tis((ti,), session=session) == 1
session.commit()

session.expire_all()
refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.try_number == 2
assert refreshed_ti.priority_weight == original_priority
assert refreshed_ti.queue == original_queue


def test_schedule_tis_empty_operator_is_noop_if_ti_already_running(dag_maker, session):
with dag_maker(session=session) as dag:
EmptyOperator(task_id="empty_task")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DecreasingPriorityStrategy(PriorityWeightStrategy):
"""A priority weight strategy that decreases the priority weight with each attempt."""

def get_weight(self, ti: TaskInstance):
return max(3 - ti.try_number + 1, 1)
return max(3 - (ti.try_number or 0) + 1, 1)


class TestPriorityWeightStrategyPlugin(AirflowPlugin):
Expand Down
Loading