Skip to content

Commit

Permalink
Ensure DAG-level references are filled on unmap (#33083)
Browse files Browse the repository at this point in the history
Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>
(cherry picked from commit bcfadcf)
  • Loading branch information
uranusjr authored and ephraimbuddy committed Aug 4, 2023
1 parent be31ac2 commit faf9de4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
2 changes: 2 additions & 0 deletions airflow/models/mappedoperator.py
Expand Up @@ -659,6 +659,8 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->

op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
SerializedBaseOperator.populate_operator(op, self.operator_class)
if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies.
SerializedBaseOperator.set_task_dag_references(op, self.dag)
return op

def _get_specified_expand_input(self) -> ExpandInput:
Expand Down
59 changes: 41 additions & 18 deletions airflow/serialization/serialized_objects.py
Expand Up @@ -735,6 +735,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
All operators are casted to SerializedBaseOperator after deserialization.
Class specific attributes used by UI are move to object attributes.
Creating a SerializedBaseOperator is a three-step process:
1. Instantiate a :class:`SerializedBaseOperator` object.
2. Populate attributes with :func:`SerializedBaseOperator.populated_operator`.
3. When the task's containing DAG is available, fix references to the DAG
with :func:`SerializedBaseOperator.set_task_dag_references`.
"""

_decorated_fields = {"executor_config"}
Expand Down Expand Up @@ -875,6 +882,13 @@ def _serialize_deps(cls, op_deps: Iterable[BaseTIDep]) -> list[str]:

@classmethod
def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
"""Populate operator attributes with serialized values.
This covers simple attributes that don't reference other things in the
DAG. Setting references (such as ``op.dag`` and task dependencies) is
done in ``set_task_dag_references`` instead, which is called after the
DAG is hydrated.
"""
if "label" not in encoded_op:
# Handle deserialization of old data before the introduction of TaskGroup
encoded_op["label"] = encoded_op["task_id"]
Expand Down Expand Up @@ -982,6 +996,32 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
# Used to determine if an Operator is inherited from EmptyOperator
setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))

@staticmethod
def set_task_dag_references(task: Operator, dag: DAG) -> None:
"""Handle DAG references on an operator.
The operator should have been mostly populated earlier by calling
``populate_operator``. This function further fixes object references
that were not possible before the task's containing DAG is hydrated.
"""
task.dag = dag

for date_attr in ("start_date", "end_date"):
if getattr(task, date_attr, None) is None:
setattr(task, date_attr, getattr(dag, date_attr, None))

if task.subdag is not None:
task.subdag.parent_dag = dag

# Dereference expand_input and op_kwargs_expand_input.
for k in ("expand_input", "op_kwargs_expand_input"):
if isinstance(kwargs_ref := getattr(task, k, None), _ExpandInputRef):
setattr(task, k, kwargs_ref.deref(dag))

for task_id in task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)

@classmethod
def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
"""Deserializes an operator from a JSON object."""
Expand Down Expand Up @@ -1328,24 +1368,7 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
setattr(dag, k, None)

for task in dag.task_dict.values():
task.dag = dag

for date_attr in ["start_date", "end_date"]:
if getattr(task, date_attr) is None:
setattr(task, date_attr, getattr(dag, date_attr))

if task.subdag is not None:
setattr(task.subdag, "parent_dag", dag)

# Dereference expand_input and op_kwargs_expand_input.
for k in ("expand_input", "op_kwargs_expand_input"):
kwargs_ref = getattr(task, k, None)
if isinstance(kwargs_ref, _ExpandInputRef):
setattr(task, k, kwargs_ref.deref(dag))

for task_id in task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
SerializedBaseOperator.set_task_dag_references(task, dag)

return dag

Expand Down
21 changes: 21 additions & 0 deletions tests/serialization/test_serialized_objects.py
Expand Up @@ -96,3 +96,24 @@ def test_use_pydantic_models():
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) # does not raise

assert isinstance(deserialized[0][0], TaskInstancePydantic)


def test_serialized_mapped_operator_unmap(dag_maker):
from airflow.serialization.serialized_objects import SerializedDAG
from tests.test_utils.mock_operators import MockOperator

with dag_maker(dag_id="dag") as dag:
MockOperator(task_id="task1", arg1="x")
MockOperator.partial(task_id="task2").expand(arg1=["a", "b"])

serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
assert serialized_dag.dag_id == "dag"

serialized_task1 = serialized_dag.get_task("task1")
assert serialized_task1.dag is serialized_dag

serialized_task2 = serialized_dag.get_task("task2")
assert serialized_task2.dag is serialized_dag

serialized_unmapped_task = serialized_task2.unmap(None)
assert serialized_unmapped_task.dag is serialized_dag

0 comments on commit faf9de4

Please sign in to comment.