-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Resolve XComArgs before trying to unmap MappedOperators #22975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
35b4e6e
1fbf52a
c691af9
207d689
5607648
ec74b41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,7 @@ | |
| from sqlalchemy import func, or_ | ||
| from sqlalchemy.orm.session import Session | ||
|
|
||
| from airflow import settings | ||
| from airflow.compat.functools import cache, cached_property | ||
| from airflow.exceptions import AirflowException, UnmappableOperator | ||
| from airflow.models.abstractoperator import ( | ||
|
|
@@ -473,6 +474,7 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: | |
| return DagAttributeTypes.OP, self.task_id | ||
|
|
||
| def _get_unmap_kwargs(self) -> Dict[str, Any]: | ||
|
|
||
| return { | ||
| "task_id": self.task_id, | ||
| "dag": self.dag, | ||
|
|
@@ -484,14 +486,26 @@ def _get_unmap_kwargs(self) -> Dict[str, Any]: | |
| **self.mapped_kwargs, | ||
| } | ||
|
|
||
| def unmap(self) -> "BaseOperator": | ||
| """Get the "normal" Operator after applying the current mapping.""" | ||
| def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> "BaseOperator": | ||
| """ | ||
| Get the "normal" Operator after applying the current mapping. | ||
|
|
||
| If ``operator_class`` is not a class (i.e. this DAG has been deserialized) then this will return a | ||
| SerializedBaseOperator that aims to "look like" the real operator. | ||
|
|
||
| :param unmap_kwargs: Override the args to pass to the Operator constructor. Only used when | ||
| ``operator_class`` is still an actual class. | ||
|
|
||
| :meta private: | ||
| """ | ||
| if isinstance(self.operator_class, type): | ||
| # We can't simply specify task_id here because BaseOperator further | ||
| # mangles the task_id based on the task hierarchy (namely, group_id | ||
| # is prepended, and '__N' appended to deduplicate). Instead of | ||
| # recreating the whole logic here, we just overwrite task_id later. | ||
| op = self.operator_class(**self._get_unmap_kwargs(), _airflow_from_mapped=True) | ||
| if unmap_kwargs is None: | ||
| unmap_kwargs = self._get_unmap_kwargs() | ||
| op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True) | ||
| op.task_id = self.task_id | ||
| return op | ||
|
|
||
|
|
@@ -569,6 +583,7 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]: | |
| map_lengths[mapped_arg_name] += length | ||
| return map_lengths | ||
|
|
||
| @cache | ||
| def _resolve_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]: | ||
| """Return dict of argument name to map length, or throw if some are not resolvable""" | ||
| expansion_kwargs = self._get_expansion_kwargs() | ||
|
|
@@ -686,33 +701,49 @@ def render_template_fields( | |
| """ | ||
| if not jinja_env: | ||
| jinja_env = self.get_template_env() | ||
| unmapped_task = self.unmap() | ||
| # Before we unmap we have to resolve the mapped arguments, otherwise the real operator constructor | ||
| # could be called with an XComArg, rather than the value it resolves to. | ||
| # | ||
| # We also need to resolve _all_ mapped arguments, even if they aren't marked as templated | ||
| kwargs = self._get_unmap_kwargs() | ||
|
|
||
| template_fields = set(self.template_fields) | ||
|
|
||
| # Ideally we'd like to pass in session as an argument to this function, but since operators _could_ | ||
| # override this we can't easily change this function signature. | ||
| # We can't use @provide_session, as that closes and expunges everything, which we don't want to do | ||
| # when we are so "deep" in the weeds here. | ||
| # | ||
| # Nor do we want to close the session -- that would expunge all the things from the internal cache | ||
| # which we don't want to do either | ||
| session = settings.Session() | ||
| self._resolve_expansion_kwargs(kwargs, template_fields, context, session) | ||
|
|
||
| unmapped_task = self.unmap(unmap_kwargs=kwargs) | ||
| self._do_render_template_fields( | ||
| parent=unmapped_task, | ||
| template_fields=unmapped_task.template_fields, | ||
| template_fields=template_fields, | ||
|
Comment on lines
-692
to
+725
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope |
||
| context=context, | ||
| jinja_env=jinja_env, | ||
| seen_oids=set(), | ||
| session=session, | ||
| ) | ||
| return unmapped_task | ||
|
|
||
| def _render_template_field( | ||
uranusjr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| key: str, | ||
| value: Any, | ||
| context: Context, | ||
| jinja_env: Optional["jinja2.Environment"] = None, | ||
| seen_oids: Optional[Set] = None, | ||
| *, | ||
| session: Session, | ||
| ) -> Any: | ||
| """Override the ordinary template rendering to add more logic. | ||
|
|
||
| Specifically, if we're rendering a mapped argument, we need to "unmap" | ||
| the value as well to assign it to the unmapped operator. | ||
| """ | ||
| value = super()._render_template_field(key, value, context, jinja_env, seen_oids, session=session) | ||
| return self._expand_mapped_field(key, value, context, session=session) | ||
| def _resolve_expansion_kwargs( | ||
| self, kwargs: Dict[str, Any], template_fields: Set[str], context: Context, session: Session | ||
| ) -> None: | ||
| """Update mapped fields in place in kwargs dict""" | ||
| from airflow.models.xcom_arg import XComArg | ||
|
|
||
| expansion_kwargs = self._get_expansion_kwargs() | ||
|
|
||
| for k, v in expansion_kwargs.items(): | ||
| if isinstance(v, XComArg): | ||
| v = v.resolve(context, session=session) | ||
| v = self._expand_mapped_field(k, v, context, session=session) | ||
| template_fields.discard(k) | ||
| kwargs[k] = v | ||
|
|
||
| def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any: | ||
| map_index = context["ti"].map_index | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,9 +22,12 @@ | |
| from airflow.models.xcom import XCOM_RETURN_KEY | ||
| from airflow.utils.context import Context | ||
| from airflow.utils.edgemodifier import EdgeModifier | ||
| from airflow.utils.session import NEW_SESSION, provide_session | ||
| from airflow.utils.types import NOTSET | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sqlalchemy.orm import Session | ||
|
|
||
| from airflow.models.operator import Operator | ||
|
|
||
|
|
||
|
|
@@ -136,12 +139,15 @@ def set_downstream( | |
| """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" | ||
| self.operator.set_downstream(task_or_task_list, edge_modifier) | ||
|
|
||
| def resolve(self, context: Context) -> Any: | ||
| @provide_session | ||
| def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: | ||
|
Comment on lines
+142
to
+143
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC I wondered if we could do this previously, but eventually did not due to interface compatibility issues.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What interface issues might that be?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don’t quite remember now, maybe it has something to do with DagParam also has |
||
| """ | ||
| Pull XCom value for the existing arg. This method is run during ``op.execute()`` | ||
| in respectable context. | ||
| """ | ||
| result = context["ti"].xcom_pull(task_ids=self.operator.task_id, key=str(self.key), default=NOTSET) | ||
| result = context["ti"].xcom_pull( | ||
| task_ids=self.operator.task_id, key=str(self.key), default=NOTSET, session=session | ||
| ) | ||
| if result is NOTSET: | ||
| raise AirflowException( | ||
| f'XComArg result from {self.operator.task_id} at {context["ti"].dag_id} ' | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’ll be difficult to actually hit this cache due to
session:(There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Session has a longer life time than you might expect due to SQLA's Pooling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.