diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 6b63fd61a1..c39fe9f543 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -37,6 +37,7 @@ modifiers.training_input.TrainingInputConstructorRefactor(), modifiers.training_input.ShuffleConfigModuleRenamer(), modifiers.serde.SerdeConstructorRenamer(), + modifiers.serde.SerdeKeywordRemover(), modifiers.image_uris.ImageURIRetrieveRefactor(), ] diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 59b7a2c19d..83a173c392 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -157,6 +157,48 @@ def modify_node(self, node): ) +class SerdeKeywordRemover(Modifier): + """A class to remove Serde-related keyword arguments from call expressions.""" + + def node_should_be_modified(self, node): + """Checks if the ``ast.Call`` node uses deprecated keywords. + + In particular, this function checks if: + + - The ``ast.Call`` represents the ``create_model`` method. + - Either the serializer or deserializer keywords are used. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` contains keywords that should be removed. + """ + if not isinstance(node.func, ast.Attribute) or node.func.attr != "create_model": + return False + return any(keyword.arg in {"serializer", "deserializer"} for keyword in node.keywords) + + def modify_node(self, node): + """Removes the serializer and deserializer keywords, as applicable. + + Args: + node (ast.Call): a node that represents a ``create_model`` call. + + Returns: + ast.Call: the node that represents a ``create_model`` call without + serializer or deserializers keywords. + """ + i = 0 + while i < len(node.keywords): + keyword = node.keywords[i] + if keyword.arg in {"serializer", "deserializer"}: + node.keywords.pop(i) + else: + i += 1 + return node + + class SerdeObjectRenamer(Modifier): """A class to rename SerDe objects imported from ``sagemaker.predictor``.""" diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py index 11da4fde6d..c71468bf41 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py @@ -346,3 +346,42 @@ def test_deserializer_module_modify_node(src, expected): node = pasta.parse(src) modified_node = modifier.modify_node(node) assert expected == pasta.dump(modified_node) + + +@pytest.mark.parametrize( + "src, expected", + [ + ('estimator.create_model(entry_point="inference.py")', False), + ("estimator.create_model(serializer=CSVSerializer())", True), + ("estimator.create_model(deserializer=CSVDeserializer())", True), + ( + "estimator.create_model(serializer=CSVSerializer(), deserializer=CSVDeserializer())", + True, + ), + ("estimator.deploy(serializer=CSVSerializer())", False), + ], +) +def test_create_model_call_node_should_be_modified(src, expected): + modifier = serde.SerdeKeywordRemover() + node = ast_call(src) + assert modifier.node_should_be_modified(node) is expected + + +@pytest.mark.parametrize( + "src, expected", + [ + ( + 'estimator.create_model(entry_point="inference.py", serializer=CSVSerializer())', + 'estimator.create_model(entry_point="inference.py")', + ), + ( + 'estimator.create_model(entry_point="inference.py", deserializer=CSVDeserializer())', + 'estimator.create_model(entry_point="inference.py")', + ), + ], +) +def test_create_model_call_modify_node(src, expected): + modifier = serde.SerdeKeywordRemover() + node = ast_call(src) + modified_node = modifier.modify_node(node) + assert expected == pasta.dump(modified_node)