diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index f35901a2c4..6b63fd61a1 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -37,6 +37,7 @@ modifiers.training_input.TrainingInputConstructorRefactor(), modifiers.training_input.ShuffleConfigModuleRenamer(), modifiers.serde.SerdeConstructorRenamer(), + modifiers.image_uris.ImageURIRetrieveRefactor(), ] IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()] @@ -55,6 +56,7 @@ modifiers.training_input.ShuffleConfigImportFromRenamer(), modifiers.serde.SerdeImportFromAmazonCommonRenamer(), modifiers.serde.SerdeImportFromPredictorRenamer(), + modifiers.image_uris.ImageURIRetrieveImportFromRenamer(), ] diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py index f6e1ead061..75f5e1dbeb 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py @@ -24,4 +24,5 @@ tfs, training_params, training_input, + image_uris, ) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py b/src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py new file mode 100644 index 0000000000..fe0ba9df2d --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py @@ -0,0 +1,134 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes to modify image uri retrieve methods for Python SDK v2.0 and later.""" +from __future__ import absolute_import + +import ast + +from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier + +GET_IMAGE_URI_NAME = "get_image_uri" +GET_IMAGE_URI_NAMESPACES = ( + "sagemaker", + "sagemaker.amazon_estimator", + "sagemaker.amazon.amazon_estimator", + "amazon_estimator", + "amazon.amazon_estimator", +) + + +class ImageURIRetrieveRefactor(Modifier): + """A class to refactor *get_image_uri() method.""" + + def node_should_be_modified(self, node): + """Checks if the ``ast.Call`` node calls a function of interest. + + This looks for the following calls: + + - ``sagemaker.get_image_uri`` + - ``sagemaker.amazon_estimator.get_image_uri`` + - ``get_image_uri`` + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` instantiates a class of interest. + """ + return matching.matches_name_or_namespaces( + node, GET_IMAGE_URI_NAME, GET_IMAGE_URI_NAMESPACES + ) + + def modify_node(self, node): + """Modifies the ``ast.Call`` node to call ``image_uris.retrieve`` instead. + And switch the first two parameters from (region, repo) to (framework, region) + + Args: + node (ast.Call): a node that represents a *image_uris.retrieve call. + """ + original_args = [None] * 3 + for kw in node.keywords: + if kw.arg == "repo_name": + original_args[0] = ast.Str(kw.value.s) + elif kw.arg == "repo_region": + original_args[1] = ast.Str(kw.value.s) + elif kw.arg == "repo_version": + original_args[2] = ast.Str(kw.value.s) + + if len(node.args) > 0: + original_args[1] = ast.Str(node.args[0].s) + if len(node.args) > 1: + original_args[0] = ast.Str(node.args[1].s) + if len(node.args) > 2: + original_args[2] = ast.Str(node.args[2].s) + + args = [] + for arg in original_args: + if arg: + args.append(arg) + + func = node.func + has_sagemaker = False + while hasattr(func, "value"): + if hasattr(func.value, "id") and func.value.id == "sagemaker": + has_sagemaker = True + break + func = func.value + + if has_sagemaker: + node.func = ast.Attribute( + value=ast.Attribute(attr="image_uris", value=ast.Name(id="sagemaker")), + attr="retrieve", + ) + else: + node.func = ast.Attribute(value=ast.Name(id="image_uris"), attr="retrieve") + node.args = args + node.keywords = [] + return node + + +class ImageURIRetrieveImportFromRenamer(Modifier): + """A class to update import statements of ``get_image_uri``.""" + + def node_should_be_modified(self, node): + """Checks if the import statement imports ``get_image_uri`` from the correct module. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the import statement imports ``get_image_uri`` from the correct module. + """ + return node.module in GET_IMAGE_URI_NAMESPACES and any( + name.name == GET_IMAGE_URI_NAME for name in node.names + ) + + def modify_node(self, node): + """Changes the ``ast.ImportFrom`` node's name from ``get_image_uri`` to ``image_uris``. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.AST: the original node, which has been potentially modified. + """ + for name in node.names: + if name.name == GET_IMAGE_URI_NAME: + name.name = "image_uris" + if node.module in GET_IMAGE_URI_NAMESPACES: + node.module = "sagemaker" + return node diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_algorithm_image_uris.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_algorithm_image_uris.py new file mode 100644 index 0000000000..ea43497d8a --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_algorithm_image_uris.py @@ -0,0 +1,114 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pasta +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import image_uris +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import + + +@pytest.fixture +def methods(): + return ( + "get_image_uri('us-west-2', 'sagemaker-xgboost')", + "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-xgboost')", + "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-xgboost')", + "sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost', repo_version='1')", + ) + + +@pytest.fixture +def import_statements(): + return ( + "from sagemaker import get_image_uri", + "from sagemaker.amazon_estimator import get_image_uri", + "from sagemaker.amazon.amazon_estimator import get_image_uri", + ) + + +def test_method_node_should_be_modified(methods): + modifier = image_uris.ImageURIRetrieveRefactor() + for method in methods: + node = ast_call(method) + assert modifier.node_should_be_modified(node) + + +def test_methodnode_should_be_modified_random_call(): + modifier = image_uris.ImageURIRetrieveRefactor() + node = ast_call("create_image_uri()") + assert not modifier.node_should_be_modified(node) + + +def test_method_modify_node(methods, caplog): + modifier = image_uris.ImageURIRetrieveRefactor() + + method = "get_image_uri('us-west-2', 'xgboost')" + node = ast_call(method) + modifier.modify_node(node) + assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) + + method = "amazon_estimator.get_image_uri('us-west-2', 'xgboost')" + node = ast_call(method) + modifier.modify_node(node) + assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) + + method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='xgboost')" + node = ast_call(method) + modifier.modify_node(node) + assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) + + method = "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='xgboost')" + node = ast_call(method) + modifier.modify_node(node) + assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) + + method = ( + "sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'xgboost', repo_version='1')" + ) + node = ast_call(method) + modifier.modify_node(node) + assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2', '1')" == pasta.dump(node) + + +def test_import_from_node_should_be_modified_image_uris_input(import_statements): + modifier = image_uris.ImageURIRetrieveImportFromRenamer() + + statement = "from sagemaker import get_image_uri" + node = ast_import(statement) + assert modifier.node_should_be_modified(node) + + statement = "from sagemaker.amazon_estimator import get_image_uri" + node = ast_import(statement) + assert modifier.node_should_be_modified(node) + + statement = "from sagemaker.amazon.amazon_estimator import get_image_uri" + node = ast_import(statement) + assert modifier.node_should_be_modified(node) + + +def test_import_from_node_should_be_modified_random_import(): + modifier = image_uris.ImageURIRetrieveImportFromRenamer() + node = ast_import("from sagemaker.amazon_estimator import registry") + assert not modifier.node_should_be_modified(node) + + +def test_import_from_modify_node(import_statements): + modifier = image_uris.ImageURIRetrieveImportFromRenamer() + expected_result = "from sagemaker import image_uris" + + for import_statement in import_statements: + node = ast_import(import_statement) + modifier.modify_node(node) + assert expected_result == pasta.dump(node)