Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

import collections.abc
import functools
import inspect
import re
Expand All @@ -39,7 +38,6 @@

import attr
import typing_extensions
from sqlalchemy.orm import Session

from airflow.compat.functools import cache, cached_property
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -68,6 +66,9 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
import jinja2 # Slow import.
from sqlalchemy.orm import Session

from airflow.models.mappedoperator import Mappable


Expand Down Expand Up @@ -430,6 +431,7 @@ def _get_unmap_kwargs(self) -> Dict[str, Any]:
self.mapped_op_kwargs,
fail_reason="mapping already partial",
)
self._combined_op_kwargs = op_kwargs
return {
"dag": self.dag,
"task_group": self.task_group,
Expand All @@ -441,13 +443,38 @@ def _get_unmap_kwargs(self) -> Dict[str, Any]:
**self.mapped_kwargs,
}

def _expand_mapped_field(self, key: str, content: Any, context: Context, *, session: Session) -> Any:
if key != "op_kwargs" or not isinstance(content, collections.abc.Mapping):
return content
# The magic super() doesn't work here, so we use the explicit form.
# Not using super(..., self) to work around pyupgrade bug.
sup: Any = super(DecoratedMappedOperator, DecoratedMappedOperator)
return {k: sup._expand_mapped_field(self, k, v, context, session=session) for k, v in content.items()}
def _resolve_expansion_kwargs(
self, kwargs: Dict[str, Any], template_fields: Set[str], context: Context, session: "Session"
) -> None:
expansion_kwargs = self._get_expansion_kwargs()

self._already_resolved_op_kwargs = set()
for k, v in expansion_kwargs.items():
if isinstance(v, XComArg):
self._already_resolved_op_kwargs.add(k)
v = v.resolve(context, session=session)
v = self._expand_mapped_field(k, v, context, session=session)
kwargs['op_kwargs'][k] = v
template_fields.discard(k)

def render_template(
self,
value: Any,
context: Context,
jinja_env: Optional["jinja2.Environment"] = None,
seen_oids: Optional[Set] = None,
) -> Any:
if hasattr(self, '_combined_op_kwargs') and value is self._combined_op_kwargs:
# Avoid rendering values that came out of resolved XComArgs
return {
k: v
if k in self._already_resolved_op_kwargs
else super(DecoratedMappedOperator, DecoratedMappedOperator).render_template(
self, v, context, jinja_env=jinja_env, seen_oids=seen_oids
)
for k, v in value.items()
}
return super().render_template(value, context, jinja_env=jinja_env, seen_oids=seen_oids)


class Task(Generic[Function]):
Expand Down
5 changes: 0 additions & 5 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ def signal_handler(signum, frame):
try:
self.task_runner.start()

# Unmap the task _after_ it has forked/execed. (This is a bit of a kludge, but if we unmap before
# fork, then the "run_raw_task" command will see the mapping index and an Non-mapped task and
# fail)
self.task_instance.task = self.task_instance.task.unmap()

heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

# task callback invocation happens either here or in
Expand Down
22 changes: 3 additions & 19 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
Union,
)

from sqlalchemy.orm import Session

from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand All @@ -52,6 +50,7 @@

if TYPE_CHECKING:
import jinja2 # Slow import.
from sqlalchemy.orm import Session

from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DAG
Expand Down Expand Up @@ -330,7 +329,7 @@ def _do_render_template_fields(
jinja_env: "jinja2.Environment",
seen_oids: Set,
*,
session: Session = NEW_SESSION,
session: "Session" = NEW_SESSION,
) -> None:
for attr_name in template_fields:
try:
Expand All @@ -342,29 +341,14 @@ def _do_render_template_fields(
)
if not value:
continue
rendered_content = self._render_template_field(
attr_name,
rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
session=session,
)
setattr(parent, attr_name, rendered_content)

def _render_template_field(
self,
key: str,
value: Any,
context: Context,
jinja_env: Optional["jinja2.Environment"] = None,
seen_oids: Optional[Set] = None,
*,
session: Session,
) -> Any:
"""Override point for MappedOperator to perform further resolution."""
return self.render_template(value, context, jinja_env, seen_oids)

