Skip to content

Commit

Permalink
Remove supplemental containers from sdk (#29)
Browse files Browse the repository at this point in the history
* Remove supplemental containers from sdk

* bump version
  • Loading branch information
owen-t committed Dec 23, 2017
1 parent ba5023f commit 157d867
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 64 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def read(fname):


setup(name="sagemaker",
version="1.0.0",
version="1.0.1",
description="Open source library for training and deploying models on Amazon SageMaker.",
packages=find_packages('src'),
package_dir={'': 'src'},
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None):
"""
container_def = self.prepare_container_def(instance_type)
model_name = self.name or name_from_image(container_def['Image'])
self.sagemaker_session.create_model(model_name, self.role, container_def, [])
self.sagemaker_session.create_model(model_name, self.role, container_def)
production_variant = sagemaker.production_variant(model_name, instance_type, initial_instance_count)
self.endpoint_name = endpoint_name or model_name
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant])
Expand Down
19 changes: 4 additions & 15 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
self.sagemaker_client.create_training_job(**train_request)

def create_model(self, name, role, primary_container, supplemental_containers=None):
def create_model(self, name, role, primary_container):
"""Create an Amazon SageMaker ``Model``.
Specify the S3 location of the model artifacts and Docker image containing
Expand All @@ -253,36 +253,27 @@ def create_model(self, name, role, primary_container, supplemental_containers=No
primary_container (str or dict[str, str]): Docker image which defines the inference code.
You can also specify the return value of ``sagemaker.container_def()``, which is used to create
more advanced container configurations, including model containers which need artifacts from S3.
supplemental_containers (list[str or dict[str, str]]): List of Docker images which define
additional containers that need to be run in addition to the primary container (default: None).
You can also specify the return values of ``sagemaker.container_def()``, which the API uses to create
more advanced container configurations, including model containers which need artifacts from S3.
Returns:
str: Name of the Amazon SageMaker ``Model`` created.
"""
role = self.expand_role(role)
primary_container = _expand_container_def(primary_container)
if supplemental_containers is None:
supplemental_containers = []
supplemental_containers = [_expand_container_def(sc) for sc in supplemental_containers]
LOGGER.info('Creating model with name: {}'.format(name))
LOGGER.debug("create_model request: {}".format({
'name': name,
'role': role,
'primary_container': primary_container,
'supplemental_containers': supplemental_containers
'primary_container': primary_container
}))

self.sagemaker_client.create_model(ModelName=name,
PrimaryContainer=primary_container,
SupplementalContainers=supplemental_containers,
ExecutionRoleArn=role)

return name

def create_model_from_job(self, training_job_name, name=None, role=None, primary_container_image=None,
model_data_url=None, env={}, supplemental_containers=None):
model_data_url=None, env={}):
"""Create an Amazon SageMaker ``Model`` from a SageMaker Training Job.
Args:
Expand All @@ -296,8 +287,6 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
model_data_url (str): S3 location of the model data (default: None). If None, defaults to
the ``ModelS3Artifacts`` of ``training_job_name``.
env (dict[string,string]): Model environment variables (default: {}).
supplemental_containers (list[dict[str, str]]): A list of supplemental Docker containers
(default: None). Defines the ``SupplementalContainers`` property on the created ``Model``.
Returns:
str: The name of the created ``Model``.
Expand All @@ -309,7 +298,7 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
primary_container_image or training_job['AlgorithmSpecification']['TrainingImage'],
model_data_url=model_data_url or training_job['ModelArtifacts']['S3ModelArtifacts'],
env=env)
return self.create_model(name, role, primary_container, supplemental_containers)
return self.create_model(name, role, primary_container)

def create_endpoint_config(self, name, model_name, initial_instance_count, instance_type):
"""Create an Amazon SageMaker endpoint configuration.
Expand Down
49 changes: 13 additions & 36 deletions tests/unit/test_create_deploy_entities.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pytest
from mock import Mock

Expand All @@ -36,46 +36,23 @@ def sagemaker_session():


def test_create_model(sagemaker_session):
supplemental_containers = [FULL_CONTAINER_DEF, FULL_CONTAINER_DEF]

returned_name = sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, primary_container=FULL_CONTAINER_DEF,
supplemental_containers=supplemental_containers)
returned_name = sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, primary_container=FULL_CONTAINER_DEF)

assert returned_name == MODEL_NAME
sagemaker_session.sagemaker_client.create_model.assert_called_once_with(
ModelName=MODEL_NAME,
PrimaryContainer=FULL_CONTAINER_DEF,
SupplementalContainers=supplemental_containers,
ExecutionRoleArn=EXPANDED_ROLE)


def test_create_model_no_supplemental_containers(sagemaker_session):
sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, primary_container=FULL_CONTAINER_DEF)

_1, _2, create_model_kwargs = sagemaker_session.sagemaker_client.create_model.mock_calls[0]
assert create_model_kwargs['SupplementalContainers'] == []


def test_create_model_expand_primary_container(sagemaker_session):
sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, primary_container=IMAGE)

_1, _2, create_model_kwargs = sagemaker_session.sagemaker_client.create_model.mock_calls[0]
assert create_model_kwargs['PrimaryContainer'] == {'Environment': {}, 'Image': IMAGE}


def test_create_model_expand_supplemental_containers(sagemaker_session):
supp_image1 = 'suppimage1'
supp_image2 = 'suppimage2'
supplemental_containers = [supp_image1, supp_image2]

sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, primary_container=IMAGE,
supplemental_containers=supplemental_containers)

expected = [{'Environment': {}, 'Image': supp_image1}, {'Environment': {}, 'Image': supp_image2}]
_1, _2, create_model_kwargs = sagemaker_session.sagemaker_client.create_model.mock_calls[0]
assert create_model_kwargs['SupplementalContainers'] == expected


def test_create_endpoint_config(sagemaker_session):
returned_name = sagemaker_session.create_endpoint_config(name=ENDPOINT_CONFIG_NAME, model_name=MODEL_NAME,
initial_instance_count=INITIAL_INSTANCE_COUNT,
Expand Down
12 changes: 1 addition & 11 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,7 @@ def test_create_model_from_job(sagemaker_session):
ModelName='jobname',
PrimaryContainer={
'Environment': {}, 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz',
'Image': 'myimage'},
SupplementalContainers=[])
'Image': 'myimage'})


def test_create_model_from_job_with_image(sagemaker_session):
Expand All @@ -355,15 +354,6 @@ def test_create_model_from_job_with_container_def(sagemaker_session):
assert c_def['Environment'] == {'a': 'b'}


def test_create_model_from_job_with_supplemental_containers(sagemaker_session):
ims = sagemaker_session
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
ims.create_model_from_job(JOB_NAME, supplemental_containers=[sagemaker.container_def('some-image')])
[create_model_call] = ims.sagemaker_client.create_model.call_args_list
[c_def] = create_model_call[1]['SupplementalContainers']
assert c_def['Image'] == 'some-image'


def test_endpoint_from_production_variants(sagemaker_session):
ims = sagemaker_session
ims.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'})
Expand Down

0 comments on commit 157d867

Please sign in to comment.