Skip to content
This repository was archived by the owner on Aug 26, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
CHANGELOG
=========

2.3.4
=====

* feature: add capture_error flag to process.check_error and process.create and to all functions that runs process: modules.run, modules,run_module, and entry_point.run

2.3.3
=====

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def read(file_name):

setup(
name='sagemaker_containers',
version='2.3.3',
version='2.3.4',
description='Open source library for creating containers to run on Amazon SageMaker.',

packages=packages,
Expand Down
14 changes: 12 additions & 2 deletions src/sagemaker_containers/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import textwrap

import six


class ClientError(Exception):
pass
Expand All @@ -27,12 +29,20 @@ class _CalledProcessError(ClientError):
cmd, return_code, output
"""

def __init__(self, cmd, return_code=None):
def __init__(self, cmd, return_code=None, output=None):
self.return_code = return_code
self.cmd = cmd
self.output = output

def __str__(self):
message = '%s:\nCommand "%s"' % (type(self).__name__, self.cmd)
if six.PY3 and self.output:
error_msg = '\n%s' % self.output.decode('latin1')
elif self.output:
error_msg = '\n%s' % self.output
else:
error_msg = ''

message = '%s:\nCommand "%s"%s' % (type(self).__name__, self.cmd, error_msg)
return message.strip()


Expand Down
21 changes: 13 additions & 8 deletions src/sagemaker_containers/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,12 @@ def prepare(path, name): # type: (str, str) -> None
_files.write_file(os.path.join(path, 'MANIFEST.in'), data)


def install(path): # type: (str) -> None
def install(path, capture_error=False): # type: (str, bool) -> None
"""Install a Python module in the executing Python environment.
Args:
path (str): Real path location of the Python module.
capture_error (bool): Default false. If True, the running process captures the
stderr, and appends it to the returned Exception message in case of errors.
"""
cmd = '%s -m pip install -U . ' % _process.python_executable()

Expand All @@ -103,10 +105,11 @@ def install(path): # type: (str) -> None

logger.info('Installing module with the following command:\n%s', cmd)

_process.check_error(shlex.split(cmd), _errors.InstallModuleError, cwd=path)
_process.check_error(shlex.split(cmd), _errors.InstallModuleError, cwd=path, capture_error=capture_error)


def run(module_name, args=None, env_vars=None, wait=True): # type: (str, list, dict, bool) -> Popen
def run(module_name, args=None, env_vars=None, wait=True, capture_error=False):
# type: (str, list, dict, bool, bool) -> Popen
"""Run Python module as a script.

Search sys.path for the named module and execute its contents as the __main__ module.
Expand Down Expand Up @@ -154,6 +157,8 @@ def run(module_name, args=None, env_vars=None, wait=True): # type: (str, list,
module_name (str): module name in the same format required by python -m <module-name> cli command.
args (list): A list of program arguments.
env_vars (dict): A map containing the environment variables to be written.
capture_error (bool): Default false. If True, the running process captures the
stderr, and appends it to the returned Exception message in case of errors.
"""
args = args or []
env_vars = env_vars or {}
Expand All @@ -163,10 +168,10 @@ def run(module_name, args=None, env_vars=None, wait=True): # type: (str, list,
_logging.log_script_invocation(cmd, env_vars)

if wait:
return _process.check_error(cmd, _errors.ExecuteUserScriptError)
return _process.check_error(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)

else:
return _process.create(cmd, _errors.ExecuteUserScriptError)
return _process.create(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)


def import_module(uri, name=DEFAULT_MODULE_NAME, cache=None): # type: (str, str, bool) -> module
Expand Down Expand Up @@ -195,8 +200,8 @@ def import_module(uri, name=DEFAULT_MODULE_NAME, cache=None): # type: (str, str
six.reraise(_errors.ImportModuleError, _errors.ImportModuleError(e), sys.exc_info()[2])


def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, wait=True):
# type: (str, list, dict, str, bool, bool) -> Popen
def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, wait=True, capture_error=False):
# type: (str, list, dict, str, bool, bool, bool) -> Popen
"""Download, prepare and executes a compressed tar file from S3 or provided directory as a module.

SageMaker Python SDK saves the user provided scripts as compressed tar files in S3
Expand All @@ -222,7 +227,7 @@ def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, w

_env.write_env_vars(env_vars)

return run(name, args, env_vars, wait)
return run(name, args, env_vars, wait, capture_error)


def _warning_cache_deprecation(cache):
Expand Down
19 changes: 13 additions & 6 deletions src/sagemaker_containers/_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,26 @@
from sagemaker_containers import _env


def create(cmd, error_class, cwd=None, **kwargs):
def create(cmd, error_class, cwd=None, capture_error=False, **kwargs):
try:
return subprocess.Popen(cmd, env=os.environ, cwd=cwd or _env.code_dir, **kwargs)
stderr = subprocess.PIPE if capture_error else None
return subprocess.Popen(cmd, env=os.environ, cwd=cwd or _env.code_dir, stderr=stderr, **kwargs)
except Exception as e:
six.reraise(error_class, error_class(e), sys.exc_info()[2])


def check_error(cmd, error_class, **kwargs):
process = create(cmd, error_class, **kwargs)
return_code = process.wait()
def check_error(cmd, error_class, capture_error=False, **kwargs):
process = create(cmd, error_class, capture_error=capture_error, **kwargs)

if capture_error:
_, stderr = process.communicate()
return_code = process.poll()
else:
stderr = None
return_code = process.wait()

if return_code:
raise error_class(return_code=return_code, cmd=' '.join(cmd))
raise error_class(return_code=return_code, cmd=' '.join(cmd), output=stderr)
return process


Expand Down
24 changes: 15 additions & 9 deletions src/sagemaker_containers/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from sagemaker_containers import _env, _errors, _files, _logging, _modules, _process


def run(uri, user_entry_point, args, env_vars=None, wait=True):
# type: (str, str, list, dict, bool) -> subprocess.Popen
def run(uri, user_entry_point, args, env_vars=None, wait=True, capture_error=False):
# type: (str, str, list, dict, bool, bool) -> subprocess.Popen
"""Download, prepare and executes a compressed tar file from S3 or provided directory as an user
entrypoint. Runs the user entry point, passing env_vars as environment variables and args as command
arguments.
Expand Down Expand Up @@ -59,39 +59,45 @@ def run(uri, user_entry_point, args, env_vars=None, wait=True):
uri (str): the location of the module.
wait (bool): If True, holds the process executing the user entry-point.
If False, returns the process that is executing it.
capture_error (bool): Default false. If True, the running process captures the
stderr, and appends it to the returned Exception message in case of errors.

"""
env_vars = env_vars or {}
env_vars = env_vars.copy()

_files.download_and_extract(uri, user_entry_point, _env.code_dir)

install(user_entry_point, _env.code_dir)
install(user_entry_point, _env.code_dir, capture_error)

_env.write_env_vars(env_vars)

return _call(user_entry_point, args, env_vars, wait)
return _call(user_entry_point, args, env_vars, wait, capture_error)


def install(name, dst):
def install(name, dst, capture_error=False):
"""Install the user provided entry point to be executed as follow:
- add the path to sys path
- if the user entry point is a command, gives exec permissions to the script

Args:
name (str): name of the script or module.
dst (str): path to directory with the script or module.
capture_error (bool): Default false. If True, the running process captures the
stderr, and appends it to the returned Exception message in case of errors.
"""
if dst not in sys.path:
sys.path.insert(0, dst)

entrypoint_type = entry_point_type(dst, name)
if entrypoint_type is EntryPointType.PYTHON_PACKAGE:
_modules.install(dst)
_modules.install(dst, capture_error)
if entrypoint_type is EntryPointType.COMMAND:
os.chmod(os.path.join(dst, name), 511)


def _call(user_entry_point, args=None, env_vars=None, wait=True): # type: (str, list, dict, bool) -> Popen
def _call(user_entry_point, args=None, env_vars=None, wait=True, capture_error=False):
# type: (str, list, dict, bool, bool) -> Popen
args = args or []
env_vars = env_vars or {}

Expand All @@ -107,10 +113,10 @@ def _call(user_entry_point, args=None, env_vars=None, wait=True): # type: (str,
_logging.log_script_invocation(cmd, env_vars)

if wait:
return _process.check_error(cmd, _errors.ExecuteUserScriptError)
return _process.check_error(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)

else:
return _process.create(cmd, _errors.ExecuteUserScriptError)
return _process.create(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)


class EntryPointType(enum.Enum):
Expand Down
64 changes: 36 additions & 28 deletions test/functional/test_training_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def framework_training_fn():
model.save(model_file)


@pytest.mark.parametrize('user_script', [USER_SCRIPT_WITH_SAVE, USER_SCRIPT_WITH_SAVE])
def test_training_framework(user_script):
@pytest.mark.parametrize('user_script, capture_error',
[[USER_SCRIPT_WITH_SAVE, False], [USER_SCRIPT_WITH_SAVE, True]])
def test_training_framework(user_script, capture_error):
with pytest.raises(ImportError):
importlib.import_module(modules.DEFAULT_MODULE_NAME)

Expand Down Expand Up @@ -234,18 +235,19 @@ def test_trainer_report_failure():
assert 'No such file or directory' in message


def framework_training_with_script_mode_fn():
def framework_training_with_script_mode_fn(capture_error):
training_env = sagemaker_containers.training_env()

entry_point.run(training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(),
training_env.to_env_vars())
training_env.to_env_vars(), capture_error=capture_error)


