Skip to content

Commit

Permalink
Add Local Mode support (#115)
Browse files Browse the repository at this point in the history
* Add Local Mode support.

When passing "local" as the instance type for any estimator,
training and deployment happens locally.

Similarly, using "local_gpu" will use nvidia-docker-compose and
work for GPU training.
  • Loading branch information
iquintero committed Apr 1, 2018
1 parent 4f92fbd commit 6184b22
Show file tree
Hide file tree
Showing 13 changed files with 1,116 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CHANGELOG
1.1.dev4
========
* feature: Frameworks: Use more idiomatic ECR repository naming scheme
* feature: Add Support for Local Mode

1.1.3
========
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def read(fname):
],

# Declare minimal set for installation
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=1.0.0'],
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=1.0.0', 'urllib3>=1.2',
'PyYAML>=3.2'],

extras_require={
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor

from sagemaker.local.local_session import LocalSession

from sagemaker.model import Model
from sagemaker.predictor import RealTimePredictor
from sagemaker.session import Session
Expand All @@ -34,5 +36,5 @@
LinearLearnerModel, LinearLearnerPredictor,
LDA, LDAModel, LDAPredictor,
FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor,
Model, NTM, NTMModel, NTMPredictor, RealTimePredictor, Session,
Model, NTM, NTMModel, NTMPredictor, RealTimePredictor, Session, LocalSession,
container_def, s3_input, production_variant, get_execution_role]
15 changes: 13 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sagemaker.fw_utils import tar_and_upload_dir
from sagemaker.fw_utils import parse_s3_url
from sagemaker.fw_utils import UploadedCode
from sagemaker.local.local_session import LocalSession
from sagemaker.model import Model
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
Expand Down Expand Up @@ -78,7 +79,17 @@ def __init__(self, role, train_instance_count, train_instance_type,
self.train_volume_size = train_volume_size
self.train_max_run = train_max_run
self.input_mode = input_mode
self.sagemaker_session = sagemaker_session or Session()

if self.train_instance_type in ('local', 'local_gpu'):
self.local_mode = True
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
raise RuntimeError("Distributed Training in Local GPU is not supported")

self.sagemaker_session = LocalSession()
else:
self.local_mode = False
self.sagemaker_session = sagemaker_session or Session()

self.base_job_name = base_job_name
self._current_job_name = None
self.output_path = output_path
Expand Down Expand Up @@ -303,7 +314,7 @@ def start_new(cls, estimator, inputs):
"""Create a new Amazon SageMaker training job from the estimator.
Args:
estimator (sagemaker.estimator.Framework): Estimator object created by the user.
estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user.
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
Returns:
Expand Down
24 changes: 14 additions & 10 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,23 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
str: The appropriate image URI based on the given parameters.
"""

if not instance_type.startswith('ml.'):
# Handle Local Mode
if instance_type.startswith('local'):
device_type = 'cpu' if instance_type == 'local' else 'gpu'
elif not instance_type.startswith('ml.'):
raise ValueError('{} is not a valid SageMaker instance type. See: '
'https://aws.amazon.com/sagemaker/pricing/instance-types/'.format(instance_type))
family = instance_type.split('.')[1]

# For some frameworks, we have optimized images for specific families, e.g c5 or p3. In those cases,
# we use the family name in the image tag. In other cases, we use 'cpu' or 'gpu'.
if family in optimized_families:
device_type = family
elif family[0] in ['g', 'p']:
device_type = 'gpu'
else:
device_type = 'cpu'
family = instance_type.split('.')[1]

# For some frameworks, we have optimized images for specific families, e.g c5 or p3. In those cases,
# we use the family name in the image tag. In other cases, we use 'cpu' or 'gpu'.
if family in optimized_families:
device_type = family
elif family[0] in ['g', 'p']:
device_type = 'gpu'
else:
device_type = 'cpu'

tag = "{}-{}-{}".format(framework_version, device_type, py_version)
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
Expand Down
12 changes: 12 additions & 0 deletions src/sagemaker/local/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

0 comments on commit 6184b22

Please sign in to comment.