Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/sagemaker/image_uri_config/pytorch.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@
"1.12": "1.12.1",
"1.13": "1.13.1"
},
"version_support": {
"1.12.1": {
"end_of_support": "2023-07-01T00:00:00.0Z"
},
"1.13.1": {
"end_of_support": "2023-10-28T00:00:00.0Z"
}
},
"versions": {
"0.4.0": {
"py_versions": [
Expand Down Expand Up @@ -346,7 +354,7 @@
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "pytorch-inference"
"repository": "pytorch-inference"
},
"1.6.0": {
"py_versions": [
Expand Down Expand Up @@ -388,7 +396,7 @@
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "pytorch-inference"
"repository": "pytorch-inference"
},
"1.7.1": {
"py_versions": [
Expand Down Expand Up @@ -914,6 +922,14 @@
"1.12": "1.12.1",
"1.13": "1.13.1"
},
"version_support": {
"1.12.1": {
"end_of_support": "2023-07-01T00:00:00.0Z"
},
"1.13.1": {
"end_of_support": "2023-10-28T00:00:00.0Z"
}
},
"versions": {
"0.4.0": {
"py_versions": [
Expand Down
97 changes: 92 additions & 5 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import os
import re
import datetime
from typing import Optional
from packaging.version import Version

Expand Down Expand Up @@ -163,7 +164,9 @@ def retrieve(
config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type)

original_version = version
version = _validate_version_and_set_if_needed(version, config, framework)
version = _validate_version_and_set_if_needed(
version, config, framework, image_scope, base_framework_version
)
version_config = config["versions"][_version_for_config(version, config)]

if framework == HUGGING_FACE_FRAMEWORK:
Expand Down Expand Up @@ -198,9 +201,8 @@ def retrieve(
container_version = sdk_version + "-" + container_version

if framework == HUGGING_FACE_FRAMEWORK:
pt_or_tf_version = (
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
)
hf_base_fw_regex = _get_hf_base_framework_version_regex()
pt_or_tf_version = hf_base_fw_regex.match(base_framework_version).group(2)
_version = original_version

if repo in [
Expand Down Expand Up @@ -254,6 +256,10 @@ def retrieve(
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)


def _get_hf_base_framework_version_regex():
return re.compile("^(pytorch|tensorflow)(.*)$")


def _get_instance_type_family(instance_type):
"""Return the family of the instance type.

Expand Down Expand Up @@ -446,7 +452,69 @@ def _validate_accelerator_type(accelerator_type):
)


def _validate_version_and_set_if_needed(version, config, framework):
def _get_end_of_support_warn_message(version, config, framework):
"""
Get end of support warning message if needed.

Args:
version (str): ML framework version
config (dict): config json loaded into python dictionary
framework (str): ML framework

Returns:
str: Warning message if version is nearing or out of support, else empty string
"""
version_support = config.get("version_support")
aliases = config.get("version_aliases")

# If version_support and version_aliases are not implemented, do not return warning message
if not (version_support and aliases):
return ""

dlc_support_policy = "https://aws.amazon.com/releasenotes/dlc-support-policy/"

end_of_support_msg = (
f"The {framework} {version} DLC has reached end of support. "
f"Please choose a supported version from our support policy - {dlc_support_policy}."
)

long_version = version
if version in aliases:
long_version = aliases[version]

# If we have version support, but the long version is not in it, assume this is out of support
if not version_support.get(long_version):
return end_of_support_msg

end_of_support = version_support[long_version].get("end_of_support")

# If no end of support specified, assume there is indefinite support
if not end_of_support:
return ""

# Convert json object to UTC timezone string
end_of_support_dt = (
datetime.datetime.strptime(end_of_support, "%Y-%m-%dT%H:%M:%S.%fZ")
.replace(tzinfo=None)
.astimezone(tz=datetime.timezone.utc)
)
# Ensure that the version is still supported
current_dt = datetime.datetime.now(datetime.timezone.utc)
time_delt_days = (end_of_support_dt - current_dt).days
if current_dt >= end_of_support_dt:
return end_of_support_msg
if time_delt_days <= 60:
return (
f"The {framework} {version} DLC is approaching end of support, "
f"and patching will stop on {end_of_support_dt.strftime('%Y-%m-%d %Z')}. "
f"Please choose a supported version from our support policy - {dlc_support_policy}."
)

# Version is still well in support window, do not warn
return ""


def _validate_version_and_set_if_needed(version, config, framework, scope, base_framework_version):
"""Checks if the framework/algorithm version is one of the supported versions."""
available_versions = list(config["versions"].keys())
aliased_versions = list(config.get("version_aliases", {}).keys())
Expand All @@ -463,6 +531,25 @@ def _validate_version_and_set_if_needed(version, config, framework):
return available_versions[0]

_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))

# For DLCs, warn if image is out of support.
eos_version, eos_config, eos_framework = version, config, framework

# Default to underlying base framework for HF, except for neuron which has its own policies.
inf_or_trn = "inf" in config.get("processors", []) or "trn" in config.get("processors", [])
Copy link
Contributor Author

@arjkesh arjkesh Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does mean that HF will need to keep up-to-date with latest patch version releases of frameworks to stay in support

Example: DLC release PT 1.13.2

HF image, if still 1.13.1, will display as out of support until HF releases 1.13.2 and updates

if framework == HUGGING_FACE_FRAMEWORK and base_framework_version and not inf_or_trn:
hf_base_fw_regex = _get_hf_base_framework_version_regex()
hf_match = hf_base_fw_regex.search(base_framework_version)
if hf_match:
eos_framework = hf_match.group(1)
eos_version = hf_match.group(2)
eos_config = _config_for_framework_and_scope(eos_framework, scope)
end_of_support_warning = _get_end_of_support_warn_message(
eos_version, eos_config, eos_framework
)
if end_of_support_warning:
logger.warning(end_of_support_warning)

return version


Expand Down
Loading