Skip to content

Commit

Permalink
refactor: Refactored __new__ magic method of BaseOperatorMeta to avoi…
Browse files Browse the repository at this point in the history
…d bad mixing classic and decorated operators (#37937)

* refactor: Refactored __new__ magic method of BaseOperatorMeta class to include an execute safeguard which prohibits called the execute method manually from a Python callable or @task decorated python method.

* refactor: Get unit_test_mode from airflow config instead of directly retrieving it from env variable

* fix: Fixed import of AirflowException

* refactor: Use conf_vars instead of patch.dict to alter Airflow unit_test_mode config parameter and use pytest instead of unittest to run the test

* fix: Fixed match on expected exception message

* refactor: Added test case where HelloWorldOperator is called from within an PythonOperator

* refactor: Refactored TestBaseOperatorMeta by having a dedicated fixture for TaskInstance

* refactor: Reformatted some files that where failing due to static checks

* refactor: Try disabling unit_test_mode right just before calling the run method of task instance in hope test won't fail on Airflow

* refactor: Re-ordered imports as asked by static checks

* refactor: Added license at top of new test module

* refactor: Refactored ExecutorSafeguard decorator as a class so we can toggle the test_mode on it when used in unit tests

* refactor: Removed patch of sqlalchemy Session on test methods

* refactor: Refactored ExecutorSafeguard decorator as a class so we can toggle the test_mode on it when used in unit tests

* Revert "refactor: Refactored ExecutorSafeguard decorator as a class so we can toggle the test_mode on it when used in unit tests"

This reverts commit fabea74.

* refactor: Refactored ExecutorSafeguard decorator as a class so we can toggle the test_mode on it when used in unit tests

* refactor: Added docstring to ExecutorSafeguard decorator

* docs: Reformatted docstring of ExecutorSafeguard

* refactor: Added missing white line between docstring and test_mode class var of ExecutorSafeguard class

* refactor: Fixed unit tests for ExecutorSafeguard

* refactor: Reformatted baseoperator and test_baseoperatormeta file as expected by static check

* fix: Fixed import of partial function from functools

* refactor: Refactored multiple patches into one context manager so it's Python 3.8 compatible

* refactor: Reformatted patch statement in TestExecutorSafeguard

* refactor: Added allow_mixing attribute to BaseOperator and added test case when allow_mixing is enabled

* refactor: Fixed static checks on baseoperator

* fix: Forgot to also add allow_mixin parameter to partial method in BaseOperator

* refactor: Trying to fix example in docstring of allow_mixin param in BaseOperator

* refactor: Added bool type to allow_mixing attribute of BaseOperator

* refactor: Added allow_mixin parameter

* refactor: Added init file in resources package

* refactor: Changed docstring of BaseOperator to raw string due to backslash present in example

* fix: Fixed name of allow_mixin parameter

* fix: Fixed check on allow_mixin property

* fix: Fixed allow_mixin property in test

* refactor: Added allow_mixin to schema.json and assertion in test_dag_serialization

* refactor: Refactored tests using dag_maker fixture and running them as db_test + fixed issue of nested calls with ExecutorSafeguard

* refactor: Moved import of Context under type checking

* refactor: Reformatted TestExecutorSafeguard

* refactor: Simplified check on test_mode in wrapper

* refactor: Improved check as a classic operator can also be called through a python function called by a Python operator

* refactor: Fixed additional static checks

* refactor: Fixed additional static checks

* refactor: Use sentinel instead of traceback to detect if execute was called outside TaskInstance

* refactor: Use test_mode from TaskInstance

* refactor: Removed unused imports

* refactor: Put message in one line

* Revert "refactor: Use test_mode from TaskInstance"

This reverts commit 8a35ebb

* refactor: Put message on one line

* refactor: Fixed passing of sentinel arg when resume_execution is being called

* refactor: Reformatted test

* refactor: Check if next_kwargs is not None

* refactor: Fixed DAG example in docstring of chain_linear method

* refactor: Renamed allow_mixin parameter of BaseOperator to allow_nested_operators

---------

Co-authored-by: David Blain <david.blain@infrabel.be>
  • Loading branch information
dabla and David Blain committed Mar 24, 2024
1 parent 77d2fc7 commit 694826d
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 5 deletions.
1 change: 1 addition & 0 deletions airflow/models/base.py
Expand Up @@ -46,6 +46,7 @@ def _get_schema():

metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
mapper_registry = registry(metadata=metadata)
_sentinel = object()

Base: Any = mapper_registry.generate_base()

Expand Down
61 changes: 58 additions & 3 deletions airflow/models/baseoperator.py
Expand Up @@ -32,6 +32,7 @@
import sys
import warnings
from datetime import datetime, timedelta
from functools import total_ordering, wraps
from inspect import signature
from types import FunctionType
from typing import (
Expand Down Expand Up @@ -75,6 +76,7 @@
DEFAULT_WEIGHT_RULE,
AbstractOperator,
)
from airflow.models.base import _sentinel
from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs
from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
Expand Down Expand Up @@ -215,6 +217,7 @@ def partial(**kwargs):
"weight_rule": DEFAULT_WEIGHT_RULE,
"inlets": [],
"outlets": [],
"allow_nested_operators": True,
}


Expand Down Expand Up @@ -265,6 +268,7 @@ def partial(
doc_yaml: str | None | ArgNotSet = NOTSET,
doc_rst: str | None | ArgNotSet = NOTSET,
logger_name: str | None | ArgNotSet = NOTSET,
allow_nested_operators: bool = True,
**kwargs,
) -> OperatorPartial:
from airflow.models.dag import DagContext
Expand Down Expand Up @@ -331,6 +335,7 @@ def partial(
"doc_rst": doc_rst,
"doc_yaml": doc_yaml,
"logger_name": logger_name,
"allow_nested_operators": allow_nested_operators,
}