def framework_training_with_run_modules_fn():
def framework_training_with_run_modules_fn(capture_error):
training_env = sagemaker_containers.training_env()

modules.run_module(training_env.module_dir, training_env.to_cmd_args(),
training_env.to_env_vars(), training_env.module_name)
training_env.to_env_vars(), training_env.module_name,
capture_error=capture_error)


def test_parameter_server():
Expand All @@ -261,10 +263,10 @@ def test_parameter_server():
process.kill()


@pytest.mark.parametrize('user_script, training_fn', [
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn],
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn]])
def test_script_mode(user_script, training_fn):
@pytest.mark.parametrize('user_script, training_fn, capture_error', [
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn, True],
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn, False]])
def test_script_mode(user_script, training_fn, capture_error):
channel = test.Channel.create(name='training')

features = [1, 2, 3, 4]
Expand All @@ -278,7 +280,7 @@ def test_script_mode(user_script, training_fn):

test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])

assert execute_an_wrap_exit(training_fn) == trainer.SUCCESS_CODE
assert execute_an_wrap_exit(training_fn, capture_error=capture_error) == trainer.SUCCESS_CODE

model_path = os.path.join(env.model_dir, 'saved_model')

Expand All @@ -290,10 +292,10 @@ def test_script_mode(user_script, training_fn):
assert model.optimizer == 'SGD'


