From 0839ce1d48c7a78f34d97c83f2793c58dadc1abb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Jun 2024 07:09:00 +0800 Subject: [PATCH] FLYTE_INTERNAL_IMAGE should have higher precedence (#2523) Signed-off-by: Kevin Su Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Co-authored-by: Thomas J. Fan --- flytekit/configuration/default_images.py | 9 +++++++-- flytekit/core/constants.py | 3 +++ tests/flytekit/unit/configuration/test_image_config.py | 6 +++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/flytekit/configuration/default_images.py b/flytekit/configuration/default_images.py index 63e02e771d..47353cf5af 100644 --- a/flytekit/configuration/default_images.py +++ b/flytekit/configuration/default_images.py @@ -4,6 +4,8 @@ import typing from contextlib import suppress +from flytekit.core.constants import FLYTE_INTERNAL_IMAGE_ENV_VAR + class PythonVersion(enum.Enum): PYTHON_3_8 = (3, 8) @@ -35,13 +37,16 @@ def default_image(cls) -> str: if default_image is not None: return default_image - default_image_str = os.environ.get("FLYTE_INTERNAL_IMAGE", cls.find_image_for()) - return default_image_str + return cls.find_image_for() @classmethod def find_image_for( cls, python_version: typing.Optional[PythonVersion] = None, flytekit_version: typing.Optional[str] = None ) -> str: + default_image_str = os.getenv(FLYTE_INTERNAL_IMAGE_ENV_VAR) + if default_image_str: + return default_image_str + if python_version is None: python_version = PythonVersion((sys.version_info.major, sys.version_info.minor)) diff --git a/flytekit/core/constants.py b/flytekit/core/constants.py index 8b85479fcc..6e8b0705f1 100644 --- a/flytekit/core/constants.py +++ b/flytekit/core/constants.py @@ -9,3 +9,6 @@ START_NODE_ID = "start-node" END_NODE_ID = "end-node" + +# If set this environment variable overrides the default container image and the default base image in ImageSpec. +FLYTE_INTERNAL_IMAGE_ENV_VAR = "FLYTE_INTERNAL_IMAGE" diff --git a/tests/flytekit/unit/configuration/test_image_config.py b/tests/flytekit/unit/configuration/test_image_config.py index 2597d5befa..8ae3b2d6fd 100644 --- a/tests/flytekit/unit/configuration/test_image_config.py +++ b/tests/flytekit/unit/configuration/test_image_config.py @@ -8,6 +8,7 @@ import flytekit from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages, PythonVersion +from flytekit.core.constants import FLYTE_INTERNAL_IMAGE_ENV_VAR @pytest.mark.parametrize( @@ -33,12 +34,15 @@ def test_set_both(python_version_enum, flytekit_version, expected_image_string): assert DefaultImages.find_image_for(python_version_enum, flytekit_version) == expected_image_string -def test_image_config_auto(): +def test_image_config_auto(monkeypatch): x = ImageConfig.auto_default_image() assert x.images[0].name == "default" version_str = f"{sys.version_info.major}.{sys.version_info.minor}" assert x.images[0].full == f"cr.flyte.org/flyteorg/flytekit:py{version_str}-latest" + monkeypatch.setenv(FLYTE_INTERNAL_IMAGE_ENV_VAR, "test") + assert DefaultImages.find_image_for() == "test" + def test_image_from_flytectl_config(): image_config = ImageConfig.auto(