Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 83 additions & 28 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def _apply_defaults(cls, func: T) -> T:
if param.name != 'self' and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
}
non_optional_args = {
name for (name, param) in non_variadic_params.items() if param.default == param.empty
name
for name, param in non_variadic_params.items()
if param.default == param.empty and name != "task_id"
}

class autostacklevel_warn:
Expand Down Expand Up @@ -746,9 +748,8 @@ def __init__(
self.doc_rst = doc_rst
self.doc = doc

# Private attributes
self._upstream_task_ids: Set[str] = set()
self._downstream_task_ids: Set[str] = set()
self.upstream_task_ids: Set[str] = set()
self.downstream_task_ids: Set[str] = set()

if dag:
self.dag = dag
Expand Down Expand Up @@ -1260,16 +1261,6 @@ def resolve_template_files(self) -> None:
self.log.exception(e)
self.prepare_template()

@property
def upstream_task_ids(self) -> Set[str]:
"""@property: set of ids of tasks directly upstream"""
return self._upstream_task_ids

@property
def downstream_task_ids(self) -> Set[str]:
"""@property: set of ids of tasks directly downstream"""
return self._downstream_task_ids

@provide_session
def clear(
self,
Expand Down Expand Up @@ -1429,9 +1420,9 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
downstream.
"""
if upstream:
return self._upstream_task_ids
return self.upstream_task_ids
else:
return self._downstream_task_ids
return self.downstream_task_ids

def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]:
"""
Expand Down Expand Up @@ -1577,7 +1568,7 @@ def get_serialized_fields(cls):
- {
'inlets',
'outlets',
'_upstream_task_ids',
'upstream_task_ids',
'default_args',
'dag',
'_dag',
Expand Down Expand Up @@ -1672,8 +1663,12 @@ def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]:
return False


def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, value: Dict[str, Any]):
if isinstance(str, cls):
def _validate_kwarg_names_for_mapping(
cls: Union[str, Type[BaseOperator]],
func_name: str,
value: Dict[str, Any],
) -> None:
if isinstance(cls, str):
# Serialized version -- would have been validated at parse time
return

Expand Down Expand Up @@ -1706,25 +1701,49 @@ def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, v
class MappedOperator(DAGNode):
"""Object representing a mapped operator in a DAG"""

operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__)
def __repr__(self) -> str:
return (
f'MappedOperator(task_type={self.task_type}, '
f'task_id={self.task_id!r}, partial_kwargs={self.partial_kwargs!r}, '
f'mapped_kwargs={self.mapped_kwargs!r}, dag={self.dag})'
)

operator_class: Union[Type[BaseOperator], str]
task_type: str = attr.ib()
task_id: str
partial_kwargs: Dict[str, Any]
mapped_kwargs: Dict[str, Any] = attr.ib(
validator=lambda self, _, v: _validate_kwarg_names_for_mapping(self.operator_class, "map", v)
)
dag: Optional["DAG"] = None
upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False)
downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False)

task_group: Optional["TaskGroup"] = attr.ib(repr=False)
upstream_task_ids: Set[str] = attr.ib(factory=set)
downstream_task_ids: Set[str] = attr.ib(factory=set)

task_group: Optional["TaskGroup"] = attr.ib()
# BaseOperator-like interface -- needed so we can add oursleves to the dag.tasks
start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None)
end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None)
start_date: Optional[pendulum.DateTime] = attr.ib(default=None)
end_date: Optional[pendulum.DateTime] = attr.ib(default=None)
owner: str = attr.ib(repr=False, default=conf.get("operators", "DEFAULT_OWNER"))
max_active_tis_per_dag: Optional[int] = attr.ib(default=None)

# Needed for SerializedBaseOperator
_is_dummy: bool = attr.ib()

deps: Iterable[BaseTIDep] = attr.ib()
operator_extra_links: Iterable['BaseOperatorLink'] = ()
params: Union[ParamsDict, dict] = attr.ib(factory=ParamsDict)
template_fields: Iterable[str] = attr.ib()

@_is_dummy.default
def _is_dummy_default(self):
from airflow.operators.dummy import DummyOperator

return issubclass(self.operator_class, DummyOperator)