def render_template(
self,
content: Any,
Expand Down
75 changes: 53 additions & 22 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from sqlalchemy import func, or_
from sqlalchemy.orm.session import Session

from airflow import settings
from airflow.compat.functools import cache, cached_property
from airflow.exceptions import AirflowException, UnmappableOperator
from airflow.models.abstractoperator import (
Expand Down Expand Up @@ -473,6 +474,7 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
return DagAttributeTypes.OP, self.task_id

def _get_unmap_kwargs(self) -> Dict[str, Any]:

return {
"task_id": self.task_id,
"dag": self.dag,
Expand All @@ -484,14 +486,26 @@ def _get_unmap_kwargs(self) -> Dict[str, Any]:
**self.mapped_kwargs,
}

def unmap(self) -> "BaseOperator":
"""Get the "normal" Operator after applying the current mapping."""
def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> "BaseOperator":
"""
Get the "normal" Operator after applying the current mapping.

If ``operator_class`` is not a class (i.e. this DAG has been deserialized) then this will return a
SerializedBaseOperator that aims to "look like" the real operator.

:param unmap_kwargs: Override the args to pass to the Operator constructor. Only used when
``operator_class`` is still an actual class.

:meta private:
"""
if isinstance(self.operator_class, type):
# We can't simply specify task_id here because BaseOperator further
# mangles the task_id based on the task hierarchy (namely, group_id
# is prepended, and '__N' appended to deduplicate). Instead of
# recreating the whole logic here, we just overwrite task_id later.
op = self.operator_class(**self._get_unmap_kwargs(), _airflow_from_mapped=True)
if unmap_kwargs is None:
unmap_kwargs = self._get_unmap_kwargs()
op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True)
op.task_id = self.task_id
return op

Expand Down Expand Up @@ -569,6 +583,7 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]:
map_lengths[mapped_arg_name] += length
return map_lengths

@cache
def _resolve_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]:
Comment on lines +586 to 587
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’ll be difficult to actually hit this cache due to session :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Session has a longer life time than you might expect due to SQLA's Pooling

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [2]: @provide_session
   ...: def a(session):
   ...:     print(id(session))
   ...:     b()

In [4]: @provide_session
   ...: def b(session):
   ...:     print(id(session))

In [5]: a()
140533469520240
140533469520240

