File tree Expand file tree Collapse file tree 4 files changed +7
-43
lines changed Expand file tree Collapse file tree 4 files changed +7
-43
lines changed Original file line number Diff line number Diff line change 4949SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge" , "ml.p3.2xlarge" )
5050
5151
52- def is_version_equal_or_higher (lowest_version , framework_version ):
53- """Determine whether the ``framework_version`` is equal to or higher than
54- ``lowest_version``
55-
56- Args:
57- lowest_version (List[int]): lowest version represented in an integer
58- list
59- framework_version (str): framework version string
60-
61- Returns:
62- bool: Whether or not ``framework_version`` is equal to or higher than
63- ``lowest_version``
64- """
65- version_list = [int (s ) for s in framework_version .split ("." )]
66- return version_list >= lowest_version [0 : len (version_list )]
67-
68-
69- def is_version_equal_or_lower (highest_version , framework_version ):
70- """Determine whether the ``framework_version`` is equal to or lower than
71- ``highest_version``
72-
73- Args:
74- highest_version (List[int]): highest version represented in an integer
75- list
76- framework_version (str): framework version string
77-
78- Returns:
79- bool: Whether or not ``framework_version`` is equal to or lower than
80- ``highest_version``
81- """
82- version_list = [int (s ) for s in framework_version .split ("." )]
83- return version_list <= highest_version [0 : len (version_list )]
84-
85-
8652def validate_source_dir (script , directory ):
8753 """Validate that the source directory exists and it contains the user script
8854 Args:
Original file line number Diff line number Diff line change 1515
1616import logging
1717
18+ from packaging .version import Version
19+
1820from sagemaker .estimator import Framework
1921from sagemaker .fw_utils import (
2022 framework_name_from_image ,
2123 framework_version_from_tag ,
22- is_version_equal_or_higher ,
2324 python_deprecation_warning ,
2425 validate_version_or_image_args ,
2526 warn_if_parameter_server_with_multi_gpu ,
@@ -157,9 +158,7 @@ def __init__(
157158
158159 if "enable_sagemaker_metrics" not in kwargs :
159160 # enable sagemaker metrics for MXNet v1.6 or greater:
160- if self .framework_version and is_version_equal_or_higher (
161- [1 , 6 ], self .framework_version
162- ):
161+ if self .framework_version and Version (self .framework_version ) >= Version ("1.6" ):
163162 kwargs ["enable_sagemaker_metrics" ] = True
164163
165164 super (MXNet , self ).__init__ (
Original file line number Diff line number Diff line change 1515
1616import logging
1717
18+ from packaging .version import Version
19+
1820from sagemaker .estimator import Framework
1921from sagemaker .fw_utils import (
2022 framework_name_from_image ,
2123 framework_version_from_tag ,
22- is_version_equal_or_higher ,
2324 python_deprecation_warning ,
2425 validate_version_or_image_args ,
2526)
@@ -116,9 +117,7 @@ def __init__(
116117
117118 if "enable_sagemaker_metrics" not in kwargs :
118119 # enable sagemaker metrics for PT v1.3 or greater:
119- if self .framework_version and is_version_equal_or_higher (
120- [1 , 3 ], self .framework_version
121- ):
120+ if self .framework_version and Version (self .framework_version ) >= Version ("1.3" ):
122121 kwargs ["enable_sagemaker_metrics" ] = True
123122
124123 super (PyTorch , self ).__init__ (
Original file line number Diff line number Diff line change @@ -129,7 +129,7 @@ def __init__(
129129
130130 if "enable_sagemaker_metrics" not in kwargs :
131131 # enable sagemaker metrics for TF v1.15 or greater:
132- if framework_version and fw . is_version_equal_or_higher ([ 1 , 15 ], framework_version ):
132+ if framework_version and version . Version ( framework_version ) >= version . Version ( "1.15" ):
133133 kwargs ["enable_sagemaker_metrics" ] = True
134134
135135 super (TensorFlow , self ).__init__ (image_uri = image_uri , ** kwargs )
You can’t perform that action at this time.
0 commit comments