Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions airflow-core/src/airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk.definitions.dag import DAG, _run_task
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.param import ParamsDict
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.serialized_objects import DagSerialization
Expand Down Expand Up @@ -429,6 +430,36 @@ def task_test(args, dag: DAG | None = None) -> None:
passed_in_params = json.loads(args.task_params)
sdk_task.params.update(passed_in_params)

if isinstance(sdk_task, MappedOperator):
from airflow.sdk.definitions._internal.expandinput import (
DictOfListsExpandInput,
ListOfDictsExpandInput,
)

expand_input = sdk_task._get_specified_expand_input()

expand_input_attr = sdk_task._expand_input_attr
if isinstance(expand_input, DictOfListsExpandInput):
new_expand_input_dict = dict(expand_input.value)
for k in expand_input.value:
if k in passed_in_params:
new_param = passed_in_params[k]
new_expand_input_dict[k] = [new_param] * (args.map_index + 1)
setattr(sdk_task, expand_input_attr, DictOfListsExpandInput(new_expand_input_dict))

if isinstance(expand_input, ListOfDictsExpandInput):
new_expand_input_list = expand_input.value
if not isinstance(new_expand_input_list, list):
new_expand_input_list = [passed_in_params] * (args.map_index + 1)

current_mapping = new_expand_input_list[args.map_index]
if isinstance(current_mapping, dict):
new_expand_input_list[args.map_index] = {**current_mapping, **passed_in_params}
else:
new_expand_input_list[args.map_index] = passed_in_params

setattr(sdk_task, expand_input_attr, ListOfDictsExpandInput(new_expand_input_list))

if sdk_task.params and isinstance(sdk_task.params, ParamsDict):
sdk_task.params.validate()

Expand Down
125 changes: 125 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from airflow.models.dagbag import DBDagBag
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import task
from airflow.serialization.serialized_objects import DagSerialization, LazyDeserializedDAG
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
Expand Down Expand Up @@ -451,6 +453,129 @@ def test_mapped_task_render_with_template(self, dag_maker):
assert 'echo "2022-01-01"' in output
assert 'echo "2022-01-08"' in output

@pytest.mark.parametrize(
("task_id", "task_params", "expected_output"),
[
pytest.param("consumer", '{"op_args": [10]}', "output=10", id="xcom_args_override"),
pytest.param("consumer_literal", '{"op_args": [10]}', "output=10", id="literal_args_override"),
pytest.param("consumer_literal", "", "output=3", id="literal_args_remain_default"),
],
)
def test_mapped_task_test_with_expand_args_override(
self, task_id, task_params, expected_output, request, capsys, dag_maker
):
dag_id = f"test_mapped_dag_{request.node.callspec.id}"
with dag_maker(dag_id) as dag:

@task
def produce():
return [[1], [2], [3]]

def consume(value):
print(f"output={value}")

PythonOperator.partial(task_id="consumer", python_callable=consume).expand(op_args=produce())
PythonOperator.partial(task_id="consumer_literal", python_callable=consume).expand(
op_args=[[1], [2], [3]]
)

dr = dag_maker.create_dagrun(
run_id="test_run",
state=State.RUNNING,
logical_date=DEFAULT_DATE,
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.CLI,
)

args = ["tasks", "test", dag_id, task_id, DEFAULT_DATE.isoformat(), "--map-index", "2"]
if task_params:
args += ["--task-params", task_params]

task_command.task_test(self.parser.parse_args(args), dag=dag)

ti = dr.get_task_instance(task_id=task_id, map_index=2)
assert ti is not None
assert ti.state == State.SUCCESS
captured = capsys.readouterr()
assert expected_output in captured.out

@pytest.mark.parametrize(
("task_id", "task_params", "expected_output"),
[
pytest.param(
"consumer",
'{"op_args": [10], "op_kwargs": {"value2": 20}}',
"output=(10, 20)",
id="xcom_kwargs_full_override",
),
pytest.param(
"consumer_literal",
'{"op_args": [10], "op_kwargs": {"value2": 20}}',
"output=(10, 20)",
id="literal_kwargs_full_override",
),
pytest.param(
"consumer_literal",
'{"op_kwargs": {"value2": 10}}',
"output=(3, 10)",
id="literal_kwargs_partial_override",
),
pytest.param(
"consumer_literal",
"",
"output=(3, 2)",
id="literal_kwargs_remain_default",
),
],
)
def test_mapped_task_test_with_expand_kwargs_override(
self, task_id, task_params, expected_output, request, capsys, dag_maker
):
dag_id = f"test_mapped_kwargs_dag_{request.node.callspec.id}"
with dag_maker(dag_id) as dag:

@task
def produce_kwargs():
return [
{"op_args": [1], "op_kwargs": {"value2": 2}},
{"op_args": [2], "op_kwargs": {"value2": 2}},
{"op_args": [3], "op_kwargs": {"value2": 2}},
]

def consume(value1, value2=2):
print(f"output=({value1}, {value2})")

PythonOperator.partial(task_id="consumer", python_callable=consume).expand_kwargs(
produce_kwargs()
)
PythonOperator.partial(task_id="consumer_literal", python_callable=consume).expand_kwargs(
[
{"op_args": [1], "op_kwargs": {"value2": 2}},
{"op_args": [2], "op_kwargs": {"value2": 2}},
{"op_args": [3], "op_kwargs": {"value2": 2}},
]
)

dr = dag_maker.create_dagrun(
run_id="test_run",
state=State.RUNNING,
logical_date=DEFAULT_DATE,
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.CLI,
)

args = ["tasks", "test", dag_id, task_id, DEFAULT_DATE.isoformat(), "--map-index", "2"]
if task_params:
args += ["--task-params", task_params]

task_command.task_test(self.parser.parse_args(args), dag=dag)

ti = dr.get_task_instance(task_id=task_id, map_index=2)
assert ti is not None
assert ti.state == State.SUCCESS
captured = capsys.readouterr()
assert expected_output in captured.out

def test_task_state(self):
task_command.task_state(
self.parser.parse_args(
Expand Down
Loading