Skip to content

Commit

Permalink
infra: add unit tests for v2 migration script file updaters and modif…
Browse files Browse the repository at this point in the history
…iers (#1536)
  • Loading branch information
laurenyu committed Jun 1, 2020
1 parent a680be1 commit 614fe7e
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/sagemaker/cli/compatibility/v2/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def update(self):
updated code to an output file.
"""

def _make_output_dirs_if_needed(self):
"""Checks if the directory path for ``self.output_path`` exists,
and creates the directories if not. This function also logs a warning if
``self.output_path`` already exists.
"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

if os.path.exists(self.output_path):
LOGGER.warning("Overwriting file %s", self.output_path)


class PyFileUpdater(FileUpdater):
"""A class for updating Python (``*.py``) files."""
Expand Down Expand Up @@ -88,12 +100,7 @@ def _write_output_file(self, output):
Args:
output (ast.Module): AST to save as the output file.
"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

if os.path.exists(self.output_path):
LOGGER.warning("Overwriting file %s", self.output_path)
self._make_output_dirs_if_needed()

with open(self.output_path, "w") as output_file:
output_file.write(pasta.dump(output))
Expand Down Expand Up @@ -168,12 +175,7 @@ def _write_output_file(self, output):
Args:
output (dict): JSON to save as the output file.
"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

if os.path.exists(self.output_path):
LOGGER.warning("Overwriting file %s", self.output_path)
self._make_output_dirs_if_needed()

with open(self.output_path, "w") as output_file:
json.dump(output, output_file, indent=1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import sys

import pasta
import pytest

from sagemaker.cli.compatibility.v2.modifiers import framework_version


@pytest.fixture(autouse=True)
def skip_if_py2():
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
if sys.version_info.major < 3:
pytest.skip("v2 migration script doesn't support Python 2.")


def test_node_should_be_modified_fw_constructor_no_fw_version():
fw_constructors = (
"TensorFlow()",
"sagemaker.tensorflow.TensorFlow()",
"TensorFlowModel()",
"sagemaker.tensorflow.TensorFlowModel()",
"MXNet()",
"sagemaker.mxnet.MXNet()",
"MXNetModel()",
"sagemaker.mxnet.MXNetModel()",
"Chainer()",
"sagemaker.chainer.Chainer()",
"ChainerModel()",
"sagemaker.chainer.ChainerModel()",
"PyTorch()",
"sagemaker.pytorch.PyTorch()",
"PyTorchModel()",
"sagemaker.pytorch.PyTorchModel()",
"SKLearn()",
"sagemaker.sklearn.SKLearn()",
"SKLearnModel()",
"sagemaker.sklearn.SKLearnModel()",
)

modifier = framework_version.FrameworkVersionEnforcer()

for constructor in fw_constructors:
node = _ast_call(constructor)
assert modifier.node_should_be_modified(node) is True


def test_node_should_be_modified_fw_constructor_with_fw_version():
fw_constructors = (
"TensorFlow(framework_version='2.2')",
"sagemaker.tensorflow.TensorFlow(framework_version='2.2')",
"TensorFlowModel(framework_version='1.10')",
"sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')",
"MXNet(framework_version='1.6')",
"sagemaker.mxnet.MXNet(framework_version='1.6')",
"MXNetModel(framework_version='1.6')",
"sagemaker.mxnet.MXNetModel(framework_version='1.6')",
"PyTorch(framework_version='1.4')",
"sagemaker.pytorch.PyTorch(framework_version='1.4')",
"PyTorchModel(framework_version='1.4')",
"sagemaker.pytorch.PyTorchModel(framework_version='1.4')",
"Chainer(framework_version='5.0')",
"sagemaker.chainer.Chainer(framework_version='5.0')",
"ChainerModel(framework_version='5.0')",
"sagemaker.chainer.ChainerModel(framework_version='5.0')",
"SKLearn(framework_version='0.20.0')",
"sagemaker.sklearn.SKLearn(framework_version='0.20.0')",
"SKLearnModel(framework_version='0.20.0')",
"sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')",
)

modifier = framework_version.FrameworkVersionEnforcer()

for constructor in fw_constructors:
node = _ast_call(constructor)
assert modifier.node_should_be_modified(node) is False


def test_node_should_be_modified_random_function_call():
node = _ast_call("sagemaker.session.Session()")
modifier = framework_version.FrameworkVersionEnforcer()
assert modifier.node_should_be_modified(node) is False


def test_modify_node_tf():
classes = (
"TensorFlow" "sagemaker.tensorflow.TensorFlow",
"TensorFlowModel",
"sagemaker.tensorflow.TensorFlowModel",
)
_test_modify_node(classes, "1.11.0")


def test_modify_node_mx():
classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel")
_test_modify_node(classes, "1.2.0")


def test_modify_node_chainer():
classes = (
"Chainer",
"sagemaker.chainer.Chainer",
"ChainerModel",
"sagemaker.chainer.ChainerModel",
)
_test_modify_node(classes, "4.1.0")


def test_modify_node_pt():
classes = (
"PyTorch",
"sagemaker.pytorch.PyTorch",
"PyTorchModel",
"sagemaker.pytorch.PyTorchModel",
)
_test_modify_node(classes, "0.4.0")


def test_modify_node_sklearn():
classes = (
"SKLearn",
"sagemaker.sklearn.SKLearn",
"SKLearnModel",
"sagemaker.sklearn.SKLearnModel",
)
_test_modify_node(classes, "0.20.0")


def _ast_call(code):
return pasta.parse(code).body[0].value


def _test_modify_node(classes, default_version):
modifier = framework_version.FrameworkVersionEnforcer()
for cls in classes:
node = _ast_call("{}()".format(cls))
modifier.modify_node(node)

expected_result = "{}(framework_version='{}')".format(cls, default_version)
assert expected_result == pasta.dump(node)
154 changes: 154 additions & 0 deletions tests/unit/sagemaker/cli/compatibility/v2/test_file_updaters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os

from mock import call, Mock, mock_open, patch

from sagemaker.cli.compatibility.v2 import files


def test_init():
input_file = "input.py"
output_file = "output.py"

updater = files.FileUpdater(input_file, output_file)
assert input_file == updater.input_path
assert output_file == updater.output_path


@patch("six.moves.builtins.open", mock_open())
@patch("os.makedirs")
def test_make_output_dirs_if_needed_make_path(makedirs):
output_dir = "dir"
output_path = os.path.join(output_dir, "output.py")

updater = files.FileUpdater("input.py", output_path)
updater._make_output_dirs_if_needed()

makedirs.assert_called_with(output_dir)


@patch("six.moves.builtins.open", mock_open())
@patch("os.path.exists", return_value=True)
def test_make_output_dirs_if_needed_overwrite_with_warning(os_path_exists, caplog):
output_file = "output.py"

updater = files.FileUpdater("input.py", output_file)
updater._make_output_dirs_if_needed()

assert "Overwriting file {}".format(output_file) in caplog.text


@patch("pasta.dump")
@patch("pasta.parse")
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer")
def test_py_file_update(ast_transformer, pasta_parse, pasta_dump):
input_ast = Mock()
pasta_parse.return_value = input_ast

output_ast = Mock(_fields=[])
ast_transformer.return_value.visit.return_value = output_ast
output_code = "print('goodbye')"
pasta_dump.return_value = output_code

input_file = "input.py"
output_file = "output.py"

input_code = "print('hello, world!')"
open_mock = mock_open(read_data=input_code)
with patch("six.moves.builtins.open", open_mock):
updater = files.PyFileUpdater(input_file, output_file)
updater.update()

pasta_parse.assert_called_with(input_code)
ast_transformer.return_value.visit.assert_called_with(input_ast)

assert call(input_file) in open_mock.mock_calls
assert call(output_file, "w") in open_mock.mock_calls

open_mock().write.assert_called_with(output_code)
pasta_dump.assert_called_with(output_ast)


@patch("json.dump")
@patch("pasta.dump")
@patch("pasta.parse")
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer")
def test_update(ast_transformer, pasta_parse, pasta_dump, json_dump):
notebook_template = """{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# code to be modified\\n",
"%s"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
"""
input_code = "print('hello, world!')"
input_notebook = notebook_template % input_code

input_ast = Mock()
pasta_parse.return_value = input_ast

output_ast = Mock(_fields=[])
ast_transformer.return_value.visit.return_value = output_ast
output_code = "print('goodbye')"
pasta_dump.return_value = "# code to be modified\n{}".format(output_code)

input_file = "input.py"
output_file = "output.py"

open_mock = mock_open(read_data=input_notebook)
with patch("six.moves.builtins.open", open_mock):
updater = files.JupyterNotebookFileUpdater(input_file, output_file)
updater.update()

pasta_parse.assert_called_with("# code to be modified\n{}".format(input_code))
ast_transformer.return_value.visit.assert_called_with(input_ast)
pasta_dump.assert_called_with(output_ast)

assert call(input_file) in open_mock.mock_calls
assert call(output_file, "w") in open_mock.mock_calls

json_dump.assert_called_with(json.loads(notebook_template % output_code), open_mock(), indent=1)
open_mock().write.assert_called_with("\n")

0 comments on commit 614fe7e

Please sign in to comment.