Skip to content

Commit

Permalink
Add support for async fit() (#59)
Browse files Browse the repository at this point in the history
when calling fit(wait=False) it will return immediately. The training
job will carry on even if the process exits. by using attach() the
estimator can be retrieved by providing the training job name.

_prepare_init_params_from_job_description() is now a classmethod instead
of being a static method. Each class is responsible to implement their
specific logic to convert a training job description into arguments that
can be passed to its own __init__()
  • Loading branch information
iquintero committed Feb 1, 2018
1 parent 354ded3 commit e1d79d5
Show file tree
Hide file tree
Showing 16 changed files with 556 additions and 142 deletions.
23 changes: 21 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ SageMaker Python SDK provides several high-level abstractions for working with A
- **Estimators**: Encapsulate training on SageMaker. Can be ``fit()`` to run training, then the resulting model ``deploy()`` ed to a SageMaker Endpoint.
- **Models**: Encapsulate built ML models. Can be ``deploy()`` ed to a SageMaker Endpoint.
- **Predictors**: Provide real-time inference and transformation using Python data-types against a SageMaker Endpoint.
- **Session**: Provides a collection of convience methods for working with SageMaker resources.
- **Session**: Provides a collection of convenience methods for working with SageMaker resources.

Estimator and Model implementations for MXNet, TensorFlow, and Amazon ML algorithms are included. There's also an Estimator that runs SageMaker compatible custom Docker containers, allowing you to run your own ML algorithms via SageMaker Python SDK.

Expand Down Expand Up @@ -1150,6 +1150,7 @@ Optional arguments
- ``wait (bool)``: Defaults to True, whether to block and wait for the
training script to complete before returning.
If set to False, it will return immediately, and can later be attached to.
- ``logs (bool)``: Defaults to True, whether to show logs produced by training
job in the Python session. Only meaningful when wait is True.
- ``run_tensorboard_locally (bool)``: Defaults to False. Executes TensorBoard in a different
Expand Down Expand Up @@ -1178,9 +1179,25 @@ the ``TensorFlow`` estimator parameter ``training_steps`` is finished or when th
job execution time reaches the ``TensorFlow`` estimator parameter ``train_max_run``.
When the training job finishes, a `TensorFlow serving <https://www.tensorflow.org/serving/serving_basic>`_
with the result of the training is generated and saved to the S3 location define by
with the result of the training is generated and saved to the S3 location defined by
the ``TensorFlow`` estimator parameter ``output_path``.
If the ``wait=False`` flag is passed to ``fit``, then it will return immediately. The training job will continue running
asynchronously. At a later time, a Tensorflow Estimator can be obtained by attaching to the existing training job. If
the training job is not finished it will start showing the standard output of training and wait until it completes.
After attaching, the estimator can be deployed as usual.
.. code:: python
tf_estimator.fit(your_input_data, wait=False)
training_job_name = tf_estimator.latest_training_job.name
# after some time, or in a separate python notebook, we can attach to it again.
tf_estimator = TensorFlow.attach(training_job_name=training_job_name)
The evaluation process
""""""""""""""""""""""
Expand Down Expand Up @@ -1244,6 +1261,8 @@ You can access TensorBoard locally at http://localhost:6006 or using your SakeMa
`https*workspace_base_url*proxy/6006/ <proxy/6006/>`_ (TensorBoard will not work if you forget to put the slash,
'/', in end of the url). If TensorBoard started on a different port, adjust these URLs to match.
Note that TensorBoard is not supported when passing wait=False to ``fit``.
Deploying TensorFlow Serving models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
25 changes: 25 additions & 0 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ def data_location(self, data_location):
data_location = data_location + '/'
self._data_location = data_location

@classmethod
def _prepare_init_params_from_job_description(cls, job_details):
"""Convert the job description to init params that can be handled by the class constructor
Args:
job_details: the returned job details from a describe_training_job API call.
Returns:
dictionary: The transformed init_params
"""
init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description(job_details)

# The hyperparam names may not be the same as the class attribute that holds them,
# for instance: local_lloyd_init_method is called local_init_method. We need to map these
# and pass the correct name to the constructor.
for attribute, value in cls.__dict__.items():
if isinstance(value, hp):
if value.name in init_params['hyperparameters']:
init_params[attribute] = init_params['hyperparameters'][value.name]

del init_params['hyperparameters']
del init_params['image']
return init_params

def fit(self, records, mini_batch_size=None, **kwargs):
"""Fit this Estimator on serialized Record objects, stored in S3.
Expand Down
190 changes: 137 additions & 53 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,60 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
self.latest_training_job = _TrainingJob.start_new(self, inputs)
if wait:
self.latest_training_job.wait(logs=logs)
else:
raise NotImplemented('Asynchronous fit not available')

@classmethod
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
"""Create an Estimator from existing training job data.
Args:
init_params (dict): The init_params the training job was created with.
hyperparameters (dict): The hyperparameters the training job was created with.
image (str): Container image (if any) the training job was created with
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
Returns: An instance of the calling Estimator Class.
"""
raise NotImplementedError()

@classmethod
def attach(cls, training_job_name, sagemaker_session=None, job_details=None):
"""Attach to an existing training job.
Create an Estimator bound to an existing training job, each subclass is responsible to implement
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
job description to the arguments that the class constructor expects. After attaching, if the training job has a
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
If the training job is in progress, attach will block and display log messages
from the training job, until the training job completes.
Args:
training_job_name (str): The name of the training job to attach to.
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
Examples:
>>> my_estimator.fit(wait=False)
>>> training_job_name = my_estimator.latest_training_job.name
Later on:
>>> attached_estimator = Estimator.attach(training_job_name)
>>> attached_estimator.deploy()
Returns:
Instance of the calling ``Estimator`` Class with the attached training job.
"""
sagemaker_session = sagemaker_session or Session()

job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
init_params = cls._prepare_init_params_from_job_description(job_details)

estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
training_job_name=init_params['base_job_name'])
estimator.latest_training_job.wait()
return estimator

