diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index f9751684e35cf..6cacd239b5c68 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -1718,6 +1718,56 @@ def check_task_2(my_input): mock_task_object_1.assert_called() mock_task_object_2.assert_not_called() + @pytest.mark.parametrize( + ("mark_success_pattern", "expected_mapped_tasks_num"), + [ + ("make_items|process_item", 1), + ("process_item", 3), + ], + ) + def test_dag_test_mark_success_pattern_allows_mapped_task( + self, + mark_success_pattern, + expected_mapped_tasks_num, + testing_dag_bundle, + ): + upstream_callable = mock.MagicMock() + mapped_callable = mock.MagicMock() + + @task_decorator + def make_items(): + upstream_callable() + return [1, 2, 3] + + with DAG( + dag_id="test_dag", + schedule=None, + start_date=DEFAULT_DATE, + ) as dag: + PythonOperator.partial( + task_id="process_item", + python_callable=mapped_callable, + ).expand(op_args=make_items()) + + sync_dag_to_db(dag) + + dr = dag.test(mark_success_pattern=mark_success_pattern) + + assert dr.state == DagRunState.SUCCESS + + upstream_ti = dr.get_task_instance("make_items") + assert upstream_ti is not None + assert upstream_ti.state == TaskInstanceState.SUCCESS + if "make_items" in mark_success_pattern: + upstream_callable.assert_not_called() + else: + upstream_callable.assert_called() + + mapped_tis = [ti for ti in dr.get_task_instances() if ti.task_id == "process_item"] + assert len(mapped_tis) == expected_mapped_tasks_num + assert all(ti.state == TaskInstanceState.SUCCESS for ti in mapped_tis) + mapped_callable.assert_not_called() + def test_dag_connection_file(self, tmp_path, testing_dag_bundle): test_connections_string = """ --- diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index fd47de467fea4..5f067cb4e33cf 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1388,6 +1388,24 @@ def test( executor = ExecutorLoader.get_default_executor() executor.start() + def is_marked_success(ti): + return ( + re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None + if mark_success_pattern is not None + else False + ) + + # If a mapped TI is marked successful, pre-expand it to a single mapped index + # without waiting for the upstream XCom value. + # If the upstream dependencies later run and produce an XCom during this Dag run, + # _revise_map_indexes_if_mapped() will update the mapped TI count to match the + # upstream XCom length. + for ti in dr.get_task_instances(session=session): + task = self.task_dict[ti.task_id] + if task.is_mapped and is_marked_success(ti): + ti.map_index = 0 + session.commit() + while dr.state == DagRunState.RUNNING: session.expire_all() schedulable_tis, _ = dr.update_state(session=session) @@ -1408,12 +1426,6 @@ def test( for ti in scheduled_tis: task = self.task_dict[ti.task_id] - mark_success = ( - re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None - if mark_success_pattern is not None - else False - ) - if use_executor: if executor.has_task(ti): continue @@ -1440,7 +1452,7 @@ def test( else: # Run the task locally try: - if mark_success: + if is_marked_success(ti): ti.set_state(TaskInstanceState.SUCCESS) log.info("[DAG TEST] Marking success for %s on %s", task, ti.logical_date) else: