Skip to content

Commit

Permalink
Support user-supplied requirements.txt for TensorFlow (#80)
Browse files Browse the repository at this point in the history
* Support requirements.txt for general case

* Support requirements.txt for TensorFlow

* Clarify wording in docstring

* Remove added integ test after discussion on #80

* Fix import statements

* Add validation for requirements file (and corresponding unit tests)

* Fix typos and flake8 errors

* Add dummy requirements file for unit tests

* Document argument for requirements file in README

* Add link to pip user guide for 'requirements.txt' description
  • Loading branch information
laurenyu committed Mar 2, 2018
1 parent 5846fb6 commit 8a3dea2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 21 deletions.
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,10 @@ you can specify these as keyword arguments.
other training source code dependencies aside from the entry point
file. Structure within this directory will be preserved when training
on SageMaker.
- ``requirements_file (str)`` Path to a ``requirements.txt`` file. The path should
be within and relative to ``source_dir``. This is a file containing a list of items to be
installed using pip install. Details on the format can be found in the
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
- ``hyperparameters (dict[str,ANY])`` Hyperparameters that will be used for training.
Will be made accessible as a dict[] to the training code on
SageMaker. Some hyperparameters will be interpreted by TensorFlow and can be use to
Expand Down
36 changes: 29 additions & 7 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class TensorFlow(Framework):

__framework_name__ = 'tensorflow'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version="py2",
framework_version=TF_VERSION, **kwargs):
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
framework_version=TF_VERSION, requirements_file='', **kwargs):
"""Initialize an ``TensorFlow`` estimator.
Args:
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
Expand All @@ -120,6 +120,9 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
framework_version (str): TensorFlow version you want to use for executing your model training code.
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
relative to ``source_dir``. Details on the format can be found in the
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
**kwargs: Additional kwargs passed to the Framework constructor.
"""
super(TensorFlow, self).__init__(**kwargs)
Expand All @@ -129,6 +132,22 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
self.training_steps = training_steps
self.evaluation_steps = evaluation_steps

self._validate_requirements_file(requirements_file)
self.requirements_file = requirements_file

def _validate_requirements_file(self, requirements_file):
if not requirements_file:
return

if not self.source_dir:
raise ValueError('Must specify source_dir along with a requirements file.')

if os.path.isabs(requirements_file):
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(requirements_file))

if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))

def fit(self, inputs, wait=True, logs=True, job_name=None, run_tensorboard_locally=False):
"""Train a model using the input training dataset.
Expand Down Expand Up @@ -228,11 +247,13 @@ def create_model(self, model_server_workers=None):
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
"""
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version,
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env,
name=self._current_job_name, container_log_level=self.container_log_level,
code_location=self.code_location, py_version=self.py_version,
framework_version=self.framework_version, model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session)

def hyperparameters(self):
"""Return hyperparameters used by your custom TensorFlow code during model training."""
Expand All @@ -243,7 +264,8 @@ def hyperparameters(self):

additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
'training_steps': self.training_steps,
'evaluation_steps': self.evaluation_steps}
'evaluation_steps': self.evaluation_steps,
'sagemaker_requirements': self.requirements_file}

hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
return hyperparameters
1 change: 1 addition & 0 deletions tests/data/dummy_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fake-requirement-for-unit-tests==1.0.0
45 changes: 31 additions & 14 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import logging

import json
import logging
import os

import pytest
from mock import Mock, patch

from sagemaker.fw_utils import create_image_uri
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.session import s3_input
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow import defaults
from sagemaker.fw_utils import create_image_uri
from sagemaker.tensorflow import TensorFlowPredictor, TensorFlowModel
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowPredictor, TensorFlowModel

DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py')
SCRIPT_FILE = 'dummy_script.py'
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
REQUIREMENTS_FILE = 'dummy_requirements.txt'
TIMESTAMP = '2017-11-06-14:14:15.673'
TIME = 1510006209.073025
BUCKET_NAME = 'mybucket'
Expand Down Expand Up @@ -85,6 +86,7 @@ def _create_train_job(tf_version):
'training_steps': '1000',
'evaluation_steps': '10',
'sagemaker_program': json.dumps('dummy_script.py'),
'sagemaker_requirements': '"{}"'.format(REQUIREMENTS_FILE),
'sagemaker_submit_directory': json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(
BUCKET_NAME, JOB_NAME)),
'sagemaker_enable_cloudwatch_metrics': 'false',
Expand All @@ -100,10 +102,10 @@ def _create_train_job(tf_version):

