Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def transformer(
role=None,
volume_kms_key=None,
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
enable_network_isolation=None,
):
"""Return a ``Transformer`` that uses a SageMaker Model based on the
training job. It reuses the SageMaker Session and base job name used by
Expand Down Expand Up @@ -863,8 +864,18 @@ def transformer(
vpc_config_override (dict[str, list[str]]): Optional override for the
VpcConfig set on the model.
Default: use subnets and security groups from this Estimator.

* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.

enable_network_isolation (bool): Specifies whether container will
run in network isolation mode. Network isolation mode restricts
the container access to outside networks (such as the internet).
The container does not make any inbound or outbound network
calls. If True, a channel named "code" will be created for any
user entry script for inference. Also known as Internet-free mode.
If not specified, this setting is taken from the estimator's
current configuration.
"""
tags = tags or self.tags

Expand All @@ -876,8 +887,13 @@ def transformer(
model_name = self._current_job_name
else:
model_name = self.latest_training_job.name
if enable_network_isolation is None:
enable_network_isolation = self.enable_network_isolation()

model = self.create_model(
vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key
vpc_config_override=vpc_config_override,
model_kms_key=self.output_kms_key,
enable_network_isolation=enable_network_isolation,
)

# not all create_model() implementations have the same kwargs
Expand Down Expand Up @@ -1354,14 +1370,16 @@ def predict_wrapper(endpoint, session):

role = role or self.role

if "enable_network_isolation" not in kwargs:
kwargs["enable_network_isolation"] = self.enable_network_isolation()

return Model(
self.model_data,
image or self.train_image(),
role,
vpc_config=self.get_vpc_config(vpc_config_override),
sagemaker_session=self.sagemaker_session,
predictor_cls=predictor_cls,
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

Expand Down Expand Up @@ -1878,6 +1896,7 @@ def transformer(
volume_kms_key=None,
entry_point=None,
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
enable_network_isolation=None,
):
"""Return a ``Transformer`` that uses a SageMaker Model based on the
training job. It reuses the SageMaker Session and base job name used by
Expand Down Expand Up @@ -1922,9 +1941,19 @@ def transformer(
vpc_config_override (dict[str, list[str]]): Optional override for
the VpcConfig set on the model.
Default: use subnets and security groups from this Estimator.

* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.

enable_network_isolation (bool): Specifies whether container will
run in network isolation mode. Network isolation mode restricts
the container access to outside networks (such as the internet).
The container does not make any inbound or outbound network
calls. If True, a channel named "code" will be created for any
user entry script for inference. Also known as Internet-free mode.
If not specified, this setting is taken from the estimator's
current configuration.

Returns:
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
SageMaker Batch Transform job.
Expand All @@ -1933,12 +1962,16 @@ def transformer(
tags = tags or self.tags

if self.latest_training_job is not None:
if enable_network_isolation is None:
enable_network_isolation = self.enable_network_isolation()

model = self.create_model(
role=role,
model_server_workers=model_server_workers,
entry_point=entry_point,
vpc_config_override=vpc_config_override,
model_kms_key=self.output_kms_key,
enable_network_isolation=enable_network_isolation,
)
model._create_sagemaker_model(instance_type, tags=tags)

Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def create_model(
else:
image = None

if "enable_network_isolation" not in kwargs:
kwargs["enable_network_isolation"] = self.enable_network_isolation()

return SKLearnModel(
self.model_data,
role,
Expand All @@ -189,7 +192,6 @@ def create_model(
image=image or self.image_name,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

Expand Down
16 changes: 16 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def transformer(
endpoint_type=None,
entry_point=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
enable_network_isolation=None,
):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It
reuses the SageMaker Session and base job name used by the Estimator.
Expand Down Expand Up @@ -836,8 +837,18 @@ def transformer(
vpc_config_override (dict[str, list[str]]): Optional override for
the VpcConfig set on the model.
Default: use subnets and security groups from this Estimator.

* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.

enable_network_isolation (bool): Specifies whether container will
run in network isolation mode. Network isolation mode restricts
the container access to outside networks (such as the internet).
The container does not make any inbound or outbound network
calls. If True, a channel named "code" will be created for any
user entry script for inference. Also known as Internet-free mode.
If not specified, this setting is taken from the estimator's
current configuration.
"""
role = role or self.role

Expand All @@ -864,13 +875,18 @@ def transformer(
sagemaker_session=self.sagemaker_session,
)

if enable_network_isolation is None:
enable_network_isolation = self.enable_network_isolation()

model = self.create_model(
model_server_workers=model_server_workers,
role=role,
vpc_config_override=vpc_config_override,
endpoint_type=endpoint_type,
entry_point=entry_point,
enable_network_isolation=enable_network_isolation,
)

return model.transformer(
instance_count,
instance_type,
Expand Down
34 changes: 27 additions & 7 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,17 @@ def create_model(
model_server_workers=None,
entry_point=None,
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
enable_network_isolation=None,
**kwargs
):
if enable_network_isolation is None:
enable_network_isolation = self.enable_network_isolation()

return DummyFrameworkModel(
self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
entry_point=entry_point,
enable_network_isolation=self.enable_network_isolation(),
enable_network_isolation=enable_network_isolation,
role=role,
**kwargs
)
Expand Down Expand Up @@ -1357,7 +1361,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
base_job_name=base_name,
subnets=vpc_config["Subnets"],
security_group_ids=vpc_config["SecurityGroupIds"],
enable_network_isolation=True,
enable_network_isolation=False,
)
fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)

Expand Down Expand Up @@ -1387,6 +1391,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
role=new_role,
model_server_workers=1,
vpc_config_override=new_vpc_config,
enable_network_isolation=True,
)