@deps.default
def _deps_from_class(self):
return self.operator_class.deps
Comment on lines +1743 to +1745
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m assuming many of these defaults (this, _is_dummy, and template_fields, I believe) don’t need to consider when operator_class is a str because in that case these values would’ve been supplied explicitly instead. (It’s kind of bad it’s designed this way but I guess that can be said for many things regarding the current serialisation implementation...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly that. I've already started thinking in my head about how to refactor/rearchitect the serialization and deserialization.


@classmethod
def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator":
dag: Optional["DAG"] = getattr(operator, '_dag', None)
Expand All @@ -1733,17 +1752,20 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) ->
# are mapped, we want to _remove_ that task from the dag
dag._remove_task(operator.task_id)

operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore
return MappedOperator(
operator_class=type(operator),
task_id=operator.task_id,
task_group=getattr(operator, 'task_group', None),
dag=getattr(operator, '_dag', None),
start_date=operator.start_date,
end_date=operator.end_date,
partial_kwargs=operator._BaseOperator__init_kwargs, # type: ignore
partial_kwargs={k: v for k, v in operator_init_kwargs.items() if k != "task_id"},
mapped_kwargs=mapped_kwargs,
owner=operator.owner,
max_active_tis_per_dag=operator.max_active_tis_per_dag,
deps=operator.deps,
params=operator.params,
)

@classmethod
Expand Down Expand Up @@ -1781,14 +1803,22 @@ def __attrs_post_init__(self):

@task_type.default
def _default_task_type(self):
return self.operator_class.__name__
# Can be a string if we are de-serialized
val = self.operator_class
if isinstance(val, str):
return val.rsplit('.', 1)[-1]
return val.__name__

@task_group.default
def _default_task_group(self):
from airflow.utils.task_group import TaskGroupContext

return TaskGroupContext.get_current_task_group(self.dag)

@template_fields.default
def _template_fields_default(self):
return self.operator_class.template_fields

@property
def node_id(self):
return self.task_id
Expand Down Expand Up @@ -1820,6 +1850,31 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
"""Required by DAGNode."""
return DagAttributeTypes.OP, self.task_id

@property
def inherits_from_dummy_operator(self):
"""Used to determine if an Operator is inherited from DummyOperator"""
return self._is_dummy

# The _serialized_fields are lazily loaded when get_serialized_fields() method is called
__serialized_fields: ClassVar[Optional[FrozenSet[str]]] = None

@classmethod
def get_serialized_fields(cls):
if cls.__serialized_fields is None:
fields_dict = attr.fields_dict(cls)
cls.__serialized_fields = frozenset(
fields_dict.keys()
- {
'deps',
'inherits_from_dummy_operator',
'operator_extra_links',
'upstream_task_ids',
'task_type',
}
| {'template_fields'}
)
return cls.__serialized_fields


# TODO: Deprecate for Airflow 3.0
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,8 +2102,8 @@ def filter_task_group(group, parent_group):
for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
# make the cut
t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys())
t._downstream_task_ids = t.downstream_task_ids.intersection(dag.task_dict.keys())
t.upstream_task_ids.intersection_update(dag.task_dict)
t.downstream_task_ids.intersection_update(dag.task_dict)

if len(dag.tasks) < len(self.tasks):
dag.partial = True
Expand Down
12 changes: 2 additions & 10 deletions airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,12 @@ def label(self) -> Optional[str]:

start_date: Optional[pendulum.DateTime]
end_date: Optional[pendulum.DateTime]
upstream_task_ids: Set[str]
downstream_task_ids: Set[str]

def has_dag(self) -> bool:
return self.dag is not None

@property
@abstractmethod
def upstream_task_ids(self) -> Set[str]:
raise NotImplementedError()

@property
@abstractmethod
def downstream_task_ids(self) -> Set[str]:
raise NotImplementedError()

@property
def log(self) -> "Logger":
raise NotImplementedError()
Expand Down
12 changes: 10 additions & 2 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
"items": { "type": "string" }
},
"subdag": { "$ref": "#/definitions/dag" },
"_downstream_task_ids": {
"downstream_task_ids": {
"type": "array",
"items": { "type": "string" }
},
Expand All @@ -211,7 +211,15 @@
"doc_md": { "type": "string" },
"doc_json": { "type": "string" },
"doc_yaml": { "type": "string" },
"doc_rst": { "type": "string" }
"doc_rst": { "type": "string" },
"_is_mapped": { "const": true, "$comment": "only present when True" },
"mapped_kwargs": { "type": "object" },
"partial_kwargs": { "type": "object" }
},
"dependencies": {
"mapped_kwargs": ["partial_kwargs", "_is_mapped"],
"partial_kwargs": ["mapped_kwargs", "_is_mapped"],
"_is_mapped": ["mapped_kwargs", "partial_kwargs"]
},
"additionalProperties": true
},
Expand Down
Loading