"""Return dict of argument name to map length, or throw if some are not resolvable"""
expansion_kwargs = self._get_expansion_kwargs()
Expand Down Expand Up @@ -686,33 +701,49 @@ def render_template_fields(
"""
if not jinja_env:
jinja_env = self.get_template_env()
unmapped_task = self.unmap()
# Before we unmap we have to resolve the mapped arguments, otherwise the real operator constructor
# could be called with an XComArg, rather than the value it resolves to.
#
# We also need to resolve _all_ mapped arguments, even if they aren't marked as templated
kwargs = self._get_unmap_kwargs()

template_fields = set(self.template_fields)

# Ideally we'd like to pass in session as an argument to this function, but since operators _could_
# override this we can't easily change this function signature.
# We can't use @provide_session, as that closes and expunges everything, which we don't want to do
# when we are so "deep" in the weeds here.
#
# Nor do we want to close the session -- that would expunge all the things from the internal cache
# which we don't want to do either
session = settings.Session()
self._resolve_expansion_kwargs(kwargs, template_fields, context, session)

unmapped_task = self.unmap(unmap_kwargs=kwargs)
self._do_render_template_fields(
parent=unmapped_task,
template_fields=unmapped_task.template_fields,
template_fields=template_fields,
Comment on lines -692 to +725
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope template_fields being unordered won’t cause bugs in user code.

context=context,
jinja_env=jinja_env,
seen_oids=set(),
session=session,
)
return unmapped_task

def _render_template_field(
self,
key: str,
value: Any,
context: Context,
jinja_env: Optional["jinja2.Environment"] = None,
seen_oids: Optional[Set] = None,
*,
session: Session,
) -> Any:
"""Override the ordinary template rendering to add more logic.

Specifically, if we're rendering a mapped argument, we need to "unmap"
the value as well to assign it to the unmapped operator.
"""
value = super()._render_template_field(key, value, context, jinja_env, seen_oids, session=session)
return self._expand_mapped_field(key, value, context, session=session)
def _resolve_expansion_kwargs(
self, kwargs: Dict[str, Any], template_fields: Set[str], context: Context, session: Session
) -> None:
"""Update mapped fields in place in kwargs dict"""
from airflow.models.xcom_arg import XComArg

expansion_kwargs = self._get_expansion_kwargs()

for k, v in expansion_kwargs.items():
if isinstance(v, XComArg):
v = v.resolve(context, session=session)
v = self._expand_mapped_field(k, v, context, session=session)
template_fields.discard(k)
kwargs[k] = v

def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any:
map_index = context["ti"].map_index
Expand Down
10 changes: 8 additions & 2 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.operator import Operator


Expand Down Expand Up @@ -136,12 +139,15 @@ def set_downstream(
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
self.operator.set_downstream(task_or_task_list, edge_modifier)

def resolve(self, context: Context) -> Any:
@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
Comment on lines +142 to +143
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC I wondered if we could do this previously, but eventually did not due to interface compatibility issues.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What interface issues might that be?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don’t quite remember now, maybe it has something to do with DagParam also has resolve. I guess we can do it and see what happens.

"""
Pull XCom value for the existing arg. This method is run during ``op.execute()``
in respectable context.
"""
result = context["ti"].xcom_pull(task_ids=self.operator.task_id, key=str(self.key), default=NOTSET)
result = context["ti"].xcom_pull(
task_ids=self.operator.task_id, key=str(self.key), default=NOTSET, session=session
)
if result is NOTSET:
raise AirflowException(
f'XComArg result from {self.operator.task_id} at {context["ti"].dag_id} '
Expand Down
36 changes: 36 additions & 0 deletions docs/apache-airflow/concepts/dynamic-task-mapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,42 @@ Currently it is only possible to map against a dict, a list, or one of those typ

If an upstream task returns an unmappable type, the mapped task will fail at run-time with an ``UnmappableXComTypePushed`` exception. For instance, you can't have the upstream task return a plain string – it must be a list or a dict.

How do templated fields and mapped arguments interact?
======================================================

All arguments to an operator can be mapped, even those that do not accept templated parameters.

If a field is marked as being templated and is mapped, it **will not be templated**.

For example, this will print ``{{ ds }}`` and not a date stamp:

.. code-block:: python

@task
def make_list():
return ["{{ ds }}"]


@task
def printer(val):
print(val)


printer.expand(val=make_list())

If you want to interpolate values either call ``task.render_template`` yourself, or use interpolation:

.. code-block:: python

@task
def make_list(ds):
return [ds]


@task
def make_list(**context):
return [context["task"].render_template("{{ ds }}", context)]

Placing limits on mapped tasks
==============================

Expand Down
40 changes: 40 additions & 0 deletions tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from airflow.decorators.base import DecoratedMappedOperator
from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.models.xcom_arg import XComArg
from airflow.utils import timezone
from airflow.utils.state import State
Expand Down Expand Up @@ -641,3 +645,39 @@ def task2(arg1, arg2):
mapped_task1 = dag.get_task("task1")
assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30) # Operator default.
mapped_task1.unmap().retry_delay == timedelta(seconds=300) # Operator default.


def test_mapped_render_template_fields(dag_maker, session):
@task_decorator
def fn(arg1, arg2):
...

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcom_arg = XComArg(task1)
mapped = fn.partial(arg2='{{ ti.task_id }}').expand(arg1=xcom_arg)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)

ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)

session.add(
TaskMap(
dag_id=dr.dag_id,
task_id=task1.task_id,
run_id=dr.run_id,
map_index=-1,
length=1,
keys=None,
)
)
session.flush()

mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session)
mapped_ti.map_index = 0
op = mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert op

assert op.op_kwargs['arg1'] == "{{ ds }}"
assert op.op_kwargs['arg2'] == "fn"
Loading