Skip to content

Commit

Permalink
fix: add enable_network_isolation to generic Estimator class (#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenyu committed Sep 6, 2019
1 parent 228a81d commit 5ec22c0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ def __init__(
train_max_wait=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
enable_network_isolation=False,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -1008,9 +1009,18 @@ def __init__(
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
enable_network_isolation (bool): Specifies whether container will
run in network isolation mode. Network isolation mode restricts
the container access to outside networks (such as the Internet).
The container does not make any inbound or outbound network
calls. If ``True``, a channel named "code" will be created for any
user entry script for training. The user entry script, files in
source_dir (if specified), and dependencies will be uploaded in
a tar to S3. Also known as internet-free mode (default: ``False``).
"""
self.image_name = image_name
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
self._enable_network_isolation = enable_network_isolation
super(Estimator, self).__init__(
role,
train_instance_count,
Expand All @@ -1036,6 +1046,14 @@ def __init__(
checkpoint_local_path=checkpoint_local_path,
)

def enable_network_isolation(self):
"""If this Estimator can use network isolation when running.
Returns:
bool: Whether this Estimator can use network isolation or not.
"""
return self._enable_network_isolation

def train_image(self):
"""Returns the docker image to use for training.
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,24 @@ def test_generic_to_fit_with_encrypt_inter_container_traffic_flag(sagemaker_sess
assert args["encrypt_inter_container_traffic"] is True


def test_generic_to_fit_with_network_isolation(sagemaker_session):
e = Estimator(
IMAGE_NAME,
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
output_path=OUTPUT_PATH,
sagemaker_session=sagemaker_session,
enable_network_isolation=True,
)

e.fit()

sagemaker_session.train.assert_called_once()
args = sagemaker_session.train.call_args[1]
assert args["enable_network_isolation"]


def test_generic_to_deploy(sagemaker_session):
e = Estimator(
IMAGE_NAME,
Expand Down

2 comments on commit 5ec22c0

@bhavik161
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The training job works after making the changes but batch transformation fails as the method def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data): still does not support the enable_network_isolation paramters

@laurenyu
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @bhavik161, could you create a new issue that details the code you're running and the error you're seeing?

create_model_package_from_algorithm uses the CreateModelPackage API, which doesn't have any parameters for enabling network isolation from what I can see.

Please sign in to comment.