diff --git a/airflow-core/docs/administration-and-deployment/cluster-policies.rst b/airflow-core/docs/administration-and-deployment/cluster-policies.rst index fd7d0b622f540..e8d637d08d0d9 100644 --- a/airflow-core/docs/administration-and-deployment/cluster-policies.rst +++ b/airflow-core/docs/administration-and-deployment/cluster-policies.rst @@ -37,10 +37,9 @@ There are three main types of cluster policy: task running in a DagRun. The ``task_policy`` defined is applied to all the task instances that will be executed in the future. * ``task_instance_mutation_hook``: Takes a :class:`~airflow.models.taskinstance.TaskInstance` parameter called - ``task_instance``. The ``task_instance_mutation_hook`` applies not to a task but to the instance of a task that - relates to a particular DagRun. It is executed in a "worker", not in the Dag file processor, just before the - task instance is executed. The policy is only applied to the currently executed run (i.e. instance) of that - task. + ``task_instance``. Runs scheduler-side, just before the task instance is dispatched to an executor. It may + fire more than once for a given task instance — at creation and on each retry transition, where the next + attempt gets a fresh ``task_instance.id``. Implementations should be idempotent. The Dag and Task cluster policies can raise the :class:`~airflow.exceptions.AirflowClusterPolicyViolation` exception to indicate that the Dag/task they were passed is not compliant and should not be loaded. diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index ca2b68f34be2f..197389bd3ff24 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1991,6 +1991,8 @@ def schedule_tis( reschedule_ti_ids: set[UUID] = set() debug_try_number_check = self.log.isEnabledFor(logging.DEBUG) expected_try_number_by_ti_id: dict[UUID, tuple[int, int, str | None]] = {} + count = 0 + had_retry_mutation = False for ti in schedulable_tis: if not ti.is_schedulable: empty_ti_ids.append(ti.id) @@ -2000,19 +2002,32 @@ def schedule_tis( # If not, we'll add this "ti" into "schedulable_ti_ids" and later # execute it to run in the worker. elif not ti.defer_task(session=session): + # Retries flow through the ORM so refresh_from_task can re-apply task + # defaults and run task_instance_mutation_hook against the about-to-run + # try_number. First attempts and reschedules stay on the bulk UPDATE path. + if ti.task is not None and ti.state == TaskInstanceState.UP_FOR_RETRY: + ti.try_number += 1 + ti.state = TaskInstanceState.SCHEDULED + ti.scheduled_dttm = timezone.utcnow() + ti.refresh_from_task(ti.task) + count += 1 + had_retry_mutation = True + continue + schedulable_ti_ids.append(ti.id) - if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE: + is_reschedule = ti.state == TaskInstanceState.UP_FOR_RESCHEDULE + if is_reschedule: reschedule_ti_ids.add(ti.id) if debug_try_number_check: expected_try_number_by_ti_id[ti.id] = ( - ti.try_number - if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE - else ti.try_number + 1, + ti.try_number if is_reschedule else ti.try_number + 1, ti.try_number, ti.state, ) - - count = 0 + if had_retry_mutation: + # Airflow disables SQLA autoflush, so retry-branch mutations need an + # explicit flush to be visible to the bulk UPDATE/SELECTs that follow. + session.flush() # Don't only check if the TI.id is in id_chunk # but also check if the TI.state is in the schedulable states. # Plus, a scheduled empty operator should not be scheduled again. diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index a9f9dfb136c17..40cb766034805 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -2286,6 +2286,153 @@ def test_schedule_tis_up_for_reschedule_does_not_increment_try_number(dag_maker, assert refreshed_ti.try_number == 3 +def test_schedule_tis_reapplies_mutation_hook_on_retry(dag_maker, session, monkeypatch): + def reroute_retries(task_instance): + if (task_instance.try_number or 0) >= 2: + task_instance.queue = "retry_queue" + + monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", reroute_retries) + + with dag_maker(session=session) as dag: + BashOperator(task_id="task", bash_command="echo 1", queue="default") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + ti.refresh_from_task(dag.get_task("task")) + ti.state = TaskInstanceState.UP_FOR_RETRY + ti.try_number = 1 + session.commit() + + assert dr.schedule_tis((ti,), session=session) == 1 + session.commit() + + session.expire_all() + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.state == TaskInstanceState.SCHEDULED + assert refreshed_ti.try_number == 2 + assert refreshed_ti.queue == "retry_queue" + + +def test_schedule_tis_does_not_apply_retry_mutation_hook_on_first_attempt(dag_maker, session, monkeypatch): + def reroute_retries(task_instance): + if (task_instance.try_number or 0) >= 1: + task_instance.queue = "retry_queue" + + monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", reroute_retries) + + with dag_maker(session=session) as dag: + BashOperator(task_id="task", bash_command="echo 1", queue="default") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + ti.refresh_from_task(dag.get_task("task")) + ti.state = None + ti.try_number = 0 + ti.queue = "default" + session.merge(ti) + session.commit() + + assert dr.schedule_tis((ti,), session=session) == 1 + session.commit() + + session.expire_all() + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.state == TaskInstanceState.SCHEDULED + assert refreshed_ti.try_number == 1 + assert refreshed_ti.queue == "default" + + +def test_schedule_tis_recomputes_priority_weight_for_dynamic_strategy(dag_maker, session): + from airflow import plugins_manager + + from tests_common.test_utils.mock_plugins import mock_plugin_manager + from unit.plugins.priority_weight_strategy import ( + DecreasingPriorityStrategy, + TestPriorityWeightStrategyPlugin, + ) + + try: + with mock_plugin_manager(plugins=[TestPriorityWeightStrategyPlugin]): + with dag_maker(session=session) as dag: + BashOperator( + task_id="task", + bash_command="echo 1", + weight_rule=DecreasingPriorityStrategy(), + ) + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + ti.refresh_from_task(dag.get_task("task")) + ti.state = TaskInstanceState.UP_FOR_RETRY + ti.try_number = 1 + session.merge(ti) + session.commit() + + assert dr.schedule_tis((ti,), session=session) == 1 + session.commit() + + session.expire_all() + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + # DecreasingPriorityStrategy returns max(3 - try + 1, 1). After schedule_tis + # bumps try_number 1 -> 2, the strategy is re-evaluated against try=2 giving 2. + assert refreshed_ti.try_number == 2 + assert refreshed_ti.priority_weight == 2 + finally: + # mock_plugin_manager clears strategy caches on entry but not on exit. + plugins_manager.get_priority_weight_strategy_plugins.cache_clear() + + +def test_schedule_tis_does_not_touch_priority_weight_for_static_strategy(dag_maker, session): + with dag_maker(session=session) as dag: + BashOperator(task_id="task", bash_command="echo 1", priority_weight=42) + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + ti.refresh_from_task(dag.get_task("task")) + ti.state = TaskInstanceState.UP_FOR_RETRY + ti.try_number = 1 + original_priority = ti.priority_weight + original_queue = ti.queue + session.commit() + + assert dr.schedule_tis((ti,), session=session) == 1 + session.commit() + + session.expire_all() + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.try_number == 2 + assert refreshed_ti.priority_weight == original_priority + assert refreshed_ti.queue == original_queue + + def test_schedule_tis_empty_operator_is_noop_if_ti_already_running(dag_maker, session): with dag_maker(session=session) as dag: EmptyOperator(task_id="empty_task") diff --git a/airflow-core/tests/unit/plugins/priority_weight_strategy.py b/airflow-core/tests/unit/plugins/priority_weight_strategy.py index ba1ff367889d0..d12f2ca284549 100644 --- a/airflow-core/tests/unit/plugins/priority_weight_strategy.py +++ b/airflow-core/tests/unit/plugins/priority_weight_strategy.py @@ -44,7 +44,7 @@ class DecreasingPriorityStrategy(PriorityWeightStrategy): """A priority weight strategy that decreases the priority weight with each attempt.""" def get_weight(self, ti: TaskInstance): - return max(3 - ti.try_number + 1, 1) + return max(3 - (ti.try_number or 0) + 1, 1) class TestPriorityWeightStrategyPlugin(AirflowPlugin):