From 7b5962e747576ccae7798e1d9c6c1398086efd29 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Sat, 4 May 2024 22:34:19 +0400 Subject: [PATCH] Refactor cloudpickle support in Python operators/decorators (#39270) * Refactor cloudpickle support in Python operators/decorators * Fixup missing marker * Return back skip TestPythonVirtualenvOperator::test_airflow_context for dill * TestPythonVirtualenvOperator::test_airflow_context xfail instead of skip * Catch only on ModuleNotFound error and simple reraise with warning * Limit test_airflow_context only for python 3.11 --- airflow/decorators/__init__.pyi | 61 ++- .../tutorial_taskflow_api_virtualenv.py | 2 +- airflow/operators/python.py | 153 ++++--- tests/decorators/test_branch_virtualenv.py | 2 +- tests/decorators/test_external_python.py | 227 +++++----- tests/decorators/test_python_virtualenv.py | 229 +++++----- tests/operators/test_python.py | 405 ++++++++---------- 7 files changed, 566 insertions(+), 513 deletions(-) diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index a10a1cf39e844..e88a535db503a 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -111,7 +111,7 @@ class TaskDecoratorCollection: # _PythonVirtualenvDecoratedOperator. requirements: None | Iterable[str] | str = None, python_version: None | str | int | float = None, - use_dill: bool = False, + serializer: Literal["pickle", "cloudpickle", "dill"] | None = None, system_site_packages: bool = True, templates_dict: Mapping[str, Any] | None = None, pip_install_options: list[str] | None = None, @@ -119,6 +119,7 @@ class TaskDecoratorCollection: index_urls: None | Collection[str] | str = None, venv_cache_path: None | str = None, show_return_value_in_logs: bool = True, + use_dill: bool = False, **kwargs, ) -> TaskDecorator: """Create a decorator to convert the decorated callable to a virtual environment task. @@ -129,6 +130,13 @@ class TaskDecoratorCollection: "requirements file" as specified by pip. :param python_version: The Python version to run the virtual environment with. Note that both 2 and 2.7 are acceptable forms. + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. :param use_dill: Whether to use dill to serialize the args and result (pickle is default). This allow more complex types but requires you to include dill in your requirements. @@ -154,6 +162,9 @@ class TaskDecoratorCollection: logs. Defaults to True, which allows return value log output. It can be set to False to prevent log output of return value when you return huge data such as transmission a large amount of XCom to TaskAPI. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. """ @overload def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... @@ -164,9 +175,10 @@ class TaskDecoratorCollection: multiple_outputs: bool | None = None, # 'python_callable', 'op_args' and 'op_kwargs' since they are filled by # _PythonVirtualenvDecoratedOperator. - use_dill: bool = False, + serializer: Literal["pickle", "cloudpickle", "dill"] | None = None, templates_dict: Mapping[str, Any] | None = None, show_return_value_in_logs: bool = True, + use_dill: bool = False, **kwargs, ) -> TaskDecorator: """Create a decorator to convert the decorated callable to a virtual environment task. @@ -176,9 +188,13 @@ class TaskDecoratorCollection: (so usually start with "/" or "X:/" depending on the filesystem/os used). :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. Dict will unroll to XCom values with keys as XCom keys. Defaults to False. - :param use_dill: Whether to use dill to serialize - the args and result (pickle is default). This allow more complex types - but requires you to include dill in your requirements. + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -187,6 +203,9 @@ class TaskDecoratorCollection: logs. Defaults to True, which allows return value log output. It can be set to False to prevent log output of return value when you return huge data such as transmission a large amount of XCom to TaskAPI. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. """ @overload def branch( # type: ignore[misc] @@ -211,7 +230,7 @@ class TaskDecoratorCollection: # _PythonVirtualenvDecoratedOperator. requirements: None | Iterable[str] | str = None, python_version: None | str | int | float = None, - use_dill: bool = False, + serializer: Literal["pickle", "cloudpickle", "dill"] | None = None, system_site_packages: bool = True, templates_dict: Mapping[str, Any] | None = None, pip_install_options: list[str] | None = None, @@ -219,6 +238,7 @@ class TaskDecoratorCollection: index_urls: None | Collection[str] | str = None, venv_cache_path: None | str = None, show_return_value_in_logs: bool = True, + use_dill: bool = False, **kwargs, ) -> TaskDecorator: """Create a decorator to wrap the decorated callable into a BranchPythonVirtualenvOperator. @@ -232,9 +252,13 @@ class TaskDecoratorCollection: "requirements file" as specified by pip. :param python_version: The Python version to run the virtual environment with. Note that both 2 and 2.7 are acceptable forms. - :param use_dill: Whether to use dill to serialize - the args and result (pickle is default). This allow more complex types - but requires you to include dill in your requirements. + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. :param system_site_packages: Whether to include system_site_packages in your virtual environment. See virtualenv documentation for more information. @@ -253,6 +277,9 @@ class TaskDecoratorCollection: logs. Defaults to True, which allows return value log output. It can be set to False to prevent log output of return value when you return huge data such as transmission a large amount of XCom to TaskAPI. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. """ @overload def branch_virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... @@ -264,9 +291,10 @@ class TaskDecoratorCollection: multiple_outputs: bool | None = None, # 'python_callable', 'op_args' and 'op_kwargs' since they are filled by # _PythonVirtualenvDecoratedOperator. - use_dill: bool = False, + serializer: Literal["pickle", "cloudpickle", "dill"] | None = None, templates_dict: Mapping[str, Any] | None = None, show_return_value_in_logs: bool = True, + use_dill: bool = False, **kwargs, ) -> TaskDecorator: """Create a decorator to wrap the decorated callable into a BranchExternalPythonOperator. @@ -279,9 +307,13 @@ class TaskDecoratorCollection: (so usually start with "/" or "X:/" depending on the filesystem/os used). :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. Dict will unroll to XCom values with keys as XCom keys. Defaults to False. - :param use_dill: Whether to use dill to serialize - the args and result (pickle is default). This allow more complex types - but requires you to include dill in your requirements. + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -290,6 +322,9 @@ class TaskDecoratorCollection: logs. Defaults to True, which allows return value log output. It can be set to False to prevent log output of return value when you return huge data such as transmission a large amount of XCom to TaskAPI. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. """ @overload def branch_external_python( diff --git a/airflow/example_dags/tutorial_taskflow_api_virtualenv.py b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py index 44134e445891d..3860876e6e687 100644 --- a/airflow/example_dags/tutorial_taskflow_api_virtualenv.py +++ b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py @@ -38,7 +38,7 @@ def tutorial_taskflow_api_virtualenv(): """ @task.virtualenv( - use_dill=True, + serializer="dill", # Use `dill` for advanced serialization. system_site_packages=False, requirements=["funcsigs"], ) diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 2368d78d80ce4..977ef54ecb617 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -23,7 +23,6 @@ import json import logging import os -import pickle import shutil import subprocess import sys @@ -36,6 +35,8 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Mapping, NamedTuple, Sequence, cast +import lazy_object_proxy + from airflow.compat.functools import cache from airflow.exceptions import ( AirflowConfigException, @@ -49,6 +50,7 @@ from airflow.models.taskinstance import _CURRENT_CONTEXT from airflow.models.variable import Variable from airflow.operators.branch import BranchMixIn +from airflow.typing_compat import Literal from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_get_dataset_events, context_merge from airflow.utils.file import get_unique_dag_module_name @@ -58,13 +60,6 @@ log = logging.getLogger(__name__) -if shutil.which("cloudpickle") or importlib.util.find_spec("cloudpickle"): - import cloudpickle as serialization_library -elif shutil.which("dill") or importlib.util.find_spec("dill"): - import dill as serialization_library -else: - log.debug("Neither dill and cloudpickle are installed. Please install one with: pip install [name]") - if TYPE_CHECKING: from pendulum.datetime import DateTime @@ -350,6 +345,41 @@ def get_tasks_to_skip(): return condition +def _load_pickle(): + import pickle + + return pickle + + +def _load_dill(): + try: + import dill + except ModuleNotFoundError: + log.error("Unable to import `dill` module. Please please make sure that it installed.") + raise + return dill + + +def _load_cloudpickle(): + try: + import cloudpickle + except ModuleNotFoundError: + log.error( + "Unable to import `cloudpickle` module. " + "Please install it with: pip install 'apache-airflow[cloudpickle]'" + ) + raise + return cloudpickle + + +_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] +_SERIALIZERS: dict[_SerializerTypeDef, Any] = { + "pickle": lazy_object_proxy.Proxy(_load_pickle), + "dill": lazy_object_proxy.Proxy(_load_dill), + "cloudpickle": lazy_object_proxy.Proxy(_load_cloudpickle), +} + + class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): BASE_SERIALIZABLE_CONTEXT_KEYS = { "ds", @@ -400,8 +430,7 @@ def __init__( self, *, python_callable: Callable, - use_dill: bool = False, - use_cloudpickle: bool = False, + serializer: _SerializerTypeDef | None = None, op_args: Collection[Any] | None = None, op_kwargs: Mapping[str, Any] | None = None, string_args: Iterable[str] | None = None, @@ -409,6 +438,7 @@ def __init__( templates_exts: list[str] | None = None, expect_airflow: bool = True, skip_on_exit_code: int | Container[int] | None = None, + use_dill: bool = False, **kwargs, ): if ( @@ -428,15 +458,29 @@ def __init__( **kwargs, ) self.string_args = string_args or [] - if use_dill and use_cloudpickle: - raise AirflowException( - "Both 'use_dill' and 'use_cloudpickle' parameters are set to True. Please," - " choose only one." - ) + if use_dill: - use_cloudpickle = use_dill - self.use_cloudpickle = use_cloudpickle - self.pickling_library = serialization_library if self.use_cloudpickle else pickle + warnings.warn( + "`use_dill` is deprecated and will be removed in a future version. " + "Please provide serializer='dill' instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + if serializer: + raise AirflowException( + "Both 'use_dill' and 'serializer' parameters are set. Please set only one of them" + ) + serializer = "dill" + serializer = serializer or "pickle" + if serializer not in _SERIALIZERS: + msg = ( + f"Unsupported serializer {serializer!r}. " + f"Expected one of {', '.join(map(repr, _SERIALIZERS))}" + ) + raise AirflowException(msg) + self.pickling_library = _SERIALIZERS[serializer] + self.serializer: _SerializerTypeDef = serializer + self.expect_airflow = expect_airflow self.skip_on_exit_code = ( skip_on_exit_code @@ -461,6 +505,7 @@ def get_python_source(self): def _write_args(self, file: Path): if self.op_args or self.op_kwargs: + self.log.info("Use %r as serializer.", self.serializer) file.write_bytes(self.pickling_library.dumps({"args": self.op_args, "kwargs": self.op_kwargs})) def _write_string_args(self, file: Path): @@ -498,7 +543,7 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): "op_args": self.op_args, "op_kwargs": op_kwargs, "expect_airflow": self.expect_airflow, - "pickling_library": self.pickling_library.__name__, + "pickling_library": self.serializer, "python_callable": self.python_callable.__name__, "python_callable_source": self.get_python_source(), } @@ -567,12 +612,13 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): "requirements file" as specified by pip. :param python_version: The Python version to run the virtual environment with. Note that both 2 and 2.7 are acceptable forms. - :param use_dill: Whether to use dill to serialize - the args and result (pickle is default). This allow more complex types - but requires you to include dill in your requirements. - :param use_cloudpickle: Whether to use cloudpickle to serialize - the args and result (pickle is default). This allows more complex types - but requires you to include cloudpickle in your requirements. + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. :param system_site_packages: Whether to include system_site_packages in your virtual environment. See virtualenv documentation for more information. @@ -601,6 +647,9 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): virtual environment will be cached, creates a sub-folder venv-{hash} whereas hash will be replaced with a checksum of requirements. If not provided the virtual environment will be created and deleted in a temp folder for every execution. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. """ template_fields: Sequence[str] = tuple( @@ -614,8 +663,7 @@ def __init__( python_callable: Callable, requirements: None | Iterable[str] | str = None, python_version: str | None = None, - use_dill: bool = False, - use_cloudpickle: bool = False, + serializer: _SerializerTypeDef | None = None, system_site_packages: bool = True, pip_install_options: list[str] | None = None, op_args: Collection[Any] | None = None, @@ -627,6 +675,7 @@ def __init__( skip_on_exit_code: int | Container[int] | None = None, index_urls: None | Collection[str] | str = None, venv_cache_path: None | os.PathLike[str] = None, + use_dill: bool = False, **kwargs, ): if ( @@ -646,13 +695,6 @@ def __init__( RemovedInAirflow3Warning, stacklevel=2, ) - if use_dill and use_cloudpickle: - raise AirflowException( - "Both 'use_dill' and 'use_cloudpickle' parameters are set to True. Please, " - "choose only one." - ) - if use_dill: - use_cloudpickle = use_dill if not is_venv_installed(): raise AirflowException("PythonVirtualenvOperator requires virtualenv, please install it.") if not requirements: @@ -673,7 +715,7 @@ def __init__( self.venv_cache_path = venv_cache_path super().__init__( python_callable=python_callable, - use_cloudpickle=use_cloudpickle, + serializer=serializer, op_args=op_args, op_kwargs=op_kwargs, string_args=string_args, @@ -681,15 +723,22 @@ def __init__( templates_exts=templates_exts, expect_airflow=expect_airflow, skip_on_exit_code=skip_on_exit_code, + use_dill=use_dill, **kwargs, ) def _requirements_list(self, exclude_cloudpickle: bool = False) -> list[str]: """Prepare a list of requirements that need to be installed for the virtual environment.""" requirements = [str(dependency) for dependency in self.requirements] - if not exclude_cloudpickle: - if not self.system_site_packages and self.use_cloudpickle and "cloudpickle" not in requirements: + if not self.system_site_packages: + if ( + self.serializer == "cloudpickle" + and not exclude_cloudpickle + and "cloudpickle" not in requirements + ): requirements.append("cloudpickle") + elif self.serializer == "dill" and "dill" not in requirements: + requirements.append("dill") requirements.sort() # Ensure a hash is stable return requirements @@ -856,13 +905,13 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator): (so usually start with "/" or "X:/" depending on the filesystem/os used). :param python_callable: A python function with no references to outside variables, defined with def, which will be run in a virtual environment. - :param use_dill: Whether to use dill to serialize - the args and result (pickle is default). This allow more complex types - but requires you to include dill in your requirements. - :param use_cloudpickle: Whether to use cloudpickle to serialize - the args and result (pickle is default). This allows more complex types - but if cloudpickle is not preinstalled in your virtual environment, the task will fail - with use_cloudpickle enabled. + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. :param op_args: A list of positional arguments to pass to python_callable. :param op_kwargs: A dict of keyword arguments to pass to python_callable. :param string_args: Strings that are present in the global var virtualenv_string_args, @@ -880,6 +929,9 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator): :param skip_on_exit_code: If python_callable exits with this exit code, leave the task in ``skipped`` state (default: None). If set to ``None``, any non-zero exit code will be treated as a failure. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. """ template_fields: Sequence[str] = tuple({"python"}.union(PythonOperator.template_fields)) @@ -889,8 +941,7 @@ def __init__( *, python: str, python_callable: Callable, - use_dill: bool = False, - use_cloudpickle: bool = False, + serializer: _SerializerTypeDef | None = None, op_args: Collection[Any] | None = None, op_kwargs: Mapping[str, Any] | None = None, string_args: Iterable[str] | None = None, @@ -899,21 +950,16 @@ def __init__( expect_airflow: bool = True, expect_pendulum: bool = False, skip_on_exit_code: int | Container[int] | None = None, + use_dill: bool = False, **kwargs, ): if not python: raise ValueError("Python Path must be defined in ExternalPythonOperator") - if use_dill and use_cloudpickle: - raise AirflowException( - "Both 'use_dill' and 'use_cloudpickle' parameters are set to True. Please, choose only one." - ) - if use_dill: - use_cloudpickle = use_dill self.python = python self.expect_pendulum = expect_pendulum super().__init__( python_callable=python_callable, - use_cloudpickle=use_cloudpickle, + serializer=serializer, op_args=op_args, op_kwargs=op_kwargs, string_args=string_args, @@ -921,6 +967,7 @@ def __init__( templates_exts=templates_exts, expect_airflow=expect_airflow, skip_on_exit_code=skip_on_exit_code, + use_dill=use_dill, **kwargs, ) diff --git a/tests/decorators/test_branch_virtualenv.py b/tests/decorators/test_branch_virtualenv.py index 57db52f167746..d38a157632b3d 100644 --- a/tests/decorators/test_branch_virtualenv.py +++ b/tests/decorators/test_branch_virtualenv.py @@ -25,7 +25,7 @@ pytestmark = pytest.mark.db_test -class Test_BranchPythonVirtualenvDecoratedOperator: +class TestBranchPythonVirtualenvDecoratedOperator: # when run in "Parallel" test run environment, sometimes this test runs for a long time # because creating virtualenv and starting new Python interpreter creates a lot of IO/contention # possibilities. So we are increasing the timeout for this test to 3x of the default timeout diff --git a/tests/decorators/test_external_python.py b/tests/decorators/test_external_python.py index fe5c76101c6d2..034a51166adf6 100644 --- a/tests/decorators/test_external_python.py +++ b/tests/decorators/test_external_python.py @@ -18,21 +18,17 @@ from __future__ import annotations import datetime -import logging import subprocess import venv from datetime import timedelta -from pathlib import Path +from importlib.util import find_spec from subprocess import CalledProcessError -from tempfile import TemporaryDirectory import pytest from airflow.decorators import setup, task, teardown from airflow.utils import timezone -log = logging.getLogger(__name__) - pytestmark = pytest.mark.db_test @@ -40,6 +36,10 @@ END_DATE = timezone.datetime(2016, 1, 2) INTERVAL = timedelta(hours=12) FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1) +DILL_INSTALLED = find_spec("dill") is not None +DILL_MARKER = pytest.mark.skipif(not DILL_INSTALLED, reason="`dill` is not installed") +CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None +CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") TI_CONTEXT_ENV_VARS = [ "AIRFLOW_CTX_DAG_ID", @@ -49,114 +49,75 @@ ] -@pytest.fixture -def venv_python(): - with TemporaryDirectory() as d: - venv.create(d, with_pip=False) - yield Path(d) / "bin" / "python" +@pytest.fixture(scope="module") +def venv_python(tmp_path_factory): + venv_dir = tmp_path_factory.mktemp("venv") + venv.create(venv_dir, with_pip=False) + return (venv_dir / "bin" / "python").resolve(strict=True).as_posix() -@pytest.fixture -def venv_python_with_cloudpickle_and_dill(): - with TemporaryDirectory() as d: - venv.create(d, with_pip=True) - python_path = Path(d) / "bin" / "python" - subprocess.call([python_path, "-m", "pip", "install", "cloudpickle", "dill"]) - yield python_path +@pytest.fixture(scope="module") +def venv_python_with_cloudpickle_and_dill(tmp_path_factory): + venv_dir = tmp_path_factory.mktemp("venv_serializers") + venv.create(venv_dir, with_pip=True) + python_path = (venv_dir / "bin" / "python").resolve(strict=True).as_posix() + subprocess.call([python_path, "-m", "pip", "install", "cloudpickle", "dill"]) + return python_path class TestExternalPythonDecorator: - def test_with_cloudpickle_works(self, dag_maker, venv_python_with_cloudpickle_and_dill): - @task.external_python(python=venv_python_with_cloudpickle_and_dill, use_cloudpickle=True) - def f(): - """Import cloudpickle to double-check it is installed .""" - try: - import cloudpickle # noqa: F401 - except ImportError: - log.warning( - "Cloudpickle package is required to be installed." - " Please install it with: pip install [cloudpickle]" - ) - - with dag_maker(): - ret = f() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_with_templated_python_cloudpickle(self, dag_maker, venv_python_with_cloudpickle_and_dill): - # add template that produces empty string when rendered - templated_python_with_cloudpickle = venv_python_with_cloudpickle_and_dill.as_posix() + "{{ '' }}" - - @task.external_python(python=templated_python_with_cloudpickle, use_cloudpickle=True) + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], + ) + def test_with_serializer_works(self, serializer, dag_maker, venv_python_with_cloudpickle_and_dill): + @task.external_python(python=venv_python_with_cloudpickle_and_dill, serializer=serializer) def f(): - """Import cloudpickle to double-check it is installed .""" - try: - import cloudpickle # noqa: F401 - except ImportError: - log.warning( - "Cloudpickle package is required to be installed." - " Please install it with: pip install [cloudpickle]" - ) + """Import cloudpickle/dill to double-check it is installed .""" + import cloudpickle # noqa: F401 + import dill # noqa: F401 with dag_maker(): ret = f() ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_no_cloudpickle_installed_raises_exception_when_use_cloudpickle(self, dag_maker, venv_python): - @task.external_python(python=venv_python, use_cloudpickle=True) - def f(): - pass - - with dag_maker(): - ret = f() - - with pytest.raises(CalledProcessError): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_with_dill_works(self, dag_maker, venv_python_with_cloudpickle_and_dill): - @task.external_python(python=venv_python_with_cloudpickle_and_dill, use_dill=True) - def f(): - """Import dill to double-check it is installed .""" - try: - import dill # noqa: F401 - except ImportError: - import logging - - _log = logging.getLogger(__name__) - _log.warning( - "Dill package is required to be installed. Please install it with: pip install [dill]" - ) - - with dag_maker(): - ret = f() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_with_templated_python_dill(self, dag_maker, venv_python_with_cloudpickle_and_dill): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], + ) + def test_with_templated_python_serializer( + self, serializer, dag_maker, venv_python_with_cloudpickle_and_dill + ): # add template that produces empty string when rendered - templated_python_with_dill = venv_python_with_cloudpickle_and_dill.as_posix() + "{{ '' }}" + templated_python_with_cloudpickle = venv_python_with_cloudpickle_and_dill + "{{ '' }}" - @task.external_python(python=templated_python_with_dill, use_dill=True) + @task.external_python(python=templated_python_with_cloudpickle, serializer=serializer) def f(): - """Import dill to double-check it is installed .""" - try: - import dill # noqa: F401 - except ImportError: - import logging - - _log = logging.getLogger(__name__) - _log.warning( - "Dill package is required to be installed. Please install it with: pip install [dill]" - ) + """Import cloudpickle/dill to double-check it is installed .""" + import cloudpickle # noqa: F401 + import dill # noqa: F401 with dag_maker(): ret = f() ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_no_dill_installed_raises_exception_when_use_dill(self, dag_maker, venv_python): - @task.external_python(python=venv_python, use_dill=True) + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], + ) + def test_no_advanced_serializer_installed(self, serializer, dag_maker, venv_python): + @task.external_python(python=venv_python, serializer=serializer) def f(): pass @@ -177,8 +138,17 @@ def f(): with pytest.raises(CalledProcessError): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_with_args(self, dag_maker, venv_python): - @task.external_python(python=venv_python) + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_with_args(self, serializer, dag_maker, venv_python_with_cloudpickle_and_dill): + @task.external_python(python=venv_python_with_cloudpickle_and_dill, serializer=serializer) def f(a, b, c=False, d=False): if a == 0 and b == 1 and c and not d: return True @@ -190,8 +160,17 @@ def f(a, b, c=False, d=False): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_return_none(self, dag_maker, venv_python): - @task.external_python(python=venv_python) + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_return_none(self, serializer, dag_maker, venv_python_with_cloudpickle_and_dill): + @task.external_python(python=venv_python_with_cloudpickle_and_dill, serializer=serializer) def f(): return None @@ -200,8 +179,17 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_nonimported_as_arg(self, dag_maker, venv_python): - @task.external_python(python=venv_python) + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_nonimported_as_arg(self, serializer, dag_maker, venv_python_with_cloudpickle_and_dill): + @task.external_python(python=venv_python_with_cloudpickle_and_dill, serializer=serializer) def f(_): return None @@ -210,9 +198,20 @@ def f(_): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_marking_external_python_task_as_setup(self, dag_maker, venv_python): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_marking_external_python_task_as_setup( + self, serializer, dag_maker, venv_python_with_cloudpickle_and_dill + ): @setup - @task.external_python(python=venv_python) + @task.external_python(python=venv_python_with_cloudpickle_and_dill, serializer=serializer) def f(): return 1 @@ -224,9 +223,20 @@ def f(): assert setup_task.is_setup ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_marking_external_python_task_as_teardown(self, dag_maker, venv_python): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_marking_external_python_task_as_teardown( + self, serializer, dag_maker, venv_python_with_cloudpickle_and_dill + ): @teardown - @task.external_python(python=venv_python) + @task.external_python(python=venv_python_with_cloudpickle_and_dill, serializer=serializer) def f(): return 1 @@ -238,12 +248,21 @@ def f(): assert teardown_task.is_teardown ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False]) def test_marking_external_python_task_as_teardown_with_on_failure_fail( - self, dag_maker, on_failure_fail_dagrun, venv_python + self, serializer, dag_maker, on_failure_fail_dagrun, venv_python_with_cloudpickle_and_dill ): @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun) - @task.external_python(python=venv_python) + @task.external_python(python=venv_python_with_cloudpickle_and_dill, serializer=serializer) def f(): return 1 diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py index 15f37d9a46a91..b91bcaae36be0 100644 --- a/tests/decorators/test_python_virtualenv.py +++ b/tests/decorators/test_python_virtualenv.py @@ -18,26 +18,30 @@ from __future__ import annotations import datetime -import logging import sys +from importlib.util import find_spec from subprocess import CalledProcessError import pytest from airflow.decorators import setup, task, teardown +from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils import timezone -log = logging.getLogger(__name__) - pytestmark = pytest.mark.db_test DEFAULT_DATE = timezone.datetime(2016, 1, 1) PYTHON_VERSION = f"{sys.version_info.major}{sys.version_info.minor}" +DILL_INSTALLED = find_spec("dill") is not None +DILL_MARKER = pytest.mark.skipif(not DILL_INSTALLED, reason="`dill` is not installed") +CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None +CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") class TestPythonVirtualenvDecorator: + @CLOUDPICKLE_MARKER def test_add_cloudpickle(self, dag_maker): - @task.virtualenv(use_cloudpickle=True, system_site_packages=False) + @task.virtualenv(serializer="cloudpickle", system_site_packages=False) def f(): """Ensure cloudpickle is correctly installed.""" import cloudpickle # noqa: F401 @@ -47,25 +51,31 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + @DILL_MARKER def test_add_dill(self, dag_maker): - @task.virtualenv(use_dill=True, system_site_packages=False) + @task.virtualenv(serializer="dill", system_site_packages=False) def f(): """Ensure dill is correctly installed.""" - try: - import dill # noqa: F401 - except ImportError: - import logging - - _log = logging.getLogger(__name__) - _log.warning( - "Dill package is required to be installed. Please install it with: pip install [dill]" - ) + import dill # noqa: F401 with dag_maker(): ret = f() ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + @DILL_MARKER + def test_add_dill_use_dill(self, dag_maker): + @task.virtualenv(use_dill=True, system_site_packages=False) + def f(): + """Ensure dill is correctly installed.""" + import dill # noqa: F401 + + with pytest.warns(RemovedInAirflow3Warning, match="`use_dill` is deprecated and will be removed"): + with dag_maker(): + ret = f() + + ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_no_requirements(self, dag_maker): """Tests that the python callable is invoked on task run.""" @@ -78,8 +88,15 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_no_system_site_packages(self, dag_maker): - @task.virtualenv(system_site_packages=False, python_version=PYTHON_VERSION, use_cloudpickle=True) + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], + ) + def test_no_system_site_packages(self, serializer, dag_maker): + @task.virtualenv(system_site_packages=False, python_version=PYTHON_VERSION, serializer=serializer) def f(): try: import funcsigs # noqa: F401 @@ -92,12 +109,19 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_system_site_packages_cloudpickle(self, dag_maker): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], + ) + def test_system_site_packages(self, serializer, dag_maker): @task.virtualenv( system_site_packages=False, requirements=["funcsigs"], python_version=PYTHON_VERSION, - use_cloudpickle=True, + serializer=serializer, ) def f(): import funcsigs # noqa: F401 @@ -107,27 +131,21 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_system_site_packages_dill(self, dag_maker): - @task.virtualenv( - system_site_packages=False, - requirements=["funcsigs"], - python_version=PYTHON_VERSION, - use_dill=True, - ) - def f(): - import funcsigs # noqa: F401 - - with dag_maker(): - ret = f() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_with_requirements_pinned_cloudpickle(self, dag_maker): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_with_requirements_pinned(self, serializer, dag_maker): @task.virtualenv( system_site_packages=False, requirements=["funcsigs==0.4"], python_version=PYTHON_VERSION, - use_cloudpickle=True, + serializer=serializer, ) def f(): import funcsigs @@ -140,25 +158,16 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_with_requirements_pinned_dill(self, dag_maker): - @task.virtualenv( - system_site_packages=False, - requirements=["funcsigs==0.4"], - python_version=PYTHON_VERSION, - use_dill=True, - ) - def f(): - import funcsigs - - if funcsigs.__version__ != "0.4": - raise Exception - - with dag_maker(): - ret = f() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_with_requirements_file_cloudpickle(self, dag_maker, tmp_path): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_with_requirements_file(self, serializer, dag_maker, tmp_path): requirements_file = tmp_path / "requirements.txt" requirements_file.write_text("funcsigs==0.4\nattrs==23.1.0") @@ -166,7 +175,7 @@ def test_with_requirements_file_cloudpickle(self, dag_maker, tmp_path): system_site_packages=False, requirements="requirements.txt", python_version=PYTHON_VERSION, - use_cloudpickle=True, + serializer=serializer, ) def f(): import funcsigs @@ -184,53 +193,21 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_with_requirements_file_dill(self, dag_maker, tmp_path): - requirements_file = tmp_path / "requirements.txt" - requirements_file.write_text("funcsigs==0.4\nattrs==23.1.0") - + @pytest.mark.parametrize( + "serializer, extra_requirements", + [ + pytest.param("pickle", [], id="pickle"), + pytest.param("dill", ["dill"], marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", ["cloudpickle"], marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, [], id="default"), + ], + ) + def test_unpinned_requirements(self, serializer, extra_requirements, dag_maker): @task.virtualenv( system_site_packages=False, - requirements="requirements.txt", + requirements=["funcsigs", *extra_requirements], python_version=PYTHON_VERSION, - use_dill=True, - ) - def f(): - import funcsigs - - if funcsigs.__version__ != "0.4": - raise Exception - - import attrs - - if attrs.__version__ != "23.1.0": - raise Exception - - with dag_maker(template_searchpath=tmp_path.as_posix()): - ret = f() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_unpinned_requirements_cloudpickle(self, dag_maker): - @task.virtualenv( - system_site_packages=False, - requirements=["funcsigs", "cloudpickle"], - python_version=PYTHON_VERSION, - use_cloudpickle=True, - ) - def f(): - import funcsigs # noqa: F401 - - with dag_maker(): - ret = f() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_unpinned_requirements_dill(self, dag_maker): - @task.virtualenv( - system_site_packages=False, - requirements=["funcsigs", "dill"], - python_version=PYTHON_VERSION, - use_dill=True, + serializer=serializer, ) def f(): import funcsigs # noqa: F401 @@ -240,7 +217,16 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_fail(self, dag_maker): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_fail(self, serializer, dag_maker): @task.virtualenv() def f(): raise Exception @@ -251,8 +237,17 @@ def f(): with pytest.raises(CalledProcessError): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_python_3_cloudpickle(self, dag_maker): - @task.virtualenv(python_version="3", use_cloudpickle=False, requirements=["cloudpickle"]) + @pytest.mark.parametrize( + "serializer, extra_requirements", + [ + pytest.param("pickle", [], id="pickle"), + pytest.param("dill", ["dill"], marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", ["cloudpickle"], marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, [], id="default"), + ], + ) + def test_python_3(self, serializer, extra_requirements, dag_maker): + @task.virtualenv(python_version="3", serializer=serializer, requirements=extra_requirements) def f(): import sys @@ -268,25 +263,17 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_python_3_dill(self, dag_maker): - @task.virtualenv(python_version="3", use_dill=False, requirements=["dill"]) - def f(): - import sys - - print(sys.version) - try: - {}.iteritems() - except AttributeError: - return - raise Exception - - with dag_maker(): - ret = f() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_with_args(self, dag_maker): - @task.virtualenv + @pytest.mark.parametrize( + "serializer, extra_requirements", + [ + pytest.param("pickle", [], id="pickle"), + pytest.param("dill", ["dill"], marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", ["cloudpickle"], marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, [], id="default"), + ], + ) + def test_with_args(self, serializer, extra_requirements, dag_maker): + @task.virtualenv(serializer=serializer, requirements=extra_requirements) def f(a, b, c=False, d=False): if a == 0 and b == 1 and c and not d: return True diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 8a31043fd87ad..2ab71be86e639 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -28,6 +28,7 @@ from collections import namedtuple from datetime import date, datetime, timedelta, timezone as _timezone from functools import partial +from importlib.util import find_spec from subprocess import CalledProcessError from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Generator @@ -65,20 +66,21 @@ from tests.test_utils import AIRFLOW_MAIN_FOLDER from tests.test_utils.db import clear_db_runs -log = logging.getLogger(__name__) +if TYPE_CHECKING: + from airflow.models.dagrun import DagRun pytestmark = pytest.mark.db_test -if TYPE_CHECKING: - from airflow.models.dagrun import DagRun - TI = TaskInstance DEFAULT_DATE = timezone.datetime(2016, 1, 1) TEMPLATE_SEARCHPATH = os.path.join(AIRFLOW_MAIN_FOLDER, "tests", "config_templates") LOGGER_NAME = "airflow.task.operators" DEFAULT_PYTHON_VERSION = f"{sys.version_info[0]}.{sys.version_info[1]}" -PY311 = sys.version_info >= (3, 11) +DILL_INSTALLED = find_spec("dill") is not None +DILL_MARKER = pytest.mark.skipif(not DILL_INSTALLED, reason="`dill` is not installed") +CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None +CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") class BasePythonTest: @@ -815,14 +817,23 @@ def f(templates_dict): task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"}) assert task.templates_dict == {"ds": self.ds_templated} - def test_deepcopy(self): - """Test that PythonVirtualenvOperator are deep-copyable.""" + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_deepcopy(self, serializer): + """Test that operator are deep-copyable.""" def f(): return 1 - task = PythonVirtualenvOperator(python_callable=f, task_id="task") - copy.deepcopy(task) + op = self.opcls(task_id="task", python_callable=f, **self.default_kwargs()) + copy.deepcopy(op) def test_virtualenv_serializable_context_fields(self, create_task_instance): """Ensure all template context fields are listed in the operator. @@ -892,6 +903,34 @@ def f(exit_code): ) assert ti.state == expected_state + @pytest.mark.parametrize( + "serializer", + [ + pytest.param( + "dill", + marks=pytest.mark.skipif( + DILL_INSTALLED, reason="For this test case `dill` shouldn't be installed" + ), + id="dill", + ), + pytest.param( + "cloudpickle", + marks=pytest.mark.skipif( + CLOUDPICKLE_INSTALLED, reason="For this test case `cloudpickle` shouldn't be installed" + ), + id="cloudpickle", + ), + ], + ) + def test_advanced_serializer_not_installed(self, serializer, caplog): + """Test case for check raising an error if dill/cloudpickle is not installed.""" + + def f(a): ... + + with pytest.raises(ModuleNotFoundError): + self.run_as_task(f, op_args=[42], serializer=serializer) + assert f"Unable to import `{serializer}` module." in caplog.text + venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path") @@ -927,56 +966,65 @@ def f(): with pytest.raises(AirflowException, match="requires virtualenv"): self.run_as_task(f) + @CLOUDPICKLE_MARKER def test_add_cloudpickle(self): def f(): """Ensure cloudpickle is correctly installed.""" - try: - import cloudpickle # noqa: F401 - except ImportError: - import logging - - _log = logging.getLogger(__name__) - _log.warning( - "Cloudpickle package is required to be installed." - " Please install it with: pip install [cloudpickle]" - ) + import cloudpickle # noqa: F401 - self.run_as_task(f, use_cloudpickle=True, system_site_packages=False) + self.run_as_task(f, serializer="cloudpickle", system_site_packages=False) + @DILL_MARKER def test_add_dill(self): def f(): """Ensure dill is correctly installed.""" - try: - import dill # noqa: F401 - except ImportError: - import logging + import dill # noqa: F401 - _log = logging.getLogger(__name__) - _log.warning( - "Dill package is required to be installed. Please install it with: pip install [dill]" - ) + self.run_as_task(f, serializer="dill", system_site_packages=False) - self.run_as_task(f, use_dill=True, system_site_packages=False) + @DILL_MARKER + def test_add_dill_use_dill(self): + def f(): + """Ensure dill is correctly installed.""" + import dill # noqa: F401 - def test_no_requirements(self): - """Tests that the python callable is invoked on task run.""" + with pytest.warns(RemovedInAirflow3Warning, match="`use_dill` is deprecated and will be removed"): + self.run_as_task(f, use_dill=True, system_site_packages=False) + def test_ambiguous_serializer(self): def f(): pass - self.run_as_task(f) + with pytest.warns(RemovedInAirflow3Warning, match="`use_dill` is deprecated and will be removed"): + with pytest.raises(AirflowException, match="Both 'use_dill' and 'serializer' parameters are set"): + self.run_as_task(f, use_dill=True, serializer="dill") - def test_no_system_site_packages_cloudpickle(self): + def test_invalid_serializer(self): def f(): - try: - import funcsigs # noqa: F401 - except ImportError: - return True - raise RuntimeError + """Ensure dill is correctly installed.""" + import dill # noqa: F401 + + with pytest.raises(AirflowException, match="Unsupported serializer 'airflow'"): + self.run_as_task(f, serializer="airflow") - self.run_as_task(f, system_site_packages=False, requirements=["cloudpickle"]) + def test_no_requirements(self): + """Tests that the python callable is invoked on task run.""" - def test_no_system_site_packages_dill(self): + def f(): + pass + + self.run_as_task(f) + + @pytest.mark.parametrize( + "serializer, extra_requirements", + [ + pytest.param("pickle", [], id="pickle"), + pytest.param("dill", ["dill"], marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", ["cloudpickle"], marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, [], id="default"), + ], + ) + def test_no_system_site_packages(self, serializer, extra_requirements): def f(): try: import funcsigs # noqa: F401 @@ -984,7 +1032,7 @@ def f(): return True raise RuntimeError - self.run_as_task(f, system_site_packages=False, requirements=["dill"]) + self.run_as_task(f, system_site_packages=False, requirements=extra_requirements) def test_system_site_packages(self): def f(): @@ -1017,29 +1065,35 @@ def f(): self.run_as_task(f, requirements=["funcsigs==0.4"], do_not_use_caching=True) - def test_unpinned_requirements_cloudpickle(self): - def f(): - import funcsigs # noqa: F401 - - self.run_as_task(f, requirements=["funcsigs", "cloudpickle"], system_site_packages=False) - - def test_unpinned_requirements_dill(self): - def f(): - import funcsigs # noqa: F401 - - self.run_as_task(f, requirements=["funcsigs", "dill"], system_site_packages=False) - - def test_range_requirements_cloudpickle(self): + @pytest.mark.parametrize( + "serializer, extra_requirements", + [ + pytest.param("pickle", [], id="pickle"), + pytest.param("dill", ["dill"], marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", ["cloudpickle"], marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, [], id="default"), + ], + ) + def test_unpinned_requirements(self, serializer, extra_requirements): def f(): import funcsigs # noqa: F401 - self.run_as_task(f, requirements=["funcsigs>1.0", "cloudpickle"], system_site_packages=False) + self.run_as_task(f, requirements=["funcsigs", *extra_requirements], system_site_packages=False) - def test_range_requirements_dill(self): + @pytest.mark.parametrize( + "serializer, extra_requirements", + [ + pytest.param("pickle", [], id="pickle"), + pytest.param("dill", ["dill"], marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", ["cloudpickle"], marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, [], id="default"), + ], + ) + def test_range_requirements(self, serializer, extra_requirements): def f(): import funcsigs # noqa: F401 - self.run_as_task(f, requirements=["funcsigs>1.0", "dill"], system_site_packages=False) + self.run_as_task(f, requirements=["funcsigs>1.0", *extra_requirements], system_site_packages=False) def test_requirements_file(self): def f(): @@ -1069,21 +1123,16 @@ def f(): pip_install_options=["--no-deps"], ) - def test_templated_requirements_file_cloudpickle(self): - def f(): - import funcsigs - - assert funcsigs.__version__ == "1.0.2" - - self.run_as_operator( - f, - requirements="requirements.txt", - use_cloudpickle=True, - params={"environ": "templated_unit_test"}, - system_site_packages=False, - ) - - def test_templated_requirements_file_dill(self): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_templated_requirements_file(self, serializer): def f(): import funcsigs @@ -1092,25 +1141,21 @@ def f(): self.run_as_operator( f, requirements="requirements.txt", - use_dill=True, + serializer=serializer, params={"environ": "templated_unit_test"}, system_site_packages=False, ) - def test_python_3_cloudpickle(self): - def f(): - import sys - - print(sys.version) - try: - {}.iteritems() - except AttributeError: - return - raise RuntimeError - - self.run_as_task(f, python_version="3", use_cloudpickle=False, requirements=["cloudpickle"]) - - def test_python_3_dill(self): + @pytest.mark.parametrize( + "serializer, extra_requirements", + [ + pytest.param("pickle", [], id="pickle"), + pytest.param("dill", ["dill"], marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", ["cloudpickle"], marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, [], id="default"), + ], + ) + def test_python_3_serializers(self, serializer, extra_requirements): def f(): import sys @@ -1121,19 +1166,16 @@ def f(): return raise RuntimeError - self.run_as_task(f, python_version="3", use_dill=False, requirements=["dill"]) - - def test_without_cloudpickle(self): - def f(a): - return a - - self.run_as_task(f, system_site_packages=False, use_cloudpickle=False, op_args=[4]) + with pytest.warns( + RemovedInAirflow3Warning, match="Passing non-string types.*python_version is deprecated" + ): + self.run_as_task(f, python_version=3, serializer=serializer, requirements=extra_requirements) - def test_without_dill(self): + def test_with_default(self): def f(a): return a - self.run_as_task(f, system_site_packages=False, use_dill=False, op_args=[4]) + self.run_as_task(f, system_site_packages=False, op_args=[4]) def test_with_index_urls(self): def f(a): @@ -1158,67 +1200,39 @@ def f(a): self.run_as_task(f, venv_cache_path=tmp_dir, op_args=[4]) # This tests might take longer than default 60 seconds as it is serializing a lot of - # context using cloudpickle (which is slow apparently). + # context using dill/cloudpickle (which is slow apparently). @pytest.mark.execution_timeout(120) @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") - @pytest.mark.skipif( - os.environ.get("PYTEST_PLAIN_ASSERTS") != "true", - reason="assertion rewriting breaks this test because cloudpickle will try to serialize " - "AssertRewritingHook including captured stdout and we need to run " - "it with `--assert=plain`pytest option and PYTEST_PLAIN_ASSERTS=true .", + @pytest.mark.parametrize( + "serializer", + [ + pytest.param( + "dill", + marks=[ + DILL_MARKER, + pytest.mark.xfail( + sys.version_info[:2] == (3, 11), + reason=( + "Also this test is failed on Python 3.11 because of impact of " + "regression in Python 3.11 connected likely with CodeType behaviour " + "https://github.com/python/cpython/issues/100316. " + "That likely causes that dill is not able to serialize the `conf` correctly. " + "Issue about fixing it is captured in https://github.com/apache/airflow/issues/35307" + ), + ), + ], + id="dill", + ), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], ) - def test_airflow_context(self): - def f( - # basic - ds_nodash, - inlets, - next_ds, - next_ds_nodash, - outlets, - params, - prev_ds, - prev_ds_nodash, - run_id, - task_instance_key_str, - test_mode, - tomorrow_ds, - tomorrow_ds_nodash, - ts, - ts_nodash, - ts_nodash_with_tz, - yesterday_ds, - yesterday_ds_nodash, - # pendulum-specific - execution_date, - next_execution_date, - prev_execution_date, - prev_execution_date_success, - prev_start_date_success, - prev_end_date_success, - # airflow-specific - macros, - conf, - dag, - dag_run, - task, - # other - **context, - ): - pass - - self.run_as_operator(f, use_cloudpickle=True, system_site_packages=True, requirements=None) - - # This tests might take longer than default 60 seconds as it is serializing a lot of - # context using dill (which is slow apparently). - @pytest.mark.execution_timeout(120) - @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") @pytest.mark.skipif( os.environ.get("PYTEST_PLAIN_ASSERTS") != "true", - reason="assertion rewriting breaks this test because dill will try to serialize " + reason="assertion rewriting breaks this test because serializer will try to serialize " "AssertRewritingHook including captured stdout and we need to run " - "it with `--assert=plain`pytest option and PYTEST_PLAIN_ASSERTS=true .", + "it with `--assert=plain` pytest option and PYTEST_PLAIN_ASSERTS=true .", ) - def test_airflow_context_dill(self): + def test_airflow_context(self, serializer): def f( # basic ds_nodash, @@ -1257,45 +1271,17 @@ def f( ): pass - self.run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None) + self.run_as_operator(f, serializer=serializer, system_site_packages=True, requirements=None) @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") - def test_pendulum_context(self): - def f( - # basic - ds_nodash, - inlets, - next_ds, - next_ds_nodash, - outlets, - prev_ds, - prev_ds_nodash, - run_id, - task_instance_key_str, - test_mode, - tomorrow_ds, - tomorrow_ds_nodash, - ts, - ts_nodash, - ts_nodash_with_tz, - yesterday_ds, - yesterday_ds_nodash, - # pendulum-specific - execution_date, - next_execution_date, - prev_execution_date, - prev_execution_date_success, - prev_start_date_success, - prev_end_date_success, - # other - **context, - ): - pass - - self.run_as_task(f, use_cloudpickle=True, system_site_packages=False, requirements=["pendulum"]) - - @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") - def test_pendulum_context_dill(self): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], + ) + def test_pendulum_context(self, serializer): def f( # basic ds_nodash, @@ -1327,38 +1313,19 @@ def f( ): pass - self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=["pendulum"]) - - @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") - def test_base_context(self): - def f( - # basic - ds_nodash, - inlets, - next_ds, - next_ds_nodash, - outlets, - prev_ds, - prev_ds_nodash, - run_id, - task_instance_key_str, - test_mode, - tomorrow_ds, - tomorrow_ds_nodash, - ts, - ts_nodash, - ts_nodash_with_tz, - yesterday_ds, - yesterday_ds_nodash, - # other - **context, - ): - pass - - self.run_as_task(f, use_cloudpickle=True, system_site_packages=False, requirements=None) + self.run_as_task(f, serializer=serializer, system_site_packages=False, requirements=["pendulum"]) @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") - def test_base_context_dill(self): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_base_context(self, serializer): def f( # basic ds_nodash, @@ -1383,7 +1350,7 @@ def f( ): pass - self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=None) + self.run_as_task(f, serializer=serializer, system_site_packages=False, requirements=None) # when venv tests are run in parallel to other test they create new processes and this might take @@ -1629,8 +1596,6 @@ def f(): # when venv tests are run in parallel to other test they create new processes and this might take # quite some time in shared docker environment and get some contention even between different containers # therefore we have to extend timeouts for those tests - - @pytest.mark.execution_timeout(120) @pytest.mark.virtualenv_operator class TestBranchPythonVirtualenvOperator(BaseTestBranchPythonVirtualenvOperator):