From f73db16ba67980ab71273a406d88ab538b646fdc Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 30 Jul 2020 14:57:00 -0700 Subject: [PATCH] breaking: deprecate is_version_equal_or_higher and is_version_equal_or_lower --- src/sagemaker/fw_utils.py | 34 --------------------------- src/sagemaker/mxnet/estimator.py | 7 +++--- src/sagemaker/pytorch/estimator.py | 7 +++--- src/sagemaker/tensorflow/estimator.py | 2 +- 4 files changed, 7 insertions(+), 43 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index a7f3455851..f8d316fe3a 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -49,40 +49,6 @@ SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge") -def is_version_equal_or_higher(lowest_version, framework_version): - """Determine whether the ``framework_version`` is equal to or higher than - ``lowest_version`` - - Args: - lowest_version (List[int]): lowest version represented in an integer - list - framework_version (str): framework version string - - Returns: - bool: Whether or not ``framework_version`` is equal to or higher than - ``lowest_version`` - """ - version_list = [int(s) for s in framework_version.split(".")] - return version_list >= lowest_version[0 : len(version_list)] - - -def is_version_equal_or_lower(highest_version, framework_version): - """Determine whether the ``framework_version`` is equal to or lower than - ``highest_version`` - - Args: - highest_version (List[int]): highest version represented in an integer - list - framework_version (str): framework version string - - Returns: - bool: Whether or not ``framework_version`` is equal to or lower than - ``highest_version`` - """ - version_list = [int(s) for s in framework_version.split(".")] - return version_list <= highest_version[0 : len(version_list)] - - def validate_source_dir(script, directory): """Validate that the source directory exists and it contains the user script Args: diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index c1a181e5b3..f5fe29c324 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -15,11 +15,12 @@ import logging +from packaging.version import Version + from sagemaker.estimator import Framework from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, - is_version_equal_or_higher, python_deprecation_warning, validate_version_or_image_args, warn_if_parameter_server_with_multi_gpu, @@ -157,9 +158,7 @@ def __init__( if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for MXNet v1.6 or greater: - if self.framework_version and is_version_equal_or_higher( - [1, 6], self.framework_version - ): + if self.framework_version and Version(self.framework_version) >= Version("1.6"): kwargs["enable_sagemaker_metrics"] = True super(MXNet, self).__init__( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 39282e63d7..97646cea0a 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -15,11 +15,12 @@ import logging +from packaging.version import Version + from sagemaker.estimator import Framework from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, - is_version_equal_or_higher, python_deprecation_warning, validate_version_or_image_args, ) @@ -116,9 +117,7 @@ def __init__( if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for PT v1.3 or greater: - if self.framework_version and is_version_equal_or_higher( - [1, 3], self.framework_version - ): + if self.framework_version and Version(self.framework_version) >= Version("1.3"): kwargs["enable_sagemaker_metrics"] = True super(PyTorch, self).__init__( diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 0bf4e41b26..3a5a2721b9 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -129,7 +129,7 @@ def __init__( if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for TF v1.15 or greater: - if framework_version and fw.is_version_equal_or_higher([1, 15], framework_version): + if framework_version and version.Version(framework_version) >= version.Version("1.15"): kwargs["enable_sagemaker_metrics"] = True super(TensorFlow, self).__init__(image_uri=image_uri, **kwargs)