def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs):
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
Expand Down Expand Up @@ -202,21 +254,33 @@ def create_model(self, **kwargs):
"""
pass

@staticmethod
def _prepare_estimator_params_from_job_description(job_details):
estimator_params = dict()
@classmethod
def _prepare_init_params_from_job_description(cls, job_details):
"""Convert the job description to init params that can be handled by the class constructor
Args:
job_details: the returned job details from a describe_training_job API call.
Returns:
dictionary: The transformed init_params
"""
init_params = dict()

estimator_params['role'] = job_details['RoleArn']
estimator_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
estimator_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
estimator_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
estimator_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
estimator_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
estimator_params['base_job_name'] = job_details['TrainingJobName']
estimator_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
estimator_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
init_params['role'] = job_details['RoleArn']
init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
init_params['base_job_name'] = job_details['TrainingJobName']
init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']

return estimator_params, job_details['HyperParameters'], job_details['AlgorithmSpecification']['TrainingImage']
init_params['hyperparameters'] = job_details['HyperParameters']
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']

return init_params

def delete_endpoint(self):
"""Delete an Amazon SageMaker ``Endpoint``.
Expand Down Expand Up @@ -333,7 +397,8 @@ class Estimator(EstimatorBase):

def __init__(self, image_name, role, train_instance_count, train_instance_type,
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None,
hyperparameters=None):
"""Initialize an ``Estimator`` instance.
Args:
Expand Down Expand Up @@ -365,9 +430,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
"""
self.image_name = image_name
self.hyperparam_dict = {}
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
train_volume_size, train_max_run, input_mode,
output_path, output_kms_key, base_job_name, sagemaker_session)
Expand Down Expand Up @@ -422,6 +488,22 @@ def predict_wrapper(endpoint, session):
return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session,
predictor_cls=predictor_cls, **kwargs)

