Skip to content

Commit

Permalink
fix: add logic to use asimov image for TF 1.14 py2 (#997)
Browse files Browse the repository at this point in the history
* fix: add logic to use asimov image for TF 1.14 py2
  • Loading branch information
chuyang-deng committed Aug 23, 2019
1 parent fa79418 commit 6f3cf42
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 13 deletions.
3 changes: 0 additions & 3 deletions doc/using_tf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ models on SageMaker Hosting.

For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`.

.. warning::
The TensorFlow estimator is available only for Python 3, starting by the TensorFlow version 1.14.

.. warning::
We have added a new format of your TensorFlow training script with TensorFlow version 1.11.
This new way gives the user script more flexibility.
Expand Down
21 changes: 20 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,26 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
is_py3 = py_version == "py3" or py_version is None
is_merged_versions = _is_merged_versions(framework, framework_version)
return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None
return (
(not is_gov_region)
and is_merged_versions
and (is_py3 or _is_tf_14_or_later(framework, framework_version))
and accelerator_type is None
)


def _is_tf_14_or_later(framework, framework_version):
"""
Args:
framework:
framework_version:
"""
# Asimov team now owns Tensorflow 1.14.0 py2 and py3
asimov_lowest_tf_py2 = [1, 14, 0]
version = [int(s) for s in framework_version.split(".")]
return (
framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2[0 : len(version)]
)


def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ class TensorFlow(Framework):
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""

_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
# 1.14.0 now supports py2
# we will need to update this version number if future versions do not support py2 anymore
_LOWEST_PYTHON_2_ONLY_VERSION = [1, 14]

def __init__(
Expand Down Expand Up @@ -343,7 +345,7 @@ def _validate_args(

if py_version == "py2" and self._only_python_3_supported():
msg = (
"Python 2 containers are only available until TensorFlow version 1.13.1. "
"Python 2 containers are only available until TensorFlow version 1.14.0. "
"Please use a Python 3 container."
)
raise AttributeError(msg)
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def test_create_image_uri_merged_py2():
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2"
)

image_uri = fw_utils.create_image_uri(
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py2"
)
assert (
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.14-gpu-py2"
)

image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2")
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2"

Expand Down
10 changes: 2 additions & 8 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,18 +957,12 @@ def test_script_mode_deprecated_args(sagemaker_session):

def test_py2_version_deprecated(sagemaker_session):
with pytest.raises(AttributeError) as e:
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14", py_version="py2")
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14.1", py_version="py2")

msg = "Python 2 containers are only available until TensorFlow version 1.13.1. Please use a Python 3 container."
msg = "Python 2 containers are only available until TensorFlow version 1.14.0. Please use a Python 3 container."
assert msg in str(e.value)


def test_py3_is_default_version_after_tf1_14(sagemaker_session):
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14")

assert estimator.py_version == "py3"


def test_py3_is_default_version_before_tf1_14(sagemaker_session):
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.13")

Expand Down

0 comments on commit 6f3cf42

Please sign in to comment.