Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions airflow-core/tests/unit/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,56 @@ def check_task_2(my_input):
mock_task_object_1.assert_called()
mock_task_object_2.assert_not_called()

@pytest.mark.parametrize(
("mark_success_pattern", "expected_mapped_tasks_num"),
[
("make_items|process_item", 1),
("process_item", 3),
],
)
def test_dag_test_mark_success_pattern_allows_mapped_task(
self,
mark_success_pattern,
expected_mapped_tasks_num,
testing_dag_bundle,
):
upstream_callable = mock.MagicMock()
mapped_callable = mock.MagicMock()

@task_decorator
def make_items():
upstream_callable()
return [1, 2, 3]

with DAG(
dag_id="test_dag",
schedule=None,
start_date=DEFAULT_DATE,
) as dag:
PythonOperator.partial(
task_id="process_item",
python_callable=mapped_callable,
).expand(op_args=make_items())

sync_dag_to_db(dag)

dr = dag.test(mark_success_pattern=mark_success_pattern)

assert dr.state == DagRunState.SUCCESS

upstream_ti = dr.get_task_instance("make_items")
assert upstream_ti is not None
assert upstream_ti.state == TaskInstanceState.SUCCESS
if "make_items" in mark_success_pattern:
upstream_callable.assert_not_called()
else:
upstream_callable.assert_called()

mapped_tis = [ti for ti in dr.get_task_instances() if ti.task_id == "process_item"]
assert len(mapped_tis) == expected_mapped_tasks_num
assert all(ti.state == TaskInstanceState.SUCCESS for ti in mapped_tis)
mapped_callable.assert_not_called()

def test_dag_connection_file(self, tmp_path, testing_dag_bundle):
test_connections_string = """
---
Expand Down
26 changes: 19 additions & 7 deletions task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,24 @@ def test(
executor = ExecutorLoader.get_default_executor()
executor.start()

def is_marked_success(ti):
return (
re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None
if mark_success_pattern is not None
else False
)

# If a mapped TI is marked successful, pre-expand it to a single mapped index
# without waiting for the upstream XCom value.
# If the upstream dependencies later run and produce an XCom during this Dag run,
# _revise_map_indexes_if_mapped() will update the mapped TI count to match the
# upstream XCom length.
for ti in dr.get_task_instances(session=session):
task = self.task_dict[ti.task_id]
if task.is_mapped and is_marked_success(ti):
ti.map_index = 0
session.commit()

while dr.state == DagRunState.RUNNING:
session.expire_all()
schedulable_tis, _ = dr.update_state(session=session)
Expand All @@ -1408,12 +1426,6 @@ def test(
for ti in scheduled_tis:
task = self.task_dict[ti.task_id]

mark_success = (
re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None
if mark_success_pattern is not None
else False
)

if use_executor:
if executor.has_task(ti):
continue
Expand All @@ -1440,7 +1452,7 @@ def test(
else:
# Run the task locally
try:
if mark_success:
if is_marked_success(ti):
ti.set_state(TaskInstanceState.SUCCESS)
log.info("[DAG TEST] Marking success for %s on %s", task, ti.logical_date)
else:
Expand Down
Loading