Skip to content

Commit

Permalink
Implement TransformedXComArg
Browse files Browse the repository at this point in the history
Some extra mechanism is added so we can continue to use XComArg like we
do right now, but also prevent TransformedXComArg from needing to
inherit a lot of the old XComArg nonsense like __str__ and __getitem__.
  • Loading branch information
uranusjr committed Jul 22, 2022
1 parent 23108c8 commit 28a06e4
Showing 1 changed file with 156 additions and 63 deletions.
219 changes: 156 additions & 63 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Union
#
from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Sequence, Type, Union, overload

from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import AbstractOperator
Expand All @@ -32,15 +33,14 @@


class XComArg(DependencyMixin):
"""
Class that represents a XCom push from a previous operator.
Defaults to "return_value" as only key.
"""Reference to an XCom value pushed from another operator.
The implementation supports::
Current implementation supports
xcomarg >> op
xcomarg << op
op >> xcomarg (by BaseOperator code)
op << xcomarg (by BaseOperator code)
op >> xcomarg # By BaseOperator code
op << xcomarg # By BaseOperator code
**Example**: The moment you get a result from any operator (decorated or regular) you can ::
Expand All @@ -53,29 +53,129 @@ class XComArg(DependencyMixin):
This object can be used in legacy Operators via Jinja.
**Example**: You can make this result to be part of any generated string ::
**Example**: You can make this result to be part of any generated string::
any_op = AnyOperator()
xcomarg = any_op.output
op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
:param operator: operator to which the XComArg belongs to
:param key: key value which is used for xcom_pull (key in the XCom table)
:param operator: Operator instance to which the XComArg references.
:param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
i.e. the referenced operator's return value.
"""

operator: "Operator"
key: str

@overload
def __new__(cls: Type["XComArg"], operator: "Operator", key: str = XCOM_RETURN_KEY) -> "XComArg":
"""Called when the user writes ``XComArg(...)`` directly."""

@overload
def __new__(cls: Type["XComArg"]) -> "XComArg":
"""Called by Python internals from subclasses."""

def __new__(cls, *args, **kwargs) -> "XComArg":
if cls is XComArg:
return PlainXComArg(*args, **kwargs)
return super().__new__(cls)

@staticmethod
def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
"""Return XComArg instances in an arbitrary value.
Recursively traverse ``arg`` and look for XComArg instances in any
collection objects, and instances with ``template_fields`` set.
"""
if isinstance(arg, XComArg):
yield arg
elif isinstance(arg, (tuple, set, list)):
for elem in arg:
yield from XComArg.iter_xcom_args(elem)
elif isinstance(arg, dict):
for elem in arg.values():
yield from XComArg.iter_xcom_args(elem)
elif isinstance(arg, AbstractOperator):
for elem in arg.template_fields:
yield from XComArg.iter_xcom_args(elem)

@staticmethod
def apply_upstream_relationship(op: "Operator", arg: Any):
"""Set dependency for XComArgs.
This looks for XComArg objects in ``arg`` "deeply" (looking inside
collections objects and classes decorated with ``template_fields``), and
sets the relationship to ``op`` on any found.
"""
for ref in XComArg.iter_xcom_args(arg):
op.set_upstream(ref.operator)

@property
def roots(self) -> List[DAGNode]:
"""Required by TaskMixin"""
return [self.operator]

@property
def leaves(self) -> List[DAGNode]:
"""Required by TaskMixin"""
return [self.operator]

def set_upstream(
self,
task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
edge_modifier: Optional[EdgeModifier] = None,
):
"""Proxy to underlying operator set_upstream method. Required by TaskMixin."""
self.operator.set_upstream(task_or_task_list, edge_modifier)

def set_downstream(
self,
task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
edge_modifier: Optional[EdgeModifier] = None,
):
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
self.operator.set_downstream(task_or_task_list, edge_modifier)

def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
raise NotImplementedError()

def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
raise NotImplementedError()


class PlainXComArg(XComArg):
"""Reference to one single XCom without any additional semantics.
This class should not be accessed directly, but only through XComArg. The
class inheritance chain and ``__new__`` is implemented in this slightly
convoluted way because we want to
a. Allow the user to continue using XComArg directly for the simple
semantics (see documentation of the base class for details).
b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
references.
c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
``__str__``) to exist on other kinds of XComArg implementations since
they don't make sense.
:meta private:
"""

def __init__(self, operator: "Operator", key: str = XCOM_RETURN_KEY):
self.operator = operator
self.key = key

def __eq__(self, other):
if not isinstance(other, PlainXComArg):
return NotImplemented
return self.operator == other.operator and self.key == other.key

def __getitem__(self, item: str) -> "XComArg":
"""Implements xcomresult['some_result_key']"""
if not isinstance(item, str):
raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
return XComArg(operator=self.operator, key=item)
return PlainXComArg(operator=self.operator, key=item)

def __iter__(self):
"""Override iterable protocol to raise error explicitly.
Expand All @@ -89,7 +189,7 @@ def __iter__(self):
This override catches the error eagerly, so an incorrectly implemented
DAG fails fast and avoids wasting resources on nonsensical iterating.
"""
raise TypeError(f"{self.__class__.__name__!r} object is not iterable")
raise TypeError("'XComArg' object is not iterable")

def __str__(self):
"""
Expand All @@ -113,31 +213,10 @@ def __str__(self):
xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
return xcom_pull

