Skip to content

Commit

Permalink
ExternalPythonOperator use version from sys.version_info (#38377)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Mar 23, 2024
1 parent 665d46c commit 48c8f35
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 15 deletions.
54 changes: 39 additions & 15 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
from collections.abc import Container
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Mapping, Sequence, cast
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Mapping, NamedTuple, Sequence, cast

import dill

from airflow.compat.functools import cache
from airflow.exceptions import (
AirflowConfigException,
AirflowException,
Expand Down Expand Up @@ -105,6 +106,40 @@ def my_task()""",
return python_task(python_callable=python_callable, multiple_outputs=multiple_outputs, **kwargs)


@cache
def _parse_version_info(text: str) -> tuple[int, int, int, str, int]:
"""Parse python version info from a text."""
parts = text.strip().split(".")
if len(parts) != 5:
msg = f"Invalid Python version info, expected 5 components separated by '.', but got {text!r}."
raise ValueError(msg)
try:
return int(parts[0]), int(parts[1]), int(parts[2]), parts[3], int(parts[4])
except ValueError:
msg = f"Unable to convert parts {parts} parsed from {text!r} to (int, int, int, str, int)."
raise ValueError(msg) from None


class _PythonVersionInfo(NamedTuple):
"""Provide the same interface as ``sys.version_info``."""

major: int
minor: int
micro: int
releaselevel: str
serial: int

@classmethod
def from_executable(cls, executable: str) -> _PythonVersionInfo:
"""Parse python version info from an executable."""
cmd = [executable, "-c", 'import sys; print(".".join(map(str, sys.version_info)))']
try:
result = subprocess.check_output(cmd, text=True)
except Exception as e:
raise ValueError(f"Error while executing command {cmd}: {e}")
return cls(*_parse_version_info(result.strip()))


class PythonOperator(BaseOperator):
"""
Executes a Python callable.
Expand Down Expand Up @@ -847,27 +882,16 @@ def execute_callable(self):
raise ValueError(f"Python Path '{python_path}' must be a file")
if not python_path.is_absolute():
raise ValueError(f"Python Path '{python_path}' must be an absolute path.")
python_version_as_list_of_strings = self._get_python_version_from_environment()
if (
python_version_as_list_of_strings
and str(python_version_as_list_of_strings[0]) != str(sys.version_info.major)
and (self.op_args or self.op_kwargs)
):
python_version = _PythonVersionInfo.from_executable(self.python)
if python_version.major != sys.version_info.major and (self.op_args or self.op_kwargs):
raise AirflowException(
"Passing op_args or op_kwargs is not supported across different Python "
"major versions for ExternalPythonOperator. Please use string_args."
f"Sys version: {sys.version_info}. "
f"Virtual environment version: {python_version_as_list_of_strings}"
f"Virtual environment version: {python_version}"
)
return self._execute_python_callable_in_subprocess(python_path)

def _get_python_version_from_environment(self) -> list[str]:
try:
result = subprocess.check_output([self.python, "--version"], text=True)
return result.strip().split(" ")[-1].split(".")
except Exception as e:
raise ValueError(f"Error while executing {self.python}: {e}")

def _iter_serializable_context_keys(self):
yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
if self._get_airflow_version_from_target_env():
Expand Down
60 changes: 60 additions & 0 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
PythonOperator,
PythonVirtualenvOperator,
ShortCircuitOperator,
_parse_version_info,
_PythonVersionInfo,
get_current_context,
)
from airflow.utils import timezone
Expand Down Expand Up @@ -1686,3 +1688,61 @@ def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, clear_
assert isinstance(actual_skipped, Generator)
assert set(actual_skipped) == {op3}
assert actual_kwargs["execution_date"] == dagrun.logical_date


@pytest.mark.parametrize(
"text_input, expected_tuple",
[
pytest.param(" 2.7.18.final.0 ", (2, 7, 18, "final", 0), id="py27"),
pytest.param("3.10.13.final.0\n", (3, 10, 13, "final", 0), id="py310"),
pytest.param("\n3.13.0.alpha.3", (3, 13, 0, "alpha", 3), id="py313-alpha"),
],
)
def test_parse_version_info(text_input, expected_tuple):
assert _parse_version_info(text_input) == expected_tuple


@pytest.mark.parametrize(
"text_input",
[
pytest.param(" 2.7.18.final.0.3 ", id="more-than-5-parts"),
pytest.param("3.10.13\n", id="less-than-5-parts"),
pytest.param("Apache Airflow 3.0.0", id="garbage-input"),
],
)
def test_parse_version_invalid_parts(text_input):
with pytest.raises(ValueError, match="expected 5 components separated by '\.'"):
_parse_version_info(text_input)


@pytest.mark.parametrize(
"text_input",
[
pytest.param("2EOL.7.18.final.0", id="major-non-int"),
pytest.param("3.XXX.13.final.3", id="minor-non-int"),
pytest.param("3.13.0a.alpha.3", id="micro-non-int"),
pytest.param("3.8.18.alpha.beta", id="serial-non-int"),
],
)
def test_parse_version_invalid_parts_types(text_input):
with pytest.raises(ValueError, match="Unable to convert parts.*parsed from.*to"):
_parse_version_info(text_input)


def test_python_version_info_fail_subprocess(mocker):
mocked_subprocess = mocker.patch("subprocess.check_output")
mocked_subprocess.side_effect = RuntimeError("some error")

with pytest.raises(ValueError, match="Error while executing command.*some error"):
_PythonVersionInfo.from_executable("/dev/null")
mocked_subprocess.assert_called_once()


def test_python_version_info(mocker):
result = _PythonVersionInfo.from_executable(sys.executable)
assert result.major == sys.version_info.major
assert result.minor == sys.version_info.minor
assert result.micro == sys.version_info.micro
assert result.releaselevel == sys.version_info.releaselevel
assert result.serial == sys.version_info.serial
assert list(result) == list(sys.version_info)

0 comments on commit 48c8f35

Please sign in to comment.