Skip to content

Commit

Permalink
Revert "fix issue-987 error by adding instance_type in endpoint_name (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mvsusp committed Oct 1, 2019
1 parent a907597 commit 3ce2d95
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 36 deletions.
5 changes: 2 additions & 3 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,14 +446,13 @@ def deploy(
if endpoint_name:
self.endpoint_name = endpoint_name
else:
self.endpoint_name = self.name + "-" + instance_type.replace(".", "-")
self.endpoint_name = self.name
if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix):
self.endpoint_name += compiled_model_suffix

if update_endpoint:
self.sagemaker_session.delete_endpoint_config(endpoint_config_name=self.endpoint_name)
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
name=self.endpoint_name,
name=self.name,
model_name=self.name,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
Expand Down
28 changes: 10 additions & 18 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,47 +1297,39 @@ def endpoint_from_model_data(
"""

model_environment_vars = model_environment_vars or {}
model_name = name or name_from_image(deployment_image)
endpoint_name = name or (
name_from_image(deployment_image) + "-" + instance_type.replace(".", "-")
)
name = name or name_from_image(deployment_image)
model_vpc_config = vpc_utils.sanitize(model_vpc_config)

if _deployment_entity_exists(
lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
lambda: self.sagemaker_client.describe_endpoint(EndpointName=name)
):
raise ValueError(
'Endpoint with name "{}" already exists; please pick a different name.'.format(
endpoint_name
)
'Endpoint with name "{}" already exists; please pick a different name.'.format(name)
)

if not _deployment_entity_exists(
lambda: self.sagemaker_client.describe_model(ModelName=model_name)
lambda: self.sagemaker_client.describe_model(ModelName=name)
):
primary_container = container_def(
image=deployment_image, model_data_url=model_s3_location, env=model_environment_vars
)
self.create_model(
name=model_name,
role=role,
container_defs=primary_container,
vpc_config=model_vpc_config,
name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config
)

if not _deployment_entity_exists(
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_name)
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)
):
self.create_endpoint_config(
name=endpoint_name,
model_name=model_name,
name=name,
model_name=name,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
accelerator_type=accelerator_type,
)

self.create_endpoint(endpoint_name=endpoint_name, config_name=endpoint_name, wait=wait)
return endpoint_name
self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
return name

def endpoint_from_production_variants(
self, name, production_variants, tags=None, kms_key=None, wait=True
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_deploy_model_with_update_endpoint(
EndpointConfigName=new_config_name
)

assert old_config_name == new_config_name
assert old_config_name != new_config_name
assert new_config["ProductionVariants"][0]["InstanceType"] == cpu_instance_type
assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1

Expand Down
13 changes: 6 additions & 7 deletions tests/unit/test_endpoint_from_model_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
DEPLOY_ROLE = "mydeployrole"
ENV_VARS = {"PYTHONUNBUFFERED": "TRUE", "some": "nonsense"}
NAME_FROM_IMAGE = "namefromimage"
DEFAULT_ENDPOINT_NAME = "namefromimage-ml-c4-xlarge"
REGION = "us-west-2"


Expand Down Expand Up @@ -65,28 +64,28 @@ def test_all_defaults_no_existing_entities(name_from_image_mock, sagemaker_sessi
)

sagemaker_session.sagemaker_client.describe_endpoint.assert_called_once_with(
EndpointName=DEFAULT_ENDPOINT_NAME
EndpointName=NAME_FROM_IMAGE
)
sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(
ModelName=NAME_FROM_IMAGE
)
sagemaker_session.sagemaker_client.describe_endpoint_config.assert_called_once_with(
EndpointConfigName=DEFAULT_ENDPOINT_NAME
EndpointConfigName=NAME_FROM_IMAGE
)
sagemaker_session.create_model.assert_called_once_with(
name=NAME_FROM_IMAGE, role=DEPLOY_ROLE, container_defs=CONTAINER_DEF, vpc_config=None
)
sagemaker_session.create_endpoint_config.assert_called_once_with(
name=DEFAULT_ENDPOINT_NAME,
name=NAME_FROM_IMAGE,
model_name=NAME_FROM_IMAGE,
initial_instance_count=INITIAL_INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
accelerator_type=None,
)
sagemaker_session.create_endpoint.assert_called_once_with(
endpoint_name=DEFAULT_ENDPOINT_NAME, config_name=DEFAULT_ENDPOINT_NAME, wait=False
endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False
)
assert returned_name == DEFAULT_ENDPOINT_NAME
assert returned_name == NAME_FROM_IMAGE


@patch("sagemaker.session.name_from_image", return_value=NAME_FROM_IMAGE)
Expand Down Expand Up @@ -153,7 +152,7 @@ def test_model_and_endpoint_config_exist(name_from_image_mock, sagemaker_session
sagemaker_session.create_model.assert_not_called()
sagemaker_session.create_endpoint_config.assert_not_called()
sagemaker_session.create_endpoint.assert_called_once_with(
endpoint_name=DEFAULT_ENDPOINT_NAME, config_name=DEFAULT_ENDPOINT_NAME, wait=False
endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False
)


Expand Down
14 changes: 7 additions & 7 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
ENTRY_POINT = "blah.py"
INSTANCE_TYPE = "p2.xlarge"
ROLE = "some-role"

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
Expand All @@ -39,7 +40,6 @@
IMAGE_NAME = "fakeimage"
REGION = "us-west-2"
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
ENDPOINT_NAME = "{}-{}".format(MODEL_NAME, INSTANCE_TYPE.replace(".", "-"))
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
BRANCH = "test-branch-git-config"
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_deploy(sagemaker_session, tmpdir):
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
ENDPOINT_NAME,
MODEL_NAME,
[
{
"InitialVariantWeight": 1,
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_deploy_tags(sagemaker_session, tmpdir):
tags = [{"ModelName": "TestModel"}]
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
ENDPOINT_NAME,
MODEL_NAME,
[
{
"InitialVariantWeight": 1,
Expand All @@ -280,7 +280,7 @@ def test_deploy_accelerator_type(tfo, time, sagemaker_session):
instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE
)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
ENDPOINT_NAME,
MODEL_NAME,
[
{
"InitialVariantWeight": 1,
Expand All @@ -305,7 +305,7 @@ def test_deploy_kms_key(tfo, time, sagemaker_session):
model = DummyFrameworkModel(sagemaker_session)
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
ENDPOINT_NAME,
MODEL_NAME,
[
{
"InitialVariantWeight": 1,
Expand Down Expand Up @@ -350,7 +350,7 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
accelerator_type=ACCELERATOR_TYPE,
)
sagemaker_session.create_endpoint_config.assert_called_with(
name=endpoint_name,
name=model.name,
model_name=model.name,
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
Expand All @@ -359,7 +359,7 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
kms_key=None,
)
config_name = sagemaker_session.create_endpoint_config(
name=endpoint_name,
name=model.name,
model_name=model.name,
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
Expand Down

0 comments on commit 3ce2d95

Please sign in to comment.