diff --git a/.flake8 b/.flake8 index de5d4c0..9618fa0 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] -application_import_names = sagemaker_containers, test -import-order-style = google \ No newline at end of file +application_import_names = sagemaker_containers, test, libchangehostname +import-order-style = google diff --git a/setup.py b/setup.py index a692f8c..ad6ca88 100644 --- a/setup.py +++ b/setup.py @@ -16,25 +16,31 @@ import os import sys -from setuptools import find_packages, setup +import setuptools def read(file_name): return open(os.path.join(os.path.dirname(__file__), file_name)).read() -packages = find_packages(where='src', exclude=('test',)) +packages = setuptools.find_packages(where='src', exclude=('test',)) packages.append('sagemaker_containers.etc') required_packages = [ - 'boto3', 'six', 'pip', 'flask', 'gunicorn', 'gevent', 'inotify_simple', 'werkzeug' + 'numpy', 'boto3', 'six', 'pip', 'flask', 'gunicorn', 'typing', + 'gevent', 'inotify_simple', 'werkzeug', 'paramiko' ] # enum is introduced in Python 3.4. Installing enum back port if sys.version_info < (3, 4): required_packages.append('enum34 >= 1.1.6') -setup( +gethostname = setuptools.Extension('libchangehostname', + sources=['src/sagemaker_containers/c/libchangehostname.c'], + extra_compile_args=['-Wall', '-shared', '-export-dynamic', + '-ldl']) + +setuptools.setup( name='sagemaker_containers', version='2.3.5', description='Open source library for creating containers to run on Amazon SageMaker.', @@ -46,6 +52,7 @@ def read(file_name): }, package_data={'sagemaker_containers.etc': ['*']}, py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob('src/*.py')], + ext_modules=[gethostname], long_description=read('README.md'), author='Amazon Web Services', url='https://github.com/aws/sagemaker-containers/', @@ -64,11 +71,11 @@ def read(file_name): install_requires=required_packages, extras_require={ - 'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'mock', 'sagemaker', 'numpy'] + 'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'mock', 'sagemaker==1.16.2'] }, entry_points={ - 'console_scripts': ['serve=sagemaker_containers.cli.serve:main', - 'train=sagemaker_containers.cli.train:main'], + 'console_scripts': ['serve=sagemaker_containers.cli.serve:main', + 'train=sagemaker_containers.cli.train:main'], } ) diff --git a/src/sagemaker_containers/c/libchangehostname.c b/src/sagemaker_containers/c/libchangehostname.c new file mode 100644 index 0000000..56e9f02 --- /dev/null +++ b/src/sagemaker_containers/c/libchangehostname.c @@ -0,0 +1,68 @@ +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the 'License'). You +// may not use this file except in compliance with the License. A copy of +// the License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the 'license' file accompanying this file. This file is +// distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +// ANY KIND, either express or implied. See the License for the specific +// language governing permissions and limitations under the License. + +#include +#include + + +int libchangehostname(char *name, size_t len) +{ + const char *val = getenv("SM_CURRENT_HOST"); + + strncpy(name, val, len); + return 0; +} + + +static PyObject* libchangehostname_call(PyObject* self, PyObject* args) { + long unsigned command; + char name[40]; + + if (!PyArg_ParseTuple(args, "k", &command)) { + return NULL; + } + + libchangehostname(name, command); + + return Py_BuildValue("s", name); +} + +static PyMethodDef LibchangehostnameMethods[] = { + { + "call", + libchangehostname_call, + METH_VARARGS, + }, + {NULL, NULL, 0, NULL}, // sentinel +}; + +#if PY_MAJOR_VERSION >= 3 +static PyModuleDef libchangehostnamemodule = { + PyModuleDef_HEAD_INIT, + "libchangehostname", + "Returns the value of $SM_CURRENT_HOST", + -1, + LibchangehostnameMethods, +}; + +PyMODINIT_FUNC PyInit_libchangehostname() { + return PyModule_Create(&libchangehostnamemodule); +} +#else +PyMODINIT_FUNC initlibchangehostname() { + PyObject* module; + + module = Py_InitModule3( + "libchangehostname", LibchangehostnameMethods, "Returns the value of $SM_CURRENT_HOST"); +} +#endif diff --git a/test/unit/c/test_libchangehostname.py b/test/unit/c/test_libchangehostname.py new file mode 100644 index 0000000..ce63ae6 --- /dev/null +++ b/test/unit/c/test_libchangehostname.py @@ -0,0 +1,38 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import sys + +from mock import patch +import pytest + +import libchangehostname +from sagemaker_containers import _errors, _process + + +def test_libchangehostname_with_env_set_in_another_process(): + with patch.dict(os.environ, {'SM_CURRENT_HOST': 'algo-5'}): + py_cmd = "import libchangehostname\nassert libchangehostname.call(30) == 'algo-5'" + _process.check_error([sys.executable, '-c', py_cmd], _errors.ExecuteUserScriptError) + + +def test_libchangehostname_with_env_set(): + with patch.dict(os.environ, {'SM_CURRENT_HOST': 'algo-3'}): + assert libchangehostname.call(30) == 'algo-3' + + +def test_libchangehostname_with_env_not_set(): + py_cmd = "import libchangehostname\nassert libchangehostname.call(30) == 'algo-9'" + + with pytest.raises(_errors.ExecuteUserScriptError): + _process.check_error([sys.executable, '-c', py_cmd], _errors.ExecuteUserScriptError)