From 15a295fb436b0ce78b9aad298d51f79cd9e99335 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 30 Jul 2020 09:28:47 -0700 Subject: [PATCH] breaking: move ShuffleConfig from sagemaker.session to sagemaker.inputs --- .../cli/compatibility/v2/ast_transformer.py | 2 + .../v2/modifiers/training_input.py | 70 ++++++++++++++++++ src/sagemaker/inputs.py | 20 ++++- src/sagemaker/session.py | 15 ---- .../v2/modifiers/test_shuffle_config.py | 74 +++++++++++++++++++ tests/unit/test_estimator.py | 2 +- 6 files changed, 165 insertions(+), 18 deletions(-) create mode 100644 tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_shuffle_config.py diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 228b68b594..f35901a2c4 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -35,6 +35,7 @@ modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(), modifiers.training_params.TrainPrefixRemover(), modifiers.training_input.TrainingInputConstructorRefactor(), + modifiers.training_input.ShuffleConfigModuleRenamer(), modifiers.serde.SerdeConstructorRenamer(), ] @@ -51,6 +52,7 @@ modifiers.predictors.PredictorImportFromRenamer(), modifiers.tfs.TensorFlowServingImportFromRenamer(), modifiers.training_input.TrainingInputImportFromRenamer(), + modifiers.training_input.ShuffleConfigImportFromRenamer(), modifiers.serde.SerdeImportFromAmazonCommonRenamer(), modifiers.serde.SerdeImportFromPredictorRenamer(), ] diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py b/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py index 171d52f570..b181cdf43e 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py @@ -100,3 +100,73 @@ def modify_node(self, node): if node.module == "sagemaker.session": node.module = "sagemaker.inputs" return node + + +class ShuffleConfigModuleRenamer(Modifier): + """A class to change ``ShuffleConfig`` usage to use ``sagemaker.inputs.ShuffleConfig``.""" + + def node_should_be_modified(self, node): + """Checks if the ``ast.Call`` node instantiates a class of interest. + + This looks for the following calls: + + - ``sagemaker.session.ShuffleConfig`` + - ``session.ShuffleConfig`` + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` instantiates a class of interest. + """ + if isinstance(node.func, ast.Name): + return False + + return matching.matches_name_or_namespaces( + node, "ShuffleConfig", ("sagemaker.session", "session") + ) + + def modify_node(self, node): + """Modifies the ``ast.Call`` node to call ``sagemaker.inputs.ShuffleConfig``. + + Args: + node (ast.Call): a node that represents a ``sagemaker.session.ShuffleConfig`` + constructor. + + Returns: + ast.Call: the original node, with its namespace changed to use the ``inputs`` module. + """ + _rename_namespace(node, "session") + return node + + +class ShuffleConfigImportFromRenamer(Modifier): + """A class to update import statements of ``ShuffleConfig``.""" + + def node_should_be_modified(self, node): + """Checks if the import statement imports ``sagemaker.session.ShuffleConfig``. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the import statement imports ``sagemaker.session.ShuffleConfig``. + """ + return node.module == "sagemaker.session" and any( + name.name == "ShuffleConfig" for name in node.names + ) + + def modify_node(self, node): + """Changes the ``ast.ImportFrom`` node's namespace to ``sagemaker.inputs``. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.ImportFrom: the original node, with its module modified to ``"sagemaker.inputs"``. + """ + node.module = "sagemaker.inputs" + return node diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index a846eb5d70..e9eadf5344 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -72,8 +72,8 @@ def __init__( found in a specified AugmentedManifestFile. target_attribute_name (str): The name of the attribute will be predicted (classified) in a SageMaker AutoML job. It is required if the input is for SageMaker AutoML job. - shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on - this channel. See the SageMaker API documentation for more info: + shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables + shuffling on this channel. See the SageMaker API documentation for more info: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html """ self.config = { @@ -102,6 +102,22 @@ def __init__( self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} +class ShuffleConfig(object): + """For configuring channel shuffling using a seed. + + For more detail, see the AWS documentation: + https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + """ + + def __init__(self, seed): + """Create a ShuffleConfig. + + Args: + seed (long): the long value used to seed the shuffled sequence. + """ + self.seed = seed + + class FileSystemInput(object): """Amazon SageMaker channel configurations for file system data sources. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index d5a9712e6a..7888a86268 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3362,21 +3362,6 @@ def get_execution_role(sagemaker_session=None): raise ValueError(message.format(arn)) -class ShuffleConfig(object): - """ - Used to configure channel shuffling using a seed. See SageMaker documentation for - more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html - """ - - def __init__(self, seed): - """ - Create a ShuffleConfig. - Args: - seed (long): the long value used to seed the shuffled sequence. - """ - self.seed = seed - - def _create_model_request( name, role, container_def=None, tags=None ): # pylint: disable=redefined-outer-name diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_shuffle_config.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_shuffle_config.py new file mode 100644 index 0000000000..2b5d607eb2 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_shuffle_config.py @@ -0,0 +1,74 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pasta +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import training_input +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import + + +@pytest.fixture +def constructors(): + return ( + "sagemaker.session.ShuffleConfig(seed)", + "session.ShuffleConfig(seed)", + ) + + +@pytest.fixture +def modified_constructors(constructors): + return [c.replace("session", "inputs") for c in constructors] + + +def test_constructor_node_should_be_modified(constructors): + modifier = training_input.ShuffleConfigModuleRenamer() + for constructor in constructors: + node = ast_call(constructor) + assert modifier.node_should_be_modified(node) + + +def test_constructor_node_should_be_modified_random_call(): + modifier = training_input.ShuffleConfigModuleRenamer() + node = ast_call("FileSystemInput()") + assert not modifier.node_should_be_modified(node) + + +def test_constructor_modify_node(constructors, modified_constructors): + modifier = training_input.ShuffleConfigModuleRenamer() + + for before, expected in zip(constructors, modified_constructors): + node = ast_call(before) + modifier.modify_node(node) + assert expected == pasta.dump(node) + + +def test_import_from_node_should_be_modified_training_input(): + modifier = training_input.ShuffleConfigImportFromRenamer() + node = ast_import("from sagemaker.session import ShuffleConfig") + assert modifier.node_should_be_modified(node) + + +def test_import_from_node_should_be_modified_random_import(): + modifier = training_input.ShuffleConfigImportFromRenamer() + node = ast_import("from sagemaker.session import Session") + assert not modifier.node_should_be_modified(node) + + +def test_import_from_modify_node(): + modifier = training_input.ShuffleConfigImportFromRenamer() + node = ast_import("from sagemaker.session import ShuffleConfig") + + modifier.modify_node(node) + assert "from sagemaker.inputs import ShuffleConfig" == pasta.dump(node) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 9796a6d63b..e6a82d4d71 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -26,9 +26,9 @@ from sagemaker import TrainingInput, utils, vpc_utils from sagemaker.algorithm import AlgorithmEstimator from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob +from sagemaker.inputs import ShuffleConfig from sagemaker.model import FrameworkModel from sagemaker.predictor import Predictor -from sagemaker.session import ShuffleConfig from sagemaker.transformer import Transformer MODEL_DATA = "s3://bucket/model.tar.gz"