diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index b833c1d7a6..3193524056 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -79,6 +79,14 @@ "1.12": "1.12.1", "1.13": "1.13.1" }, + "version_support": { + "1.12.1": { + "end_of_support": "2023-07-01T00:00:00.0Z" + }, + "1.13.1": { + "end_of_support": "2023-10-28T00:00:00.0Z" + } + }, "versions": { "0.4.0": { "py_versions": [ @@ -346,7 +354,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "pytorch-inference" + "repository": "pytorch-inference" }, "1.6.0": { "py_versions": [ @@ -388,7 +396,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "pytorch-inference" + "repository": "pytorch-inference" }, "1.7.1": { "py_versions": [ @@ -914,6 +922,14 @@ "1.12": "1.12.1", "1.13": "1.13.1" }, + "version_support": { + "1.12.1": { + "end_of_support": "2023-07-01T00:00:00.0Z" + }, + "1.13.1": { + "end_of_support": "2023-10-28T00:00:00.0Z" + } + }, "versions": { "0.4.0": { "py_versions": [ diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index b9101acd96..b3c93f11bf 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -17,6 +17,7 @@ import logging import os import re +import datetime from typing import Optional from packaging.version import Version @@ -163,7 +164,9 @@ def retrieve( config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type) original_version = version - version = _validate_version_and_set_if_needed(version, config, framework) + version = _validate_version_and_set_if_needed( + version, config, framework, image_scope, base_framework_version + ) version_config = config["versions"][_version_for_config(version, config)] if framework == HUGGING_FACE_FRAMEWORK: @@ -198,9 +201,8 @@ def retrieve( container_version = sdk_version + "-" + container_version if framework == HUGGING_FACE_FRAMEWORK: - pt_or_tf_version = ( - re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2) - ) + hf_base_fw_regex = _get_hf_base_framework_version_regex() + pt_or_tf_version = hf_base_fw_regex.match(base_framework_version).group(2) _version = original_version if repo in [ @@ -254,6 +256,10 @@ def retrieve( return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo) +def _get_hf_base_framework_version_regex(): + return re.compile("^(pytorch|tensorflow)(.*)$") + + def _get_instance_type_family(instance_type): """Return the family of the instance type. @@ -446,7 +452,69 @@ def _validate_accelerator_type(accelerator_type): ) -def _validate_version_and_set_if_needed(version, config, framework): +def _get_end_of_support_warn_message(version, config, framework): + """ + Get end of support warning message if needed. + + Args: + version (str): ML framework version + config (dict): config json loaded into python dictionary + framework (str): ML framework + + Returns: + str: Warning message if version is nearing or out of support, else empty string + """ + version_support = config.get("version_support") + aliases = config.get("version_aliases") + + # If version_support and version_aliases are not implemented, do not return warning message + if not (version_support and aliases): + return "" + + dlc_support_policy = "https://aws.amazon.com/releasenotes/dlc-support-policy/" + + end_of_support_msg = ( + f"The {framework} {version} DLC has reached end of support. " + f"Please choose a supported version from our support policy - {dlc_support_policy}." + ) + + long_version = version + if version in aliases: + long_version = aliases[version] + + # If we have version support, but the long version is not in it, assume this is out of support + if not version_support.get(long_version): + return end_of_support_msg + + end_of_support = version_support[long_version].get("end_of_support") + + # If no end of support specified, assume there is indefinite support + if not end_of_support: + return "" + + # Convert json object to UTC timezone string + end_of_support_dt = ( + datetime.datetime.strptime(end_of_support, "%Y-%m-%dT%H:%M:%S.%fZ") + .replace(tzinfo=None) + .astimezone(tz=datetime.timezone.utc) + ) + # Ensure that the version is still supported + current_dt = datetime.datetime.now(datetime.timezone.utc) + time_delt_days = (end_of_support_dt - current_dt).days + if current_dt >= end_of_support_dt: + return end_of_support_msg + if time_delt_days <= 60: + return ( + f"The {framework} {version} DLC is approaching end of support, " + f"and patching will stop on {end_of_support_dt.strftime('%Y-%m-%d %Z')}. " + f"Please choose a supported version from our support policy - {dlc_support_policy}." + ) + + # Version is still well in support window, do not warn + return "" + + +def _validate_version_and_set_if_needed(version, config, framework, scope, base_framework_version): """Checks if the framework/algorithm version is one of the supported versions.""" available_versions = list(config["versions"].keys()) aliased_versions = list(config.get("version_aliases", {}).keys()) @@ -463,6 +531,25 @@ def _validate_version_and_set_if_needed(version, config, framework): return available_versions[0] _validate_arg(version, available_versions + aliased_versions, "{} version".format(framework)) + + # For DLCs, warn if image is out of support. + eos_version, eos_config, eos_framework = version, config, framework + + # Default to underlying base framework for HF, except for neuron which has its own policies. + inf_or_trn = "inf" in config.get("processors", []) or "trn" in config.get("processors", []) + if framework == HUGGING_FACE_FRAMEWORK and base_framework_version and not inf_or_trn: + hf_base_fw_regex = _get_hf_base_framework_version_regex() + hf_match = hf_base_fw_regex.search(base_framework_version) + if hf_match: + eos_framework = hf_match.group(1) + eos_version = hf_match.group(2) + eos_config = _config_for_framework_and_scope(eos_framework, scope) + end_of_support_warning = _get_end_of_support_warn_message( + eos_version, eos_config, eos_framework + ) + if end_of_support_warning: + logger.warning(end_of_support_warning) + return version diff --git a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py index 0f00f7eb22..a61180d3de 100644 --- a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py +++ b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py @@ -12,6 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import io +import contextlib + from packaging.version import Version from sagemaker import image_uris @@ -89,6 +92,47 @@ def _test_image_uris( assert expected == uri +def _test_end_of_support_messaging(framework, fw_version, py_version, scope, accelerator_type=None): + base_args = { + "framework": framework, + "version": fw_version, + "py_version": py_version, + "image_scope": scope, + } + + config = image_uris._config_for_framework_and_scope( + framework=framework, image_scope=scope, accelerator_type=accelerator_type + ) + version_support = config.get("version_support") + aliases = config.get("version_aliases") + + if version_support: + for v, _ in version_support.items(): + # Version in version_support MUST be in an aliased value + assert v in aliases.values() + + expected_warning_message = image_uris._get_end_of_support_warn_message( + fw_version, config, framework + ) + + TYPES_AND_PROCESSORS = INSTANCE_TYPES_AND_PROCESSORS + if framework == "pytorch" and Version(fw_version) >= Version("1.13"): + """Handle P2 deprecation""" + TYPES_AND_PROCESSORS = RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS + + for instance_type, _processor in TYPES_AND_PROCESSORS: + retrieve_io = io.StringIO() + with contextlib.redirect_stderr(retrieve_io): + _uri = image_uris.retrieve(region=REGION, instance_type=instance_type, **base_args) + assert expected_warning_message in retrieve_io.getvalue() + + for region in DLC_ALTERNATE_REGION_ACCOUNTS: + retrieve_io = io.StringIO() + with contextlib.redirect_stderr(retrieve_io): + _uri = image_uris.retrieve(region=region, instance_type="ml.c4.xlarge", **base_args) + assert expected_warning_message in retrieve_io.getvalue() + + def test_chainer(chainer_version, chainer_py_version): expected_fn_args = { "chainer_version": chainer_version, @@ -123,15 +167,25 @@ def test_tensorflow_training(tensorflow_training_version, tensorflow_training_py "py_version": tensorflow_training_py_version, } + framework = "tensorflow" + scope = "training" + _test_image_uris( - "tensorflow", + framework, tensorflow_training_version, tensorflow_training_py_version, - "training", + scope, _expected_tf_training_uri, expected_fn_args, ) + _test_end_of_support_messaging( + framework=framework, + fw_version=tensorflow_training_version, + py_version=tensorflow_training_py_version, + scope=scope, + ) + def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", region=REGION): version = Version(tf_training_version) @@ -155,15 +209,25 @@ def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", def test_tensorflow_inference(tensorflow_inference_version, tensorflow_inference_py_version): + framework = "tensorflow" + scope = "inference" + _test_image_uris( - "tensorflow", + framework, tensorflow_inference_version, tensorflow_inference_py_version, - "inference", + scope, _expected_tf_inference_uri, {"tf_inference_version": tensorflow_inference_version}, ) + _test_end_of_support_messaging( + framework=framework, + fw_version=tensorflow_inference_version, + py_version=tensorflow_inference_py_version, + scope=scope, + ) + def test_tensorflow_eia(tensorflow_eia_version): base_args = { @@ -222,16 +286,25 @@ def test_mxnet_training(mxnet_training_version, mxnet_training_py_version): "mxnet_version": mxnet_training_version, "py_version": mxnet_training_py_version, } + framework = "mxnet" + scope = "training" _test_image_uris( - "mxnet", + framework, mxnet_training_version, mxnet_training_py_version, - "training", + scope, _expected_mxnet_training_uri, expected_fn_args, ) + _test_end_of_support_messaging( + framework=framework, + fw_version=mxnet_training_version, + py_version=mxnet_training_py_version, + scope=scope, + ) + def _expected_mxnet_training_uri(mxnet_version, py_version, processor="cpu", region=REGION): version = Version(mxnet_version) @@ -258,15 +331,25 @@ def test_mxnet_inference(mxnet_inference_version, mxnet_inference_py_version): "py_version": mxnet_inference_py_version, } + framework = "mxnet" + scope = "inference" + _test_image_uris( - "mxnet", + framework, mxnet_inference_version, mxnet_inference_py_version, - "inference", + scope, _expected_mxnet_inference_uri, expected_fn_args, ) + _test_end_of_support_messaging( + framework=framework, + fw_version=mxnet_inference_version, + py_version=mxnet_inference_py_version, + scope=scope, + ) + def test_mxnet_eia(mxnet_eia_version, mxnet_eia_py_version): base_args = { @@ -319,17 +402,26 @@ def _expected_mxnet_inference_uri( def test_pytorch_training(pytorch_training_version, pytorch_training_py_version): + framework = "pytorch" + scope = "training" + _test_image_uris( - "pytorch", + framework, pytorch_training_version, pytorch_training_py_version, - "training", + scope, _expected_pytorch_training_uri, { "pytorch_version": pytorch_training_version, "py_version": pytorch_training_py_version, }, ) + _test_end_of_support_messaging( + framework=framework, + fw_version=pytorch_training_version, + py_version=pytorch_training_py_version, + scope=scope, + ) def _expected_pytorch_training_uri(pytorch_version, py_version, processor="cpu", region=REGION): @@ -350,11 +442,14 @@ def _expected_pytorch_training_uri(pytorch_version, py_version, processor="cpu", def test_pytorch_inference(pytorch_inference_version, pytorch_inference_py_version): + framework = "pytorch" + scope = "inference" + _test_image_uris( - "pytorch", + framework, pytorch_inference_version, pytorch_inference_py_version, - "inference", + scope, _expected_pytorch_inference_uri, { "pytorch_version": pytorch_inference_version, @@ -362,6 +457,13 @@ def test_pytorch_inference(pytorch_inference_version, pytorch_inference_py_versi }, ) + _test_end_of_support_messaging( + framework=framework, + fw_version=pytorch_inference_version, + py_version=pytorch_inference_py_version, + scope=scope, + ) + def _expected_pytorch_inference_uri(pytorch_version, py_version, processor="cpu", region=REGION): version = Version(pytorch_version) @@ -422,3 +524,19 @@ def _sagemaker_or_dlc_account(repo, region): ) else: return DLC_ACCOUNT if region == REGION else DLC_ALTERNATE_REGION_ACCOUNTS[region] + + +def test_end_of_support(): + config = { + "version_aliases": {"1.6": "1.6.0"}, + "version_support": {"1.6.0": {"end_of_support": "2021-06-21T00:00:00.0Z"}}, + } + dlc_support_policy = "https://aws.amazon.com/releasenotes/dlc-support-policy/" + warn_message = image_uris._get_end_of_support_warn_message( + config=config, framework="pytorch", version="1.6.0" + ) + + assert warn_message == ( + f"The pytorch 1.6.0 DLC has reached end of support. " + f"Please choose a supported version from our support policy - {dlc_support_policy}." + )