From d34bd50683a91ab8fe5f64845752a429c45f44fa Mon Sep 17 00:00:00 2001 From: Stijn De Haes Date: Mon, 5 Dec 2022 02:12:36 +0100 Subject: [PATCH] Make sure we can get out of a faulty scheduler state (#27834) * Make sure we can get out of a faulty scheduler state This PR fixed the case where we have a faulty state in the database. The state that is fixed is that both the unmapped task instance and mapped task instances exist at the same time. So we have instances with map_index [-1, 0, 1]. The -1 task instances should be removed in this case. (cherry picked from commit 73d9352225bcc1f086b63f1c767d25b2d7c4c221) --- airflow/models/abstractoperator.py | 36 +++++++++++++------- airflow/models/dagrun.py | 9 ++--- tests/models/test_dagrun.py | 43 +++++++++++++++++++++++- tests/models/test_mappedoperator.py | 51 +++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 16 deletions(-) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index d5d6ad082f3d88..ba0a8954ae1832 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -484,7 +484,6 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence # are not done yet, so the task can't fail yet. if not self.dag or not self.dag.partial: unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED - indexes_to_map: Iterable[int] = () elif total_length < 1: # If the upstream maps this to a zero-length value, simply mark # the unmapped task instance as SKIPPED (if needed). @@ -494,18 +493,33 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence total_length, ) unmapped_ti.state = TaskInstanceState.SKIPPED - indexes_to_map = () else: - # Otherwise convert this into the first mapped index, and create - # TaskInstance for other indexes. - unmapped_ti.map_index = 0 - self.log.debug("Updated in place to become %s", unmapped_ti) - all_expanded_tis.append(unmapped_ti) - indexes_to_map = range(1, total_length) - state = unmapped_ti.state - elif not total_length: + zero_index_ti_exists = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == 0, + ) + .count() + > 0 + ) + if not zero_index_ti_exists: + # Otherwise convert this into the first mapped index, and create + # TaskInstance for other indexes. + unmapped_ti.map_index = 0 + self.log.debug("Updated in place to become %s", unmapped_ti) + all_expanded_tis.append(unmapped_ti) + session.flush() + else: + self.log.debug("Deleting the original task instance: %s", unmapped_ti) + session.delete(unmapped_ti) + state = unmapped_ti.state + + if total_length is None or total_length < 1: # Nothing to fixup. - indexes_to_map = () + indexes_to_map: Iterable[int] = () else: # Only create "missing" ones. current_max_mapping = ( diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 9e02f4775f549b..ae3a390653e011 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -768,7 +768,8 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: """Try to expand the ti, if needed. If the ti needs expansion, newly created task instances are - returned. The original ti is modified in-place and assigned the + returned as well as the original ti. + The original ti is also modified in-place and assigned the ``map_index`` of 0. If the ti does not need expansion, either because the task is not @@ -781,8 +782,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: except NotMapped: # Not a mapped task, nothing needed. return None if expanded_tis: - assert expanded_tis[0] is ti - return expanded_tis[1:] + return expanded_tis return () # Check dependencies. @@ -798,12 +798,13 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: # in the scheduler to ensure that the mapped task is correctly # expanded before executed. Also see _revise_map_indexes_if_mapped # docstring for additional information. + new_tis = None if schedulable.map_index < 0: new_tis = _expand_mapped_task_if_needed(schedulable) if new_tis is not None: additional_tis.extend(new_tis) expansion_happened = True - if schedulable.state in SCHEDULEABLE_STATES: + if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session)) ready_tis.append(schedulable) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index ffee25f5e89a9d..34b67ba543eafc 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -29,7 +29,15 @@ from airflow import settings from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.decorators import task, task_group -from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance as TI, clear_task_instances +from airflow.models import ( + DAG, + DagBag, + DagModel, + DagRun, + TaskInstance, + TaskInstance as TI, + clear_task_instances, +) from airflow.models.baseoperator import BaseOperator from airflow.models.taskmap import TaskMap from airflow.operators.empty import EmptyOperator @@ -1285,6 +1293,39 @@ def task_2(arg2): ] +def test_mapped_literal_faulty_state_in_db(dag_maker, session): + """ + This test tries to recreate a faulty state in the database and checks if we can recover from it. + The state that happens is that there exists mapped task instances and the unmapped task instance. + So we have instances with map_index [-1, 0, 1]. The -1 task instances should be removed in this case. + """ + + with dag_maker(session=session) as dag: + + @task + def task_1(): + return [1, 2] + + @task + def task_2(arg2): + ... + + task_2.expand(arg2=task_1()) + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance(task_id="task_1") + ti.run() + decision = dr.task_instance_scheduling_decisions() + assert len(decision.schedulable_tis) == 2 + + # We insert a faulty record + session.add(TaskInstance(dag.get_task("task_2"), dr.execution_date, dr.run_id)) + session.flush() + + decision = dr.task_instance_scheduling_decisions() + assert len(decision.schedulable_tis) == 2 + + def test_mapped_literal_length_with_no_change_at_runtime_doesnt_call_verify_integrity(dag_maker, session): """ Test that when there's no change to mapped task indexes at runtime, the dagrun.verify_integrity diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 1998563d704337..036a12fac4c2ac 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -228,6 +228,57 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec assert indices == expected +def test_expand_mapped_task_failed_state_in_db(dag_maker, session): + """ + This test tries to recreate a faulty state in the database and checks if we can recover from it. + The state that happens is that there exists mapped task instances and the unmapped task instance. + So we have instances with map_index [-1, 0, 1]. The -1 task instances should be removed in this case. + """ + literal = [1, 2] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + for index in range(2): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + # Make sure we have the faulty state in the database + assert indices == [(-1, None), (0, "success"), (1, "success")] + + mapped.expand_mapped_task(dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + # The -1 index should be cleaned up + assert indices == [(0, "success"), (1, "success")] + + def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): with dag_maker(session=session): task1 = BaseOperator(task_id="op1")