Skip to content

Commit

Permalink
Support multiple XCom output in the BaseOperator (apache#37297)
Browse files Browse the repository at this point in the history
* Support multiple XCom output in the BaseOperator

* consolidate task_flow and normal operator multiple_outputs

* revert sftp provider change
  • Loading branch information
hussein-awala authored and sunank200 committed Feb 21, 2024
1 parent 7110fa2 commit 55ddb2c
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 22 deletions.
20 changes: 0 additions & 20 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import typing_extensions

from airflow.datasets import Dataset
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
from airflow.models.baseoperator import (
BaseOperator,
Expand Down Expand Up @@ -195,7 +194,6 @@ def __init__(
task_id: str,
op_args: Collection[Any] | None = None,
op_kwargs: Mapping[str, Any] | None = None,
multiple_outputs: bool = False,
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -227,7 +225,6 @@ def __init__(
else:
signature.bind(*op_args, **op_kwargs)

self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
Expand Down Expand Up @@ -257,23 +254,6 @@ def _handle_output(self, return_value: Any, context: Context, xcom_push: Callabl
for item in return_value:
if isinstance(item, Dataset):
self.outlets.append(item)
if not self.multiple_outputs or return_value is None:
return return_value
if isinstance(return_value, dict):
for key in return_value.keys():
if not isinstance(key, str):
raise AirflowException(
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for key, value in return_value.items():
if isinstance(value, Dataset):
self.outlets.append(value)
xcom_push(context, key, value)
else:
raise AirflowException(
f"Returned output was type {type(return_value)} expected dictionary for multiple_outputs"
)
return return_value

def _hook_apply_defaults(self, *args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion airflow/decorators/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
task_id: str,
**kwargs,
) -> None:
kwargs.pop("multiple_outputs")
kwargs["task_id"] = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
super().__init__(**kwargs)

Expand Down
5 changes: 5 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,8 @@ class derived from this one results in the creation of a task object,
:param do_xcom_push: if True, an XCom is pushed containing the Operator's
result
:param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each
key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom.
:param task_group: The TaskGroup to which the task should belong. This is typically provided when not
using a TaskGroup as a context manager.
:param doc: Add documentation or notes to your Task objects that is visible in
Expand Down Expand Up @@ -713,6 +715,7 @@ class derived from this one results in the creation of a task object,
"on_retry_callback",
"on_skipped_callback",
"do_xcom_push",
"multiple_outputs",
}

# Defines if the operator supports lineage without manual definitions
Expand Down Expand Up @@ -782,6 +785,7 @@ def __init__(
max_active_tis_per_dagrun: int | None = None,
executor_config: dict | None = None,
do_xcom_push: bool = True,
multiple_outputs: bool = False,
inlets: Any | None = None,
outlets: Any | None = None,
task_group: TaskGroup | None = None,
Expand Down Expand Up @@ -929,6 +933,7 @@ def __init__(
self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
self.do_xcom_push: bool = do_xcom_push
self.multiple_outputs: bool = multiple_outputs

self.doc_md = doc_md
self.doc_json = doc_json
Expand Down
16 changes: 15 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from collections import defaultdict
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -450,6 +450,20 @@ def _execute_callable(context, **execute_callable_kwargs):
else:
xcom_value = None
if xcom_value is not None: # If the task returns a result, push an XCom containing it.
if task_to_execute.multiple_outputs:
if not isinstance(xcom_value, Mapping):
raise AirflowException(
f"Returned output was type {type(xcom_value)} "
"expected dictionary for multiple_outputs"
)
for key in xcom_value.keys():
if not isinstance(key, str):
raise AirflowException(
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for key, value in xcom_value.items():
task_instance.xcom_push(key=key, value=value, session=session)
task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session)
_record_task_map_for_downstreams(
task_instance=task_instance, task=task_orig, value=xcom_value, session=session
Expand Down
2 changes: 2 additions & 0 deletions docs/apache-airflow/core-concepts/xcoms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ You can also use XComs in :ref:`templates <concepts:jinja-templating>`::

XComs are a relative of :doc:`variables`, with the main difference being that XComs are per-task-instance and designed for communication within a DAG run, while Variables are global and designed for overall configuration and value sharing.

If you want to push multiple XComs at once or rename the pushed XCom key, you can use set ``do_xcom_push`` and ``multiple_outputs`` arguments to ``True``, and then return a dictionary of values.

.. note::

If the first task run is not succeeded then on every retry task XComs will be cleared to make the task run idempotent.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"email_on_retry",
"post_execute",
"pre_execute",
"multiple_outputs",
# Doesn't matter, not used anywhere.
"default_args",
# Deprecated and is aliased to max_active_tis_per_dag.
Expand Down
61 changes: 61 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,67 @@ def test_xcom_push_flag(self, dag_maker):
ti.run()
assert ti.xcom_pull(task_ids=task_id, key=XCOM_RETURN_KEY) is None

def test_xcom_without_multiple_outputs(self, dag_maker):
"""
Tests the option for Operators to push XComs without multiple outputs
"""
value = {"key1": "value1", "key2": "value2"}
task_id = "test_xcom_push_without_multiple_outputs"

with dag_maker(dag_id="test_xcom"):
task = PythonOperator(
task_id=task_id,
python_callable=lambda: value,
do_xcom_push=True,
)
ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
ti.task = task
ti.run()
assert ti.xcom_pull(task_ids=task_id, key=XCOM_RETURN_KEY) == value

def test_xcom_with_multiple_outputs(self, dag_maker):
"""
Tests the option for Operators to push XComs with multiple outputs
"""
value = {"key1": "value1", "key2": "value2"}
task_id = "test_xcom_push_with_multiple_outputs"

with dag_maker(dag_id="test_xcom"):
task = PythonOperator(
task_id=task_id,
python_callable=lambda: value,
do_xcom_push=True,
multiple_outputs=True,
)
ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
ti.task = task
ti.run()
assert ti.xcom_pull(task_ids=task_id, key=XCOM_RETURN_KEY) == value
assert ti.xcom_pull(task_ids=task_id, key="key1") == "value1"
assert ti.xcom_pull(task_ids=task_id, key="key2") == "value2"

def test_xcom_with_multiple_outputs_and_no_mapping_result(self, dag_maker):
"""
Tests the option for Operators to push XComs with multiple outputs and no mapping result
"""
value = "value"
task_id = "test_xcom_push_with_multiple_outputs"

with dag_maker(dag_id="test_xcom"):
task = PythonOperator(
task_id=task_id,
python_callable=lambda: value,
do_xcom_push=True,
multiple_outputs=True,
)
ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
ti.task = task
with pytest.raises(AirflowException) as ctx:
ti.run()
assert f"Returned output was type {type(value)} expected dictionary for multiple_outputs" in str(
ctx.value
)

def test_post_execute_hook(self, dag_maker):
"""
Test that post_execute hook is called with the Operator's result.
Expand Down
1 change: 1 addition & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,7 @@ def test_no_new_fields_added_to_base_operator(self):
"wait_for_downstream": False,
"wait_for_past_depends_before_skipping": False,
"weight_rule": "downstream",
"multiple_outputs": False,
}, """
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Expand Down

0 comments on commit 55ddb2c

Please sign in to comment.