Skip to content

Commit

Permalink
Add support for PyTorch (#243)
Browse files Browse the repository at this point in the history
Add support for PyTorch framework.
  • Loading branch information
nadiaya committed Jun 20, 2018
1 parent b8f00ff commit 92eb47d
Show file tree
Hide file tree
Showing 17 changed files with 1,609 additions and 5 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
CHANGELOG
=========

1.4.3dev
========
1.5.0
=====
* feature: Add Support for PyTorch Framework
* feature: Estimators: add support for TensorFlow 1.7.0
* feature: Estimators: add support for TensorFlow 1.8.0
* feature: Allow Local Serving of Models in S3
Expand Down
16 changes: 15 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ You can install from source by cloning this repository and issuing a pip install

git clone https://github.com/aws/sagemaker-python-sdk.git
python setup.py sdist
pip install dist/sagemaker-1.4.2.tar.gz
pip install dist/sagemaker-1.5.0.tar.gz

Supported Python versions
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -236,6 +236,20 @@ More details at `Chainer SageMaker Estimators and Models`_.
.. _Chainer SageMaker Estimators and Models: src/sagemaker/chainer/README.rst


PyTorch SageMaker Estimators
-------------------------------

With PyTorch Estimators, you can train and host PyTorch models on Amazon SageMaker.

Supported versions of PyTorch: ``0.4.0``

You can visit the PyTorch repository at https://github.com/pytorch/pytorch.

More details at `PyTorch SageMaker Estimators and Models`_.

.. _PyTorch SageMaker Estimators and Models: src/sagemaker/pytorch/README.rst


AWS SageMaker Estimators
------------------------
Amazon SageMaker provides several built-in machine learning algorithms that you can use for a variety of problem types.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def read(fname):


setup(name="sagemaker",
version="1.4.2",
version="1.5.0",
description="Open source library for training and deploying models on Amazon SageMaker.",
packages=find_packages('src'),
package_dir={'': 'src'},
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def framework_name_from_image(image_name):
else:
# extract framework, python version and image tag
# We must support both the legacy and current image name format.
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer):(.*?)-(.*?)-(py2|py3)$')
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer|pytorch):(.*?)-(.*?)-(py2|py3)$')
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
name_match = name_pattern.match(sagemaker_match.group(8))
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))
Expand Down
711 changes: 711 additions & 0 deletions src/sagemaker/pytorch/README.rst

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions src/sagemaker/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from sagemaker.pytorch.estimator import PyTorch
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor

__all__ = [PyTorch, PyTorchModel, PyTorchPredictor]
16 changes: 16 additions & 0 deletions src/sagemaker/pytorch/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

PYTORCH_VERSION = '0.4'
PYTHON_VERSION = 'py3'
112 changes: 112 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from sagemaker.estimator import Framework
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.pytorch.model import PyTorchModel


class PyTorch(Framework):
"""Handle end-to-end training and deployment of custom PyTorch code."""

__framework_name__ = "pytorch"

def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION,
framework_version=PYTORCH_VERSION, **kwargs):
"""
This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker
Training Job. The managed PyTorch environment is an Amazon-built Docker container that executes functions
defined in the supplied ``entry_point`` Python script.
Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a
hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.pytorch.model.PyTorchPredictor` instance
that can be used to perform inference against the hosted model.
Technical documentation on preparing PyTorch scripts for SageMaker training and using the PyTorch Estimator is
available on the project home-page: https://github.com/aws/sagemaker-python-sdk
Args:
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
source_dir (str): Path (absolute or relative) to a directory with any other training
source code dependencies aside from tne entry point file (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
For convenience, this accepts other types for keys and values, but ``str()`` will be called
to convert them before training.
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
One of 'py2' or 'py3'.
framework_version (str): PyTorch version you want to use for executing your model training code.
List of supported versions https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
"""
super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
self.py_version = py_version
self.framework_version = framework_version

def train_image(self):
"""Return the Docker image to use for training.
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
find the image to use for model training.
Returns:
str: The URI of the Docker image.
"""
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
self.train_instance_type, framework_version=self.framework_version,
py_version=self.py_version)

def create_model(self, model_server_workers=None):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
Args:
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
Returns:
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
"""
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version,
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)

@classmethod
def _prepare_init_params_from_job_description(cls, job_details):
"""Convert the job description to init params that can be handled by the class constructor
Args:
job_details: the returned job details from a describe_training_job API call.
Returns:
dictionary: The transformed init_params
"""
init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details)
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))

init_params['py_version'] = py_version
init_params['framework_version'] = framework_version_from_tag(tag)

training_job_name = init_params['base_job_name']

if framework != cls.__framework_name__:
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))

return init_params
94 changes: 94 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import sagemaker
from sagemaker.fw_utils import create_image_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
from sagemaker.utils import name_from_image


class PyTorchPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against PyTorch Endpoints.
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for PyTorch
inference."""

def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``PyTorchPredictor``.
Args:
endpoint_name (str): The name of the endpoint to perform inference on.
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
"""
super(PyTorchPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer)


class PyTorchModel(FrameworkModel):
"""An PyTorch SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""

__framework_name__ = 'pytorch'

def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_VERSION,
framework_version=PYTORCH_VERSION, predictor_cls=PyTorchPredictor,
model_server_workers=None, **kwargs):
"""Initialize an PyTorchModel.
Args:
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
After the endpoint is created, the inference code might use the IAM role,
if it needs to access an AWS resource.
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
image (str): A Docker image URI (default: None). If not specified, a default image for PyTorch will be used.
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
framework_version (str): PyTorch version you want to use for executing your model training code.
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
invoking this function on the created endpoint name.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
"""
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)
self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type):
"""Return a container definition with framework configuration set in model environment variables.
Args:
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
Returns:
dict[str, str]: A container definition object usable with the CreateModel API.
"""
deploy_image = self.image
if not deploy_image:
region_name = self.sagemaker_session.boto_session.region_name
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
self.framework_version, self.py_version)
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def mxnet_version(request):
return request.param


@pytest.fixture(scope='module', params=["0.4", "0.4.0"])
def pytorch_version(request):
return request.param


@pytest.fixture(scope='module', params=['4.0', '4.0.0'])
def chainer_version(request):
return request.param
Expand All @@ -96,6 +101,11 @@ def mxnet_full_version(request):
return request.param


@pytest.fixture(scope='module', params=["0.4.0"])
def pytorch_full_version(request):
return request.param


@pytest.fixture(scope='module', params=['4.0.0'])
def chainer_full_version(request):
return request.param
3 changes: 3 additions & 0 deletions tests/data/pytorch_mnist/failure_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if __name__ == '__main__':
"""For use with integration tests expecting failures."""
raise Exception('This failure is expected.')
Loading

0 comments on commit 92eb47d

Please sign in to comment.