def _build_tf(sagemaker_session, framework_version=defaults.TF_VERSION, train_instance_type=None,
checkpoint_path=None, enable_cloudwatch_metrics=False, base_job_name=None,
training_steps=None, evalutation_steps=None, **kwargs):
training_steps=None, evaluation_steps=None, **kwargs):
return TensorFlow(entry_point=SCRIPT_PATH,
training_steps=training_steps,
evaluation_steps=evalutation_steps,
evaluation_steps=evaluation_steps,
framework_version=framework_version,
role=ROLE,
sagemaker_session=sagemaker_session,
Expand Down Expand Up @@ -158,6 +160,20 @@ def test_tf_deploy_model_server_workers_unset(sagemaker_session):
assert MODEL_SERVER_WORKERS_PARAM_NAME.upper() not in sagemaker_session.method_calls[3][1][2]['Environment']


def test_tf_invalid_requirements_path(sagemaker_session):
requirements_file = '/foo/bar/requirements.txt'
with pytest.raises(ValueError) as e:
_build_tf(sagemaker_session, requirements_file=requirements_file, source_dir=DATA_DIR)
assert 'Requirements file {} is not a path relative to source_dir.'.format(requirements_file) in str(e.value)


def test_tf_nonexistent_requirements_path(sagemaker_session):
requirements_file = 'nonexistent_requirements.txt'
with pytest.raises(ValueError) as e:
_build_tf(sagemaker_session, requirements_file=requirements_file, source_dir=DATA_DIR)
assert 'Requirements file {} does not exist.'.format(requirements_file) in str(e.value)


def test_create_model(sagemaker_session, tf_version):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
Expand Down Expand Up @@ -186,9 +202,9 @@ def test_create_model(sagemaker_session, tf_version):
@patch('time.strftime', return_value=TIMESTAMP)
@patch('time.time', return_value=TIME)
def test_tf(time, strftime, sagemaker_session, tf_version):
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE, framework_version=tf_version)
tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, training_steps=1000,
evaluation_steps=10, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
framework_version=tf_version, requirements_file=REQUIREMENTS_FILE, source_dir=DATA_DIR)

inputs = 's3://mybucket/train'

Expand All @@ -210,6 +226,7 @@ def test_tf(time, strftime, sagemaker_session, tf_version):
assert {'Environment':
{'SAGEMAKER_SUBMIT_DIRECTORY': 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME),
'SAGEMAKER_PROGRAM': 'dummy_script.py',
'SAGEMAKER_REQUIREMENTS': 'dummy_requirements.txt',
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
'SAGEMAKER_REGION': 'us-west-2',
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
Expand Down Expand Up @@ -315,7 +332,7 @@ def test_tf_training_and_evaluation_steps_not_set(sagemaker_session):
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name)

tf = _build_tf(sagemaker_session, training_steps=None, evalutation_steps=None, output_path=output_path)
tf = _build_tf(sagemaker_session, training_steps=None, evaluation_steps=None, output_path=output_path)
tf.fit(inputs=s3_input('s3://mybucket/train'))
assert tf.hyperparameters()['training_steps'] == 'null'
assert tf.hyperparameters()['evaluation_steps'] == 'null'
Expand All @@ -325,7 +342,7 @@ def test_tf_training_and_evaluation_steps(sagemaker_session):
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name)

tf = _build_tf(sagemaker_session, training_steps=123, evalutation_steps=456, output_path=output_path)
tf = _build_tf(sagemaker_session, training_steps=123, evaluation_steps=456, output_path=output_path)
tf.fit(inputs=s3_input('s3://mybucket/train'))
assert tf.hyperparameters()['training_steps'] == '123'
assert tf.hyperparameters()['evaluation_steps'] == '456'
Expand Down

0 comments on commit 8a3dea2

Please sign in to comment.