-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
infra: add unit tests for v2 migration script file updaters and modif…
…iers (#1536)
- Loading branch information
Showing
3 changed files
with
321 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
153 changes: 153 additions & 0 deletions
153
tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import sys | ||
|
||
import pasta | ||
import pytest | ||
|
||
from sagemaker.cli.compatibility.v2.modifiers import framework_version | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def skip_if_py2(): | ||
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. | ||
if sys.version_info.major < 3: | ||
pytest.skip("v2 migration script doesn't support Python 2.") | ||
|
||
|
||
def test_node_should_be_modified_fw_constructor_no_fw_version(): | ||
fw_constructors = ( | ||
"TensorFlow()", | ||
"sagemaker.tensorflow.TensorFlow()", | ||
"TensorFlowModel()", | ||
"sagemaker.tensorflow.TensorFlowModel()", | ||
"MXNet()", | ||
"sagemaker.mxnet.MXNet()", | ||
"MXNetModel()", | ||
"sagemaker.mxnet.MXNetModel()", | ||
"Chainer()", | ||
"sagemaker.chainer.Chainer()", | ||
"ChainerModel()", | ||
"sagemaker.chainer.ChainerModel()", | ||
"PyTorch()", | ||
"sagemaker.pytorch.PyTorch()", | ||
"PyTorchModel()", | ||
"sagemaker.pytorch.PyTorchModel()", | ||
"SKLearn()", | ||
"sagemaker.sklearn.SKLearn()", | ||
"SKLearnModel()", | ||
"sagemaker.sklearn.SKLearnModel()", | ||
) | ||
|
||
modifier = framework_version.FrameworkVersionEnforcer() | ||
|
||
for constructor in fw_constructors: | ||
node = _ast_call(constructor) | ||
assert modifier.node_should_be_modified(node) is True | ||
|
||
|
||
def test_node_should_be_modified_fw_constructor_with_fw_version(): | ||
fw_constructors = ( | ||
"TensorFlow(framework_version='2.2')", | ||
"sagemaker.tensorflow.TensorFlow(framework_version='2.2')", | ||
"TensorFlowModel(framework_version='1.10')", | ||
"sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')", | ||
"MXNet(framework_version='1.6')", | ||
"sagemaker.mxnet.MXNet(framework_version='1.6')", | ||
"MXNetModel(framework_version='1.6')", | ||
"sagemaker.mxnet.MXNetModel(framework_version='1.6')", | ||
"PyTorch(framework_version='1.4')", | ||
"sagemaker.pytorch.PyTorch(framework_version='1.4')", | ||
"PyTorchModel(framework_version='1.4')", | ||
"sagemaker.pytorch.PyTorchModel(framework_version='1.4')", | ||
"Chainer(framework_version='5.0')", | ||
"sagemaker.chainer.Chainer(framework_version='5.0')", | ||
"ChainerModel(framework_version='5.0')", | ||
"sagemaker.chainer.ChainerModel(framework_version='5.0')", | ||
"SKLearn(framework_version='0.20.0')", | ||
"sagemaker.sklearn.SKLearn(framework_version='0.20.0')", | ||
"SKLearnModel(framework_version='0.20.0')", | ||
"sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')", | ||
) | ||
|
||
modifier = framework_version.FrameworkVersionEnforcer() | ||
|
||
for constructor in fw_constructors: | ||
node = _ast_call(constructor) | ||
assert modifier.node_should_be_modified(node) is False | ||
|
||
|
||
def test_node_should_be_modified_random_function_call(): | ||
node = _ast_call("sagemaker.session.Session()") | ||
modifier = framework_version.FrameworkVersionEnforcer() | ||
assert modifier.node_should_be_modified(node) is False | ||
|
||
|
||
def test_modify_node_tf(): | ||
classes = ( | ||
"TensorFlow" "sagemaker.tensorflow.TensorFlow", | ||
"TensorFlowModel", | ||
"sagemaker.tensorflow.TensorFlowModel", | ||
) | ||
_test_modify_node(classes, "1.11.0") | ||
|
||
|
||
def test_modify_node_mx(): | ||
classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel") | ||
_test_modify_node(classes, "1.2.0") | ||
|
||
|
||
def test_modify_node_chainer(): | ||
classes = ( | ||
"Chainer", | ||
"sagemaker.chainer.Chainer", | ||
"ChainerModel", | ||
"sagemaker.chainer.ChainerModel", | ||
) | ||
_test_modify_node(classes, "4.1.0") | ||
|
||
|
||
def test_modify_node_pt(): | ||
classes = ( | ||
"PyTorch", | ||
"sagemaker.pytorch.PyTorch", | ||
"PyTorchModel", | ||
"sagemaker.pytorch.PyTorchModel", | ||
) | ||
_test_modify_node(classes, "0.4.0") | ||
|
||
|
||
def test_modify_node_sklearn(): | ||
classes = ( | ||
"SKLearn", | ||
"sagemaker.sklearn.SKLearn", | ||
"SKLearnModel", | ||
"sagemaker.sklearn.SKLearnModel", | ||
) | ||
_test_modify_node(classes, "0.20.0") | ||
|
||
|
||
def _ast_call(code): | ||
return pasta.parse(code).body[0].value | ||
|
||
|
||
def _test_modify_node(classes, default_version): | ||
modifier = framework_version.FrameworkVersionEnforcer() | ||
for cls in classes: | ||
node = _ast_call("{}()".format(cls)) | ||
modifier.modify_node(node) | ||
|
||
expected_result = "{}(framework_version='{}')".format(cls, default_version) | ||
assert expected_result == pasta.dump(node) |
154 changes: 154 additions & 0 deletions
154
tests/unit/sagemaker/cli/compatibility/v2/test_file_updaters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import os | ||
|
||
from mock import call, Mock, mock_open, patch | ||
|
||
from sagemaker.cli.compatibility.v2 import files | ||
|
||
|
||
def test_init(): | ||
input_file = "input.py" | ||
output_file = "output.py" | ||
|
||
updater = files.FileUpdater(input_file, output_file) | ||
assert input_file == updater.input_path | ||
assert output_file == updater.output_path | ||
|
||
|
||
@patch("six.moves.builtins.open", mock_open()) | ||
@patch("os.makedirs") | ||
def test_make_output_dirs_if_needed_make_path(makedirs): | ||
output_dir = "dir" | ||
output_path = os.path.join(output_dir, "output.py") | ||
|
||
updater = files.FileUpdater("input.py", output_path) | ||
updater._make_output_dirs_if_needed() | ||
|
||
makedirs.assert_called_with(output_dir) | ||
|
||
|
||
@patch("six.moves.builtins.open", mock_open()) | ||
@patch("os.path.exists", return_value=True) | ||
def test_make_output_dirs_if_needed_overwrite_with_warning(os_path_exists, caplog): | ||
output_file = "output.py" | ||
|
||
updater = files.FileUpdater("input.py", output_file) | ||
updater._make_output_dirs_if_needed() | ||
|
||
assert "Overwriting file {}".format(output_file) in caplog.text | ||
|
||
|
||
@patch("pasta.dump") | ||
@patch("pasta.parse") | ||
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer") | ||
def test_py_file_update(ast_transformer, pasta_parse, pasta_dump): | ||
input_ast = Mock() | ||
pasta_parse.return_value = input_ast | ||
|
||
output_ast = Mock(_fields=[]) | ||
ast_transformer.return_value.visit.return_value = output_ast | ||
output_code = "print('goodbye')" | ||
pasta_dump.return_value = output_code | ||
|
||
input_file = "input.py" | ||
output_file = "output.py" | ||
|
||
input_code = "print('hello, world!')" | ||
open_mock = mock_open(read_data=input_code) | ||
with patch("six.moves.builtins.open", open_mock): | ||
updater = files.PyFileUpdater(input_file, output_file) | ||
updater.update() | ||
|
||
pasta_parse.assert_called_with(input_code) | ||
ast_transformer.return_value.visit.assert_called_with(input_ast) | ||
|
||
assert call(input_file) in open_mock.mock_calls | ||
assert call(output_file, "w") in open_mock.mock_calls | ||
|
||
open_mock().write.assert_called_with(output_code) | ||
pasta_dump.assert_called_with(output_ast) | ||
|
||
|
||
@patch("json.dump") | ||
@patch("pasta.dump") | ||
@patch("pasta.parse") | ||
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer") | ||
def test_update(ast_transformer, pasta_parse, pasta_dump, json_dump): | ||
notebook_template = """{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# code to be modified\\n", | ||
"%s" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} | ||
""" | ||
input_code = "print('hello, world!')" | ||
input_notebook = notebook_template % input_code | ||
|
||
input_ast = Mock() | ||
pasta_parse.return_value = input_ast | ||
|
||
output_ast = Mock(_fields=[]) | ||
ast_transformer.return_value.visit.return_value = output_ast | ||
output_code = "print('goodbye')" | ||
pasta_dump.return_value = "# code to be modified\n{}".format(output_code) | ||
|
||
input_file = "input.py" | ||
output_file = "output.py" | ||
|
||
open_mock = mock_open(read_data=input_notebook) | ||
with patch("six.moves.builtins.open", open_mock): | ||
updater = files.JupyterNotebookFileUpdater(input_file, output_file) | ||
updater.update() | ||
|
||
pasta_parse.assert_called_with("# code to be modified\n{}".format(input_code)) | ||
ast_transformer.return_value.visit.assert_called_with(input_ast) | ||
pasta_dump.assert_called_with(output_ast) | ||
|
||
assert call(input_file) in open_mock.mock_calls | ||
assert call(output_file, "w") in open_mock.mock_calls | ||
|
||
json_dump.assert_called_with(json.loads(notebook_template % output_code), open_mock(), indent=1) | ||
open_mock().write.assert_called_with("\n") |