# Inject DAG-level default args into args provided to this function.
Expand Down Expand Up @@ -365,6 +370,35 @@ def partial(
)


class ExecutorSafeguard:
"""
The ExecutorSafeguard decorator.
Checks if the execute method of an operator isn't manually called outside
the TaskInstance as we want to avoid bad mixing between decorated and
classic operators.
"""

test_mode = conf.getboolean("core", "unit_test_mode")

@classmethod
def decorator(cls, func):
@wraps(func)
def wrapper(self, *args, **kwargs):
from airflow.decorators.base import DecoratedOperator

sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None)

if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator):
message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!"
if not self.allow_nested_operators:
raise AirflowException(message)
self.log.warning(message)
return func(self, *args, **kwargs)

return wrapper


class BaseOperatorMeta(abc.ABCMeta):
"""Metaclass of BaseOperator."""

Expand Down Expand Up @@ -396,7 +430,7 @@ def _apply_defaults(cls, func: T) -> T:

fixup_decorator_warning_stack(func)

@functools.wraps(func)
@wraps(func)
def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
from airflow.models.dag import DagContext
from airflow.utils.task_group import TaskGroupContext
Expand Down Expand Up @@ -464,6 +498,9 @@ def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
return cast(T, apply_defaults)

def __new__(cls, name, bases, namespace, **kwargs):
execute_method = namespace.get("execute")
if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False):
namespace["execute"] = ExecutorSafeguard().decorator(execute_method)
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
with contextlib.suppress(KeyError):
# Update the partial descriptor with the class method, so it calls the actual function
Expand All @@ -475,9 +512,9 @@ def __new__(cls, name, bases, namespace, **kwargs):
return new_cls


@functools.total_ordering
@total_ordering
class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
"""
r"""
Abstract base class for all operators.
Since operators create objects that become nodes in the DAG, BaseOperator
Expand Down Expand Up @@ -672,6 +709,21 @@ class derived from this one results in the creation of a task object,
If set to `None` (default), the logger name will fall back to
`airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. SimpleHttpOperator will have
*airflow.task.operators.airflow.providers.http.operators.http.SimpleHttpOperator* as logger).
:param allow_nested_operators: if True, when an operator is executed within another one a warning message
will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested
within another one). In future releases of Airflow this parameter will be removed and an exception
will always be thrown when operators are nested within each other (default is True).
**Example**: example of a bad operator mixin usage::
@task(provide_context=True)
def say_hello_world(**context):
hello_world_task = BashOperator(
task_id="hello_world_task",
bash_command="python -c \"print('Hello, world!')\"",
dag=dag,
)
hello_world_task.execute(context)
"""

# Implementing Operator.
Expand Down Expand Up @@ -727,6 +779,7 @@ class derived from this one results in the creation of a task object,
"on_skipped_callback",
"do_xcom_push",
"multiple_outputs",
"allow_nested_operators",
}

# Defines if the operator supports lineage without manual definitions
Expand Down Expand Up @@ -807,6 +860,7 @@ def __init__(
doc_yaml: str | None = None,
doc_rst: str | None = None,
logger_name: str | None = None,
allow_nested_operators: bool = True,
**kwargs,
):
from airflow.models.dag import DagContext
Expand Down Expand Up @@ -956,6 +1010,7 @@ def __init__(

self._log_config_logger_name = "airflow.task.operators"
self._logger_name = logger_name
self.allow_nested_operators: bool = allow_nested_operators

# Lineage
self.inlets: list = []
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/mappedoperator.py
Expand Up @@ -656,6 +656,10 @@ def doc_yaml(self) -> str | None:
def doc_rst(self) -> str | None:
return self.partial_kwargs.get("doc_rst")

@property
def allow_nested_operators(self) -> bool:
return bool(self.partial_kwargs.get("allow_nested_operators"))

def get_dag(self) -> DAG | None:
"""Implement Operator."""
return self.dag
Expand Down
8 changes: 7 additions & 1 deletion airflow/models/taskinstance.py
Expand Up @@ -86,7 +86,7 @@
XComForMappingNotPushed,
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID, TaskInstanceDependencies
from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
from airflow.models.dagbag import DagBag
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
Expand Down Expand Up @@ -411,11 +411,17 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
execute_callable_kwargs: dict[str, Any] = {}
execute_callable: Callable
if task_instance.next_method:
if task_instance.next_method == "execute":
if not task_instance.next_kwargs:
task_instance.next_kwargs = {}
task_instance.next_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel
execute_callable = task_to_execute.resume_execution
execute_callable_kwargs["next_method"] = task_instance.next_method
execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs
else:
execute_callable = task_to_execute.execute
if execute_callable.__name__ == "execute":
execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel

def _execute_callable(context, **execute_callable_kwargs):
try:
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/microsoft/azure/serialization/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
3 changes: 2 additions & 1 deletion airflow/serialization/schema.json
Expand Up @@ -289,7 +289,8 @@
"_is_mapped": { "const": true, "$comment": "only present when True" },
"expand_input": { "type": "object" },
"partial_kwargs": { "type": "object" },
"map_index_template": { "type": "string" }
"map_index_template": { "type": "string" },
"allow_nested_operators": { "type": "boolean" }
},
"dependencies": {
"expand_input": ["partial_kwargs", "_is_mapped"],
Expand Down

0 comments on commit 694826d

Please sign in to comment.