@classmethod
def _prepare_init_params_from_job_description(cls, job_details):
"""Convert the job description to init params that can be handled by the class constructor
Args:
job_details: the returned job details from a describe_training_job API call.
Returns:
dictionary: The transformed init_params
"""
init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details)

init_params['image_name'] = init_params.pop('image')
return init_params


class Framework(EstimatorBase):
"""Base class that cannot be instantiated directly.
Expand Down Expand Up @@ -528,12 +610,37 @@ def hyperparameters(self):
return self._json_encode_hyperparameters(self._hyperparameters)

@classmethod
def attach(cls, training_job_name, sagemaker_session=None, **kwargs):
def _prepare_init_params_from_job_description(cls, job_details):
"""Convert the job description to init params that can be handled by the class constructor
Args:
job_details: the returned job details from a describe_training_job API call.
Returns:
dictionary: The transformed init_params
"""
init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details)

init_params['entry_point'] = json.loads(init_params['hyperparameters'].get(SCRIPT_PARAM_NAME))
init_params['source_dir'] = json.loads(init_params['hyperparameters'].get(DIR_PARAM_NAME))
init_params['enable_cloudwatch_metrics'] = json.loads(
init_params['hyperparameters'].get(CLOUDWATCH_METRICS_PARAM_NAME))
init_params['container_log_level'] = json.loads(
init_params['hyperparameters'].get(CONTAINER_LOG_LEVEL_PARAM_NAME))

init_params['hyperparameters'] = {k: json.loads(v) for k, v in init_params['hyperparameters'].items()}

return init_params

@classmethod
def attach(cls, training_job_name, sagemaker_session=None):
"""Attach to an existing training job.
Create an Estimator bound to an existing training job. After attaching, if
the training job has a Complete status, it can be ``deploy()`` ed to create
a SageMaker Endpoint and return a ``Predictor``.
Create an Estimator bound to an existing training job, each subclass is responsible to implement
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
job description to the arguments that the class constructor expects. After attaching, if the training job has a
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
If the training job is in progress, attach will block and display log messages
from the training job, until the training job completes.
Expand All @@ -543,41 +650,18 @@ def attach(cls, training_job_name, sagemaker_session=None, **kwargs):
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
Examples:
>>> my_estimator.fit(wait=False)
>>> training_job_name = my_estimator.latest_training_job.name
Later on:
>>> attached_estimator = Estimator.attach(training_job_name)
>>> attached_estimator.deploy()
Returns:
sagemaker.estimator.Framework: ``Estimator`` with the attached training job.
Instance of the calling ``Estimator`` Class with the attached training job.
"""
sagemaker_session = sagemaker_session or Session()

if training_job_name is not None:
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
init_params, hp, _ = cls._prepare_estimator_params_from_job_description(job_details)

else:
# this case is only valid when called from inheriting class and then the class must declare framework
if not hasattr(cls, '__framework_name__'):
raise ValueError('must specify training_job name')
init_params = dict(kwargs)
hp = init_params.pop('hyperparameters')

# parameters for framework classes
framework_init_params = dict()
framework_init_params['entry_point'] = json.loads(hp.get(SCRIPT_PARAM_NAME))
framework_init_params['source_dir'] = json.loads(hp.get(DIR_PARAM_NAME))
framework_init_params['enable_cloudwatch_metrics'] = json.loads(hp.get(CLOUDWATCH_METRICS_PARAM_NAME))
framework_init_params['container_log_level'] = json.loads(hp.get(CONTAINER_LOG_LEVEL_PARAM_NAME))

# drop json and remove other SageMaker specific additions
hyperparameters = {entry: json.loads(hp[entry]) for entry in hp}
framework_init_params['hyperparameters'] = hyperparameters

init_params.update(framework_init_params)

estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
training_job_name=init_params['base_job_name'])
estimator.latest_training_job.wait()
estimator = super(Framework, cls).attach(training_job_name, sagemaker_session)
estimator.uploaded_code = UploadedCode(estimator.source_dir, estimator.entry_point)
return estimator

Expand Down

0 comments on commit e1d79d5

Please sign in to comment.