From 6341eb776ccaed39e981549c1494a67dd4fc7165 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 30 Jul 2020 11:09:45 -0700 Subject: [PATCH] change: don't require instance_type for image_uris.retrieve() if only one option --- src/sagemaker/image_uris.py | 4 ++++ tests/unit/sagemaker/image_uris/test_retrieve.py | 16 ++++++++++++++++ tests/unit/sagemaker/image_uris/test_xgboost.py | 6 +----- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 193794d815..2436424df4 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -173,6 +173,10 @@ def _processor(instance_type, available_processors): logger.info("Ignoring unnecessary instance type: %s.", instance_type) return None + if len(available_processors) == 1 and not instance_type: + logger.info("Defaulting to only supported image scope: %s.", available_processors[0]) + return available_processors[0] + if not instance_type: raise ValueError( "Empty SageMaker instance type. For options, see: " diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index dd17724a2d..27e189acc4 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -522,6 +522,22 @@ def test_retrieve_processor_type_from_version_specific_processor_config(config_f assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.1.0-py3" == uri +@patch("sagemaker.image_uris.config_for_framework") +def test_retrieve_default_processor_type_if_possible(config_for_framework): + config = copy.deepcopy(BASE_CONFIG) + config["processors"] = ["cpu"] + config_for_framework.return_value = config + + uri = image_uris.retrieve( + framework="useless-string", + version="1.0.0", + py_version="py3", + region="us-west-2", + image_scope="training", + ) + assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri + + @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) def test_retrieve_unsupported_processor_type(config_for_framework): with pytest.raises(ValueError) as e: diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 33ee7bb950..7887cb2d44 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -72,11 +72,7 @@ def test_xgboost_framework(xgboost_framework_version): for region in regions.regions(): uri = image_uris.retrieve( - framework="xgboost", - region=region, - version=xgboost_framework_version, - py_version="py3", - instance_type="ml.c4.xlarge", + framework="xgboost", region=region, version=xgboost_framework_version, py_version="py3", ) expected = expected_uris.framework_uri(