diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index acb943f1234bd..c61160b8c36ab 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -70,7 +70,7 @@ from airflow.utils.helpers import is_container from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet, DagRunType @@ -1022,7 +1022,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES ): dummy_ti_ids.append(ti.task_id) else: - schedulable_ti_ids.append(ti.task_id) + schedulable_ti_ids.append((ti.task_id, ti.map_index)) count = 0 @@ -1032,7 +1032,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES .filter( TI.dag_id == self.dag_id, TI.run_id == self.run_id, - TI.task_id.in_(schedulable_ti_ids), + tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids), ) .update({TI.state: State.SCHEDULED}, synchronize_session=False) ) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 091920fbd6846..094a0ecb548c0 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -1157,3 +1157,24 @@ def _task_ids(tis): decision = dr.task_instance_scheduling_decisions(session=session) assert decision.schedulable_tis == [] assert result == [2, 4] + + +def test_schedule_tis_map_index(dag_maker, session): + with dag_maker(session=session, dag_id="test"): + task = BaseOperator(task_id='task_1') + + dr = DagRun(dag_id="test", run_id="test", run_type=DagRunType.MANUAL) + ti0 = TI(task=task, run_id=dr.run_id, map_index=0, state=TaskInstanceState.SUCCESS) + ti1 = TI(task=task, run_id=dr.run_id, map_index=1, state=None) + ti2 = TI(task=task, run_id=dr.run_id, map_index=2, state=TaskInstanceState.SUCCESS) + session.add_all((dr, ti0, ti1, ti2)) + session.flush() + + assert dr.schedule_tis((ti1,), session=session) == 1 + + session.refresh(ti0) + session.refresh(ti1) + session.refresh(ti2) + assert ti0.state == TaskInstanceState.SUCCESS + assert ti1.state == TaskInstanceState.SCHEDULED + assert ti2.state == TaskInstanceState.SUCCESS