@property
def roots(self) -> List[DAGNode]:
"""Required by TaskMixin"""
return [self.operator]

@property
def leaves(self) -> List[DAGNode]:
"""Required by TaskMixin"""
return [self.operator]

def set_upstream(
self,
task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
edge_modifier: Optional[EdgeModifier] = None,
):
"""Proxy to underlying operator set_upstream method. Required by TaskMixin."""
self.operator.set_upstream(task_or_task_list, edge_modifier)

def set_downstream(
self,
task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
edge_modifier: Optional[EdgeModifier] = None,
):
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
self.operator.set_downstream(task_or_task_list, edge_modifier)
def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
if self.key != XCOM_RETURN_KEY:
raise ValueError
return MapXComArg(self, [f])

@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
Expand All @@ -155,32 +234,46 @@ def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
)
return result

@staticmethod
def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
"""Return XComArg instances in an arbitrary value.

This recursively traverse ``arg`` and look for XComArg instances in any
collection objects, and instances with ``template_fields`` set.
"""
if isinstance(arg, XComArg):
yield arg
elif isinstance(arg, (tuple, set, list)):
for elem in arg:
yield from XComArg.iter_xcom_args(elem)
elif isinstance(arg, dict):
for elem in arg.values():
yield from XComArg.iter_xcom_args(elem)
elif isinstance(arg, AbstractOperator):
for elem in arg.template_fields:
yield from XComArg.iter_xcom_args(elem)
class _MapResult(Sequence):
def __init__(self, value: Union[Sequence, dict], callables: Sequence[Callable[[Any], Any]]) -> None:
self.value = value
self.callables = callables

@staticmethod
def apply_upstream_relationship(op: "Operator", arg: Any):
"""Set dependency for XComArgs.
def __getitem__(self, index: Any) -> Any:
value = self.value[index]
for f in self.callables:
value = f(value)
return value

This looks for XComArg objects in ``arg`` "deeply" (looking inside
collections objects and classes decorated with ``template_fields``), and
sets the relationship to ``op`` on any found.
"""
for ref in XComArg.iter_xcom_args(arg):
op.set_upstream(ref.operator)
def __len__(self) -> int:
return len(self.value)


class MapXComArg(XComArg):
"""An XCom reference with ``map()`` call(s) applied.
This is based on an XComArg, but also applies a series of "transforms" that
convert the pulled XCom value.
"""

def __init__(self, arg: PlainXComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
self.arg = arg
self.callables = callables

@property
def operator(self) -> "Operator": # type: ignore[override]
return self.arg.operator

@property
def key(self) -> str: # type: ignore[override]
return self.arg.key

def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
return MapXComArg(self.arg, [*self.callables, f])

@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
value = self.arg.resolve(context, session=session)
assert isinstance(value, (Sequence, dict)) # Validation was done when XCom was pushed.
return _MapResult(value, self.callables)

0 comments on commit 28a06e4

Please sign in to comment.