diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index e01b6c1411..1842f1561c 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -825,6 +825,7 @@ def transformer( role=None, volume_kms_key=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + enable_network_isolation=None, ): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by @@ -863,8 +864,18 @@ def transformer( vpc_config_override (dict[str, list[str]]): Optional override for the VpcConfig set on the model. Default: use subnets and security groups from this Estimator. + * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. + + enable_network_isolation (bool): Specifies whether container will + run in network isolation mode. Network isolation mode restricts + the container access to outside networks (such as the internet). + The container does not make any inbound or outbound network + calls. If True, a channel named "code" will be created for any + user entry script for inference. Also known as Internet-free mode. + If not specified, this setting is taken from the estimator's + current configuration. """ tags = tags or self.tags @@ -876,8 +887,13 @@ def transformer( model_name = self._current_job_name else: model_name = self.latest_training_job.name + if enable_network_isolation is None: + enable_network_isolation = self.enable_network_isolation() + model = self.create_model( - vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key + vpc_config_override=vpc_config_override, + model_kms_key=self.output_kms_key, + enable_network_isolation=enable_network_isolation, ) # not all create_model() implementations have the same kwargs @@ -1354,6 +1370,9 @@ def predict_wrapper(endpoint, session): role = role or self.role + if "enable_network_isolation" not in kwargs: + kwargs["enable_network_isolation"] = self.enable_network_isolation() + return Model( self.model_data, image or self.train_image(), @@ -1361,7 +1380,6 @@ def predict_wrapper(endpoint, session): vpc_config=self.get_vpc_config(vpc_config_override), sagemaker_session=self.sagemaker_session, predictor_cls=predictor_cls, - enable_network_isolation=self.enable_network_isolation(), **kwargs ) @@ -1878,6 +1896,7 @@ def transformer( volume_kms_key=None, entry_point=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + enable_network_isolation=None, ): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by @@ -1922,9 +1941,19 @@ def transformer( vpc_config_override (dict[str, list[str]]): Optional override for the VpcConfig set on the model. Default: use subnets and security groups from this Estimator. + * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. + enable_network_isolation (bool): Specifies whether container will + run in network isolation mode. Network isolation mode restricts + the container access to outside networks (such as the internet). + The container does not make any inbound or outbound network + calls. If True, a channel named "code" will be created for any + user entry script for inference. Also known as Internet-free mode. + If not specified, this setting is taken from the estimator's + current configuration. + Returns: sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a SageMaker Batch Transform job. @@ -1933,12 +1962,16 @@ def transformer( tags = tags or self.tags if self.latest_training_job is not None: + if enable_network_isolation is None: + enable_network_isolation = self.enable_network_isolation() + model = self.create_model( role=role, model_server_workers=model_server_workers, entry_point=entry_point, vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key, + enable_network_isolation=enable_network_isolation, ) model._create_sagemaker_model(instance_type, tags=tags) diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 10524044cc..b96aa5216e 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -174,6 +174,9 @@ def create_model( else: image = None + if "enable_network_isolation" not in kwargs: + kwargs["enable_network_isolation"] = self.enable_network_isolation() + return SKLearnModel( self.model_data, role, @@ -189,7 +192,6 @@ def create_model( image=image or self.image_name, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - enable_network_isolation=self.enable_network_isolation(), **kwargs ) diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index aedfcdfcd8..0f31d217bf 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -791,6 +791,7 @@ def transformer( endpoint_type=None, entry_point=None, vpc_config_override=VPC_CONFIG_DEFAULT, + enable_network_isolation=None, ): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -836,8 +837,18 @@ def transformer( vpc_config_override (dict[str, list[str]]): Optional override for the VpcConfig set on the model. Default: use subnets and security groups from this Estimator. + * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. + + enable_network_isolation (bool): Specifies whether container will + run in network isolation mode. Network isolation mode restricts + the container access to outside networks (such as the internet). + The container does not make any inbound or outbound network + calls. If True, a channel named "code" will be created for any + user entry script for inference. Also known as Internet-free mode. + If not specified, this setting is taken from the estimator's + current configuration. """ role = role or self.role @@ -864,13 +875,18 @@ def transformer( sagemaker_session=self.sagemaker_session, ) + if enable_network_isolation is None: + enable_network_isolation = self.enable_network_isolation() + model = self.create_model( model_server_workers=model_server_workers, role=role, vpc_config_override=vpc_config_override, endpoint_type=endpoint_type, entry_point=entry_point, + enable_network_isolation=enable_network_isolation, ) + return model.transformer( instance_count, instance_type, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 58a432459b..d9a98bc8cf 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -121,13 +121,17 @@ def create_model( model_server_workers=None, entry_point=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + enable_network_isolation=None, **kwargs ): + if enable_network_isolation is None: + enable_network_isolation = self.enable_network_isolation() + return DummyFrameworkModel( self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), entry_point=entry_point, - enable_network_isolation=self.enable_network_isolation(), + enable_network_isolation=enable_network_isolation, role=role, **kwargs ) @@ -1357,7 +1361,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa base_job_name=base_name, subnets=vpc_config["Subnets"], security_group_ids=vpc_config["SecurityGroupIds"], - enable_network_isolation=True, + enable_network_isolation=False, ) fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) @@ -1387,6 +1391,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa role=new_role, model_server_workers=1, vpc_config_override=new_vpc_config, + enable_network_isolation=True, ) sagemaker_session.create_model.assert_called_with( @@ -1437,8 +1442,8 @@ def test_ensure_latest_training_job_failure(sagemaker_session): assert "Estimator is not associated with a training job" in str(e) -@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock()) -def test_estimator_transformer_creation(sagemaker_session): +@patch("sagemaker.estimator.Estimator.create_model") +def test_estimator_transformer_creation(create_model, sagemaker_session): estimator = Estimator( image_name=IMAGE_NAME, role=ROLE, @@ -1450,6 +1455,12 @@ def test_estimator_transformer_creation(sagemaker_session): transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE) + create_model.assert_called_with( + vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + model_kms_key=estimator.output_kms_key, + enable_network_isolation=False, + ) + assert isinstance(transformer, Transformer) assert transformer.sagemaker_session == sagemaker_session assert transformer.instance_count == INSTANCE_COUNT @@ -1458,9 +1469,11 @@ def test_estimator_transformer_creation(sagemaker_session): assert transformer.tags is None -@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock()) -def test_estimator_transformer_creation_with_optional_params(sagemaker_session): +@patch("sagemaker.estimator.Estimator.create_model") +def test_estimator_transformer_creation_with_optional_params(create_model, sagemaker_session): base_name = "foo" + kms_key = "key" + estimator = Estimator( image_name=IMAGE_NAME, role=ROLE, @@ -1468,16 +1481,17 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session): train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session, base_job_name=base_name, + output_kms_key=kms_key, ) estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) strategy = "MultiRecord" assemble_with = "Line" - kms_key = "key" accept = "text/csv" max_concurrent_transforms = 1 max_payload = 6 env = {"FOO": "BAR"} + new_vpc_config = {"Subnets": ["x"], "SecurityGroupIds": ["y"]} transformer = estimator.transformer( INSTANCE_COUNT, @@ -1492,6 +1506,12 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session): max_payload=max_payload, env=env, role=ROLE, + vpc_config_override=new_vpc_config, + enable_network_isolation=True, + ) + + create_model.assert_called_with( + vpc_config_override=new_vpc_config, model_kms_key=kms_key, enable_network_isolation=True ) assert transformer.strategy == strategy diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index f8ec996d62..165f35b6f1 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -336,7 +336,7 @@ def test_create_model_with_optional_params(sagemaker_session): @patch("sagemaker.tensorflow.estimator.TensorFlow.create_model") -def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session): +def test_transformer_creation_with_optional_args(create_model, sagemaker_session): model = Mock() create_model.return_value = model @@ -348,38 +348,67 @@ def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session train_instance_type=INSTANCE_TYPE, ) tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name") + + strategy = "SingleRecord" + assemble_with = "Line" + output_path = "s3://{}/batch-output".format(BUCKET_NAME) + kms_key = "kms" + accept_type = "text/bytes" + env = {"foo": "bar"} + max_concurrent_transforms = 3 + max_payload = 100 + tags = {"Key": "foo", "Value": "bar"} + new_role = "role" + model_server_workers = 2 + vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]} + tf.transformer( INSTANCE_COUNT, INSTANCE_TYPE, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=kms_key, + accept=accept_type, + env=env, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + tags=tags, + role=new_role, + model_server_workers=model_server_workers, + volume_kms_key=kms_key, endpoint_type="tensorflow-serving", entry_point=SERVING_SCRIPT_FILE, + vpc_config_override=vpc_config, + enable_network_isolation=True, ) create_model.assert_called_with( + model_server_workers=model_server_workers, + role=new_role, + vpc_config_override=vpc_config, endpoint_type="tensorflow-serving", - model_server_workers=None, - role=ROLE, - vpc_config_override="VPC_CONFIG_DEFAULT", entry_point=SERVING_SCRIPT_FILE, + enable_network_isolation=True, ) model.transformer.assert_called_with( INSTANCE_COUNT, INSTANCE_TYPE, - accept=None, - assemble_with=None, - env=None, - max_concurrent_transforms=None, - max_payload=None, - output_kms_key=None, - output_path=None, - strategy=None, - tags=None, - volume_kms_key=None, + accept=accept_type, + assemble_with=assemble_with, + env=env, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + output_kms_key=kms_key, + output_path=output_path, + strategy=strategy, + tags=tags, + volume_kms_key=kms_key, ) @patch("sagemaker.tensorflow.estimator.TensorFlow.create_model") -def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session): +def test_transformer_creation_without_optional_args(create_model, sagemaker_session): model = Mock() create_model.return_value = model @@ -399,6 +428,7 @@ def test_transformer_creation_without_endpoint_type(create_model, sagemaker_sess role=ROLE, vpc_config_override="VPC_CONFIG_DEFAULT", entry_point=None, + enable_network_isolation=False, ) model.transformer.assert_called_with( INSTANCE_COUNT,