diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index df9c44910fa20..838c68102f4e1 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -38,6 +38,7 @@ from airflow.models.expandinput import NotFullyPopulated from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.definitions.dag import DAG, _run_task +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import ParamsDict from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.serialized_objects import DagSerialization @@ -429,6 +430,36 @@ def task_test(args, dag: DAG | None = None) -> None: passed_in_params = json.loads(args.task_params) sdk_task.params.update(passed_in_params) + if isinstance(sdk_task, MappedOperator): + from airflow.sdk.definitions._internal.expandinput import ( + DictOfListsExpandInput, + ListOfDictsExpandInput, + ) + + expand_input = sdk_task._get_specified_expand_input() + + expand_input_attr = sdk_task._expand_input_attr + if isinstance(expand_input, DictOfListsExpandInput): + new_expand_input_dict = dict(expand_input.value) + for k in expand_input.value: + if k in passed_in_params: + new_param = passed_in_params[k] + new_expand_input_dict[k] = [new_param] * (args.map_index + 1) + setattr(sdk_task, expand_input_attr, DictOfListsExpandInput(new_expand_input_dict)) + + if isinstance(expand_input, ListOfDictsExpandInput): + new_expand_input_list = expand_input.value + if not isinstance(new_expand_input_list, list): + new_expand_input_list = [passed_in_params] * (args.map_index + 1) + + current_mapping = new_expand_input_list[args.map_index] + if isinstance(current_mapping, dict): + new_expand_input_list[args.map_index] = {**current_mapping, **passed_in_params} + else: + new_expand_input_list[args.map_index] = passed_in_params + + setattr(sdk_task, expand_input_attr, ListOfDictsExpandInput(new_expand_input_list)) + if sdk_task.params and isinstance(sdk_task.params, ParamsDict): sdk_task.params.validate() diff --git a/airflow-core/tests/unit/cli/commands/test_task_command.py b/airflow-core/tests/unit/cli/commands/test_task_command.py index f39a8409bc243..5945079609613 100644 --- a/airflow-core/tests/unit/cli/commands/test_task_command.py +++ b/airflow-core/tests/unit/cli/commands/test_task_command.py @@ -43,6 +43,8 @@ from airflow.models.dagbag import DBDagBag from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.bash import BashOperator +from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk import task from airflow.serialization.serialized_objects import DagSerialization, LazyDeserializedDAG from airflow.utils.session import create_session from airflow.utils.state import State, TaskInstanceState @@ -451,6 +453,129 @@ def test_mapped_task_render_with_template(self, dag_maker): assert 'echo "2022-01-01"' in output assert 'echo "2022-01-08"' in output + @pytest.mark.parametrize( + ("task_id", "task_params", "expected_output"), + [ + pytest.param("consumer", '{"op_args": [10]}', "output=10", id="xcom_args_override"), + pytest.param("consumer_literal", '{"op_args": [10]}', "output=10", id="literal_args_override"), + pytest.param("consumer_literal", "", "output=3", id="literal_args_remain_default"), + ], + ) + def test_mapped_task_test_with_expand_args_override( + self, task_id, task_params, expected_output, request, capsys, dag_maker + ): + dag_id = f"test_mapped_dag_{request.node.callspec.id}" + with dag_maker(dag_id) as dag: + + @task + def produce(): + return [[1], [2], [3]] + + def consume(value): + print(f"output={value}") + + PythonOperator.partial(task_id="consumer", python_callable=consume).expand(op_args=produce()) + PythonOperator.partial(task_id="consumer_literal", python_callable=consume).expand( + op_args=[[1], [2], [3]] + ) + + dr = dag_maker.create_dagrun( + run_id="test_run", + state=State.RUNNING, + logical_date=DEFAULT_DATE, + run_type=DagRunType.MANUAL, + triggered_by=DagRunTriggeredByType.CLI, + ) + + args = ["tasks", "test", dag_id, task_id, DEFAULT_DATE.isoformat(), "--map-index", "2"] + if task_params: + args += ["--task-params", task_params] + + task_command.task_test(self.parser.parse_args(args), dag=dag) + + ti = dr.get_task_instance(task_id=task_id, map_index=2) + assert ti is not None + assert ti.state == State.SUCCESS + captured = capsys.readouterr() + assert expected_output in captured.out + + @pytest.mark.parametrize( + ("task_id", "task_params", "expected_output"), + [ + pytest.param( + "consumer", + '{"op_args": [10], "op_kwargs": {"value2": 20}}', + "output=(10, 20)", + id="xcom_kwargs_full_override", + ), + pytest.param( + "consumer_literal", + '{"op_args": [10], "op_kwargs": {"value2": 20}}', + "output=(10, 20)", + id="literal_kwargs_full_override", + ), + pytest.param( + "consumer_literal", + '{"op_kwargs": {"value2": 10}}', + "output=(3, 10)", + id="literal_kwargs_partial_override", + ), + pytest.param( + "consumer_literal", + "", + "output=(3, 2)", + id="literal_kwargs_remain_default", + ), + ], + ) + def test_mapped_task_test_with_expand_kwargs_override( + self, task_id, task_params, expected_output, request, capsys, dag_maker + ): + dag_id = f"test_mapped_kwargs_dag_{request.node.callspec.id}" + with dag_maker(dag_id) as dag: + + @task + def produce_kwargs(): + return [ + {"op_args": [1], "op_kwargs": {"value2": 2}}, + {"op_args": [2], "op_kwargs": {"value2": 2}}, + {"op_args": [3], "op_kwargs": {"value2": 2}}, + ] + + def consume(value1, value2=2): + print(f"output=({value1}, {value2})") + + PythonOperator.partial(task_id="consumer", python_callable=consume).expand_kwargs( + produce_kwargs() + ) + PythonOperator.partial(task_id="consumer_literal", python_callable=consume).expand_kwargs( + [ + {"op_args": [1], "op_kwargs": {"value2": 2}}, + {"op_args": [2], "op_kwargs": {"value2": 2}}, + {"op_args": [3], "op_kwargs": {"value2": 2}}, + ] + ) + + dr = dag_maker.create_dagrun( + run_id="test_run", + state=State.RUNNING, + logical_date=DEFAULT_DATE, + run_type=DagRunType.MANUAL, + triggered_by=DagRunTriggeredByType.CLI, + ) + + args = ["tasks", "test", dag_id, task_id, DEFAULT_DATE.isoformat(), "--map-index", "2"] + if task_params: + args += ["--task-params", task_params] + + task_command.task_test(self.parser.parse_args(args), dag=dag) + + ti = dr.get_task_instance(task_id=task_id, map_index=2) + assert ti is not None + assert ti.state == State.SUCCESS + captured = capsys.readouterr() + assert expected_output in captured.out + def test_task_state(self): task_command.task_state( self.parser.parse_args(