Skip to content

Commit

Permalink
Make sure we can get out of a faulty scheduler state (#27834)
Browse files Browse the repository at this point in the history
* Make sure we can get out of a faulty scheduler state

This PR fixed the case where we have a faulty state in the database.
The state that is fixed is that both the unmapped task instance and mapped task instances exist at the same time.

So we have instances with map_index [-1, 0, 1].
The -1 task instances should be removed in this case.

(cherry picked from commit 73d9352)
  • Loading branch information
stijndehaes authored and ephraimbuddy committed Jan 11, 2023
1 parent d174ef1 commit d34bd50
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 16 deletions.
36 changes: 25 additions & 11 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
# are not done yet, so the task can't fail yet.
if not self.dag or not self.dag.partial:
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
indexes_to_map: Iterable[int] = ()
elif total_length < 1:
# If the upstream maps this to a zero-length value, simply mark
# the unmapped task instance as SKIPPED (if needed).
Expand All @@ -494,18 +493,33 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
total_length,
)
unmapped_ti.state = TaskInstanceState.SKIPPED
indexes_to_map = ()
else:
# Otherwise convert this into the first mapped index, and create
# TaskInstance for other indexes.
unmapped_ti.map_index = 0
self.log.debug("Updated in place to become %s", unmapped_ti)
all_expanded_tis.append(unmapped_ti)
indexes_to_map = range(1, total_length)
state = unmapped_ti.state
elif not total_length:
zero_index_ti_exists = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index == 0,
)
.count()
> 0
)
if not zero_index_ti_exists:
# Otherwise convert this into the first mapped index, and create
# TaskInstance for other indexes.
unmapped_ti.map_index = 0
self.log.debug("Updated in place to become %s", unmapped_ti)
all_expanded_tis.append(unmapped_ti)
session.flush()
else:
self.log.debug("Deleting the original task instance: %s", unmapped_ti)
session.delete(unmapped_ti)
state = unmapped_ti.state

if total_length is None or total_length < 1:
# Nothing to fixup.
indexes_to_map = ()
indexes_to_map: Iterable[int] = ()
else:
# Only create "missing" ones.
current_max_mapping = (
Expand Down
9 changes: 5 additions & 4 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,8 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
"""Try to expand the ti, if needed.
If the ti needs expansion, newly created task instances are
returned. The original ti is modified in-place and assigned the
returned as well as the original ti.
The original ti is also modified in-place and assigned the
``map_index`` of 0.
If the ti does not need expansion, either because the task is not
Expand All @@ -781,8 +782,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
except NotMapped: # Not a mapped task, nothing needed.
return None
if expanded_tis:
assert expanded_tis[0] is ti
return expanded_tis[1:]
return expanded_tis
return ()

# Check dependencies.
Expand All @@ -798,12 +798,13 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
# in the scheduler to ensure that the mapped task is correctly
# expanded before executed. Also see _revise_map_indexes_if_mapped
# docstring for additional information.
new_tis = None
if schedulable.map_index < 0:
new_tis = _expand_mapped_task_if_needed(schedulable)
if new_tis is not None:
additional_tis.extend(new_tis)
expansion_happened = True
if schedulable.state in SCHEDULEABLE_STATES:
if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session))
ready_tis.append(schedulable)

Expand Down
43 changes: 42 additions & 1 deletion tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@
from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.decorators import task, task_group
from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance as TI, clear_task_instances
from airflow.models import (
DAG,
DagBag,
DagModel,
DagRun,
TaskInstance,
TaskInstance as TI,
clear_task_instances,
)
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskmap import TaskMap
from airflow.operators.empty import EmptyOperator
Expand Down Expand Up @@ -1285,6 +1293,39 @@ def task_2(arg2):
]


def test_mapped_literal_faulty_state_in_db(dag_maker, session):
"""
This test tries to recreate a faulty state in the database and checks if we can recover from it.
The state that happens is that there exists mapped task instances and the unmapped task instance.
So we have instances with map_index [-1, 0, 1]. The -1 task instances should be removed in this case.
"""

with dag_maker(session=session) as dag:

@task
def task_1():
return [1, 2]

@task
def task_2(arg2):
...

task_2.expand(arg2=task_1())

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id="task_1")
ti.run()
decision = dr.task_instance_scheduling_decisions()
assert len(decision.schedulable_tis) == 2

# We insert a faulty record
session.add(TaskInstance(dag.get_task("task_2"), dr.execution_date, dr.run_id))
session.flush()

decision = dr.task_instance_scheduling_decisions()
assert len(decision.schedulable_tis) == 2


def test_mapped_literal_length_with_no_change_at_runtime_doesnt_call_verify_integrity(dag_maker, session):
"""
Test that when there's no change to mapped task indexes at runtime, the dagrun.verify_integrity
Expand Down
51 changes: 51 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,57 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec
assert indices == expected


def test_expand_mapped_task_failed_state_in_db(dag_maker, session):
"""
This test tries to recreate a faulty state in the database and checks if we can recover from it.
The state that happens is that there exists mapped task instances and the unmapped task instance.
So we have instances with map_index [-1, 0, 1]. The -1 task instances should be removed in this case.
"""
literal = [1, 2]
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output)

dr = dag_maker.create_dagrun()

session.add(
TaskMap(
dag_id=dr.dag_id,
task_id=task1.task_id,
run_id=dr.run_id,
map_index=-1,
length=len(literal),
keys=None,
)
)

for index in range(2):
# Give the existing TIs a state to make sure we don't change them
ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS)
session.add(ti)
session.flush()

indices = (
session.query(TaskInstance.map_index, TaskInstance.state)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
.order_by(TaskInstance.map_index)
.all()
)
# Make sure we have the faulty state in the database
assert indices == [(-1, None), (0, "success"), (1, "success")]

mapped.expand_mapped_task(dr.run_id, session=session)

indices = (
session.query(TaskInstance.map_index, TaskInstance.state)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
.order_by(TaskInstance.map_index)
.all()
)
# The -1 index should be cleaned up
assert indices == [(0, "success"), (1, "success")]


def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
Expand Down

0 comments on commit d34bd50

Please sign in to comment.