Skip to content

Commit abff377

Browse files
committed
run formatting
1 parent 07fa0a9 commit abff377

File tree

2 files changed

+40
-29
lines changed

2 files changed

+40
-29
lines changed

src/sagemaker/image_uris.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def retrieve(
164164
config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type)
165165

166166
original_version = version
167-
version = _validate_version_and_set_if_needed(version, config, framework)
167+
version = _validate_version_and_set_if_needed(
168+
version, config, framework, image_scope, base_framework_version
169+
)
168170
version_config = config["versions"][_version_for_config(version, config)]
169171

170172
if framework == HUGGING_FACE_FRAMEWORK:
@@ -199,9 +201,8 @@ def retrieve(
199201
container_version = sdk_version + "-" + container_version
200202

201203
if framework == HUGGING_FACE_FRAMEWORK:
202-
pt_or_tf_version = (
203-
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
204-
)
204+
hf_base_fw_regex = _get_hf_base_framework_version_regex()
205+
pt_or_tf_version = hf_base_fw_regex.match(base_framework_version).group(2)
205206
_version = original_version
206207

207208
if repo in [
@@ -255,6 +256,10 @@ def retrieve(
255256
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
256257

257258

259+
def _get_hf_base_framework_version_regex():
260+
return re.compile("^(pytorch|tensorflow)(.*)$")
261+
262+
258263
def _get_instance_type_family(instance_type):
259264
"""Return the family of the instance type.
260265
@@ -509,7 +514,7 @@ def _get_end_of_support_warn_message(version, config, framework):
509514
return ""
510515

511516

512-
def _validate_version_and_set_if_needed(version, config, framework):
517+
def _validate_version_and_set_if_needed(version, config, framework, scope, base_framework_version):
513518
"""Checks if the framework/algorithm version is one of the supported versions."""
514519
available_versions = list(config["versions"].keys())
515520
aliased_versions = list(config.get("version_aliases", {}).keys())
@@ -527,8 +532,21 @@ def _validate_version_and_set_if_needed(version, config, framework):
527532

528533
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
529534

530-
# For DLCs, warn if image is out of support. Use aliased version to grab information about the latest
531-
end_of_support_warning = _get_end_of_support_warn_message(version, config, framework)
535+
# For DLCs, warn if image is out of support.
536+
eos_version, eos_config, eos_framework = version, config, framework
537+
538+
# Default to underlying base framework for HF, except for neuron which has its own policies.
539+
inf_or_trn = "inf" in config.get("processors", []) or "trn" in config.get("processors", [])
540+
if framework == HUGGING_FACE_FRAMEWORK and base_framework_version and not inf_or_trn:
541+
hf_base_fw_regex = _get_hf_base_framework_version_regex()
542+
hf_match = hf_base_fw_regex.search(base_framework_version)
543+
if hf_match:
544+
eos_framework = hf_match.group(1)
545+
eos_version = hf_match.group(2)
546+
eos_config = _config_for_framework_and_scope(eos_framework, scope)
547+
end_of_support_warning = _get_end_of_support_warn_message(
548+
eos_version, eos_config, eos_framework
549+
)
532550
if end_of_support_warning:
533551
logger.warning(end_of_support_warning)
534552

tests/unit/sagemaker/image_uris/test_dlc_frameworks.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import io
1616
import contextlib
17-
import datetime
1817

1918
from packaging.version import Version
2019

@@ -93,21 +92,17 @@ def _test_image_uris(
9392
assert expected == uri
9493

9594

96-
def _test_end_of_support_messaging(
97-
framework,
98-
fw_version,
99-
py_version,
100-
scope,
101-
accelerator_type=None
102-
):
95+
def _test_end_of_support_messaging(framework, fw_version, py_version, scope, accelerator_type=None):
10396
base_args = {
10497
"framework": framework,
10598
"version": fw_version,
10699
"py_version": py_version,
107100
"image_scope": scope,
108101
}
109102

110-
config = image_uris._config_for_framework_and_scope(framework=framework, image_scope=scope, accelerator_type=accelerator_type)
103+
config = image_uris._config_for_framework_and_scope(
104+
framework=framework, image_scope=scope, accelerator_type=accelerator_type
105+
)
111106
version_support = config.get("version_support")
112107
aliases = config.get("version_aliases")
113108

@@ -116,7 +111,9 @@ def _test_end_of_support_messaging(
116111
# Version in version_support MUST be in an aliased value
117112
assert v in aliases.values()
118113

119-
expected_warning_message = image_uris._get_end_of_support_warn_message(fw_version, config, framework)
114+
expected_warning_message = image_uris._get_end_of_support_warn_message(
115+
fw_version, config, framework
116+
)
120117

121118
TYPES_AND_PROCESSORS = INSTANCE_TYPES_AND_PROCESSORS
122119
if framework == "pytorch" and Version(fw_version) >= Version("1.13"):
@@ -169,7 +166,7 @@ def test_tensorflow_training(tensorflow_training_version, tensorflow_training_py
169166
"tf_training_version": tensorflow_training_version,
170167
"py_version": tensorflow_training_py_version,
171168
}
172-
169+
173170
framework = "tensorflow"
174171
scope = "training"
175172

@@ -531,19 +528,15 @@ def _sagemaker_or_dlc_account(repo, region):
531528

532529
def test_end_of_support():
533530
config = {
534-
"version_aliases": {
535-
"1.6": "1.6.0"
536-
},
537-
"version_support": {
538-
"1.6.0": {
539-
"end_of_support": "2021-06-21T00:00:00.0Z"
540-
}
541-
}
531+
"version_aliases": {"1.6": "1.6.0"},
532+
"version_support": {"1.6.0": {"end_of_support": "2021-06-21T00:00:00.0Z"}},
542533
}
543534
dlc_support_policy = "https://aws.amazon.com/releasenotes/dlc-support-policy/"
544-
warn_message = image_uris._get_end_of_support_warn_message(config=config, framework="pytorch", version="1.6.0")
535+
warn_message = image_uris._get_end_of_support_warn_message(
536+
config=config, framework="pytorch", version="1.6.0"
537+
)
545538

546539
assert warn_message == (
547-
f"The pytorch 1.6.0 DLC has reached end of support. "
540+
f"The pytorch 1.6.0 DLC has reached end of support. "
548541
f"Please choose a supported version from our support policy - {dlc_support_policy}."
549-
)
542+
)

0 commit comments

Comments
 (0)