@pytest.mark.parametrize('user_script, training_fn', [
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn],
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn]])
def test_script_mode_local_directory(user_script, training_fn, tmpdir):
@pytest.mark.parametrize('user_script, training_fn, capture_error', [
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn, False],
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn, True]])
def test_script_mode_local_directory(user_script, training_fn, capture_error, tmpdir):
channel = test.Channel.create(name='training')

features = [1, 2, 3, 4]
Expand All @@ -311,7 +313,7 @@ def test_script_mode_local_directory(user_script, training_fn, tmpdir):

test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel], local=True)

assert execute_an_wrap_exit(training_fn) == trainer.SUCCESS_CODE
assert execute_an_wrap_exit(training_fn, capture_error=capture_error) == trainer.SUCCESS_CODE

model_path = os.path.join(env.model_dir, 'saved_model')

Expand All @@ -329,10 +331,10 @@ def test_script_mode_local_directory(user_script, training_fn, tmpdir):
"""


@pytest.mark.parametrize('training_fn', [
framework_training_with_script_mode_fn,
framework_training_with_run_modules_fn])
def test_script_mode_client_error(training_fn):
@pytest.mark.parametrize('training_fn, capture_error', [
(framework_training_with_script_mode_fn, True),
(framework_training_with_run_modules_fn, False)])
def test_script_mode_client_error(training_fn, capture_error):
channel = test.Channel.create(name='training')

module = test.UserModule(test.File(name='user_script.py', data=USER_MODE_SCRIPT_WITH_ERROR))
Expand All @@ -342,16 +344,18 @@ def test_script_mode_client_error(training_fn):
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])

with pytest.raises(errors.ExecuteUserScriptError) as e:
training_fn()
training_fn(capture_error)

message = str(e.value)
assert 'ExecuteUserScriptError' in message
if capture_error:
assert 'ZeroDivisionError' in message


@pytest.mark.parametrize('training_fn', [
framework_training_with_script_mode_fn,
framework_training_with_run_modules_fn])
def test_script_mode_client_import_error(training_fn):
@pytest.mark.parametrize('training_fn, capture_error', [
[framework_training_with_script_mode_fn, True],
[framework_training_with_run_modules_fn, False]])
def test_script_mode_client_import_error(training_fn, capture_error):
channel = test.Channel.create(name='training')

requirements_file = test.File('requirements.txt', '42/0')
Expand All @@ -364,20 +368,24 @@ def test_script_mode_client_import_error(training_fn):
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])

with pytest.raises(errors.InstallModuleError) as e:
training_fn()
training_fn(capture_error)

message = str(e.value)
assert 'InstallModuleError:' in message

if capture_error:
assert "Invalid requirement: \'42/0\'" in message
assert "It looks like a path. File \'42/0\' does not exist." in message


def failure_message():
with open(os.path.join(env.output_dir, 'failure')) as f:
return f.read()


def execute_an_wrap_exit(fn):
def execute_an_wrap_exit(fn, **kargs):
try:
fn()
fn(**kargs)
return trainer.SUCCESS_CODE
except ValueError as e:
return int(str(e))
Loading