sagemaker_session.create_model.assert_called_with(
Expand Down Expand Up @@ -1437,8 +1442,8 @@ def test_ensure_latest_training_job_failure(sagemaker_session):
assert "Estimator is not associated with a training job" in str(e)


@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock())
def test_estimator_transformer_creation(sagemaker_session):
@patch("sagemaker.estimator.Estimator.create_model")
def test_estimator_transformer_creation(create_model, sagemaker_session):
estimator = Estimator(
image_name=IMAGE_NAME,
role=ROLE,
Expand All @@ -1450,6 +1455,12 @@ def test_estimator_transformer_creation(sagemaker_session):

transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

create_model.assert_called_with(
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
model_kms_key=estimator.output_kms_key,
enable_network_isolation=False,
)

assert isinstance(transformer, Transformer)
assert transformer.sagemaker_session == sagemaker_session
assert transformer.instance_count == INSTANCE_COUNT
Expand All @@ -1458,26 +1469,29 @@ def test_estimator_transformer_creation(sagemaker_session):
assert transformer.tags is None


@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock())
def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
@patch("sagemaker.estimator.Estimator.create_model")
def test_estimator_transformer_creation_with_optional_params(create_model, sagemaker_session):
base_name = "foo"
kms_key = "key"

estimator = Estimator(
image_name=IMAGE_NAME,
role=ROLE,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
sagemaker_session=sagemaker_session,
base_job_name=base_name,
output_kms_key=kms_key,
)
estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)

strategy = "MultiRecord"
assemble_with = "Line"
kms_key = "key"
accept = "text/csv"
max_concurrent_transforms = 1
max_payload = 6
env = {"FOO": "BAR"}
new_vpc_config = {"Subnets": ["x"], "SecurityGroupIds": ["y"]}

transformer = estimator.transformer(
INSTANCE_COUNT,
Expand All @@ -1492,6 +1506,12 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
max_payload=max_payload,
env=env,
role=ROLE,
vpc_config_override=new_vpc_config,
enable_network_isolation=True,
)

create_model.assert_called_with(
vpc_config_override=new_vpc_config, model_kms_key=kms_key, enable_network_isolation=True
)

assert transformer.strategy == strategy
Expand Down
60 changes: 45 additions & 15 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def test_create_model_with_optional_params(sagemaker_session):


@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session):
def test_transformer_creation_with_optional_args(create_model, sagemaker_session):
model = Mock()
create_model.return_value = model

Expand All @@ -348,38 +348,67 @@ def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session
train_instance_type=INSTANCE_TYPE,
)
tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")

strategy = "SingleRecord"
assemble_with = "Line"
output_path = "s3://{}/batch-output".format(BUCKET_NAME)
kms_key = "kms"
accept_type = "text/bytes"
env = {"foo": "bar"}
max_concurrent_transforms = 3
max_payload = 100
tags = {"Key": "foo", "Value": "bar"}
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]}

tf.transformer(
INSTANCE_COUNT,
INSTANCE_TYPE,
strategy=strategy,
assemble_with=assemble_with,
output_path=output_path,
output_kms_key=kms_key,
accept=accept_type,
env=env,
max_concurrent_transforms=max_concurrent_transforms,
max_payload=max_payload,
tags=tags,
role=new_role,
model_server_workers=model_server_workers,
volume_kms_key=kms_key,
endpoint_type="tensorflow-serving",
entry_point=SERVING_SCRIPT_FILE,
vpc_config_override=vpc_config,
enable_network_isolation=True,
)

create_model.assert_called_with(
model_server_workers=model_server_workers,
role=new_role,
vpc_config_override=vpc_config,
endpoint_type="tensorflow-serving",
model_server_workers=None,
role=ROLE,
vpc_config_override="VPC_CONFIG_DEFAULT",
entry_point=SERVING_SCRIPT_FILE,
enable_network_isolation=True,
)
model.transformer.assert_called_with(
INSTANCE_COUNT,
INSTANCE_TYPE,
accept=None,
assemble_with=None,
env=None,
max_concurrent_transforms=None,
max_payload=None,
output_kms_key=None,
output_path=None,
strategy=None,
tags=None,
volume_kms_key=None,
accept=accept_type,
assemble_with=assemble_with,
env=env,
max_concurrent_transforms=max_concurrent_transforms,
max_payload=max_payload,
output_kms_key=kms_key,
output_path=output_path,
strategy=strategy,
tags=tags,
volume_kms_key=kms_key,
)


@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session):
def test_transformer_creation_without_optional_args(create_model, sagemaker_session):
model = Mock()
create_model.return_value = model

Expand All @@ -399,6 +428,7 @@ def test_transformer_creation_without_endpoint_type(create_model, sagemaker_sess
role=ROLE,
vpc_config_override="VPC_CONFIG_DEFAULT",
entry_point=None,
enable_network_isolation=False,
)
model.transformer.assert_called_with(
INSTANCE_COUNT,
Expand Down