Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,17 @@ def __init__(
:class:`~sagemaker.estimator.EstimatorBase`.
"""
super(AmazonAlgorithmEstimatorBase, self).__init__(
role, train_instance_count, train_instance_type, **kwargs
role,
train_instance_count,
train_instance_type,
enable_network_isolation=enable_network_isolation,
**kwargs
)

data_location = data_location or "s3://{}/sagemaker-record-sets/".format(
self.sagemaker_session.default_bucket()
)
self._data_location = data_location
self._enable_network_isolation = enable_network_isolation

def train_image(self):
"""Placeholder docstring"""
Expand All @@ -101,14 +104,6 @@ def hyperparameters(self):
"""Placeholder docstring"""
return hp.serialize_all(self)

def enable_network_isolation(self):
"""If this Estimator can use network isolation when running.

Returns:
bool: Whether this Estimator can use network isolation or not.
"""
return self._enable_network_isolation

@property
def data_location(self):
"""Placeholder docstring"""
Expand Down
45 changes: 16 additions & 29 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
debugger_hook_config=None,
tensorboard_output_config=None,
enable_sagemaker_metrics=None,
enable_network_isolation=False,
):
"""Initialize an ``EstimatorBase`` instance.

Expand Down Expand Up @@ -199,6 +200,11 @@ def __init__(
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
enable_network_isolation (bool): Specifies whether container will
run in network isolation mode (default: ``False``). Network
isolation mode restricts the container access to outside networks
(such as the Internet). The container does not make any inbound or
outbound network calls. Also known as Internet-free mode.
"""
self.role = role
self.train_instance_count = train_instance_count
Expand Down Expand Up @@ -260,6 +266,7 @@ def __init__(
self.collection_configs = None

self.enable_sagemaker_metrics = enable_sagemaker_metrics
self._enable_network_isolation = enable_network_isolation

@abstractmethod
def train_image(self):
Expand Down Expand Up @@ -290,7 +297,7 @@ def enable_network_isolation(self):
Returns:
bool: Whether this Estimator needs network isolation or not.
"""
return False
return self._enable_network_isolation

def prepare_workflow_for_training(self, job_name=None):
"""Calls _prepare_for_training. Used when setting up a workflow.
Expand Down Expand Up @@ -1219,21 +1226,17 @@ def __init__(
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
enable_network_isolation (bool): Specifies whether container will
run in network isolation mode. Network isolation mode restricts
the container access to outside networks (such as the Internet).
The container does not make any inbound or outbound network
calls. If ``True``, a channel named "code" will be created for any
user entry script for training. The user entry script, files in
source_dir (if specified), and dependencies will be uploaded in
a tar to S3. Also known as internet-free mode (default: ``False``).
run in network isolation mode (default: ``False``). Network
isolation mode restricts the container access to outside networks
(such as the Internet). The container does not make any inbound or
outbound network calls. Also known as Internet-free mode.
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
"""
self.image_name = image_name
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
self._enable_network_isolation = enable_network_isolation
super(Estimator, self).__init__(
role,
train_instance_count,
Expand Down Expand Up @@ -1261,16 +1264,9 @@ def __init__(
debugger_hook_config=debugger_hook_config,
tensorboard_output_config=tensorboard_output_config,
enable_sagemaker_metrics=enable_sagemaker_metrics,
enable_network_isolation=enable_network_isolation,
)

def enable_network_isolation(self):
"""If this Estimator can use network isolation when running.

Returns:
bool: Whether this Estimator can use network isolation or not.
"""
return self._enable_network_isolation

def train_image(self):
"""Returns the docker image to use for training.

Expand Down Expand Up @@ -1498,15 +1494,15 @@ def __init__(
>>> |------ train.py
>>> |------ common
>>> |------ virtual-env

enable_network_isolation (bool): Specifies whether container will
run in network isolation mode. Network isolation mode restricts
the container access to outside networks (such as the internet).
The container does not make any inbound or outbound network
calls. If True, a channel named "code" will be created for any
user entry script for training. The user entry script, files in
source_dir (if specified), and dependencies will be uploaded in
a tar to S3. Also known as internet-free mode (default: `False`
).
a tar to S3. Also known as internet-free mode (default: `False`).
git_config (dict[str, str]): Git configurations used for cloning
files, including ``repo``, ``branch``, ``commit``,
``2FA_enabled``, ``username``, ``password`` and ``token``. The
Expand Down Expand Up @@ -1579,7 +1575,7 @@ def __init__(
You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.EstimatorBase`.
"""
super(Framework, self).__init__(**kwargs)
super(Framework, self).__init__(enable_network_isolation=enable_network_isolation, **kwargs)
if entry_point.startswith("s3://"):
raise ValueError(
"Invalid entry point script: {}. Must be a path to a local file.".format(
Expand All @@ -1599,7 +1595,6 @@ def __init__(
self.container_log_level = container_log_level
self.code_location = code_location
self.image_name = image_name
self._enable_network_isolation = enable_network_isolation

self.uploaded_code = None

Expand All @@ -1608,14 +1603,6 @@ def __init__(
self.checkpoint_local_path = checkpoint_local_path
self.enable_sagemaker_metrics = enable_sagemaker_metrics

def enable_network_isolation(self):
"""Return True if this Estimator can use network isolation to run.

Returns:
bool: Whether this Estimator can use network isolation or not.
"""
return self._enable_network_isolation

def _prepare_for_training(self, job_name=None):
"""Set hyperparameters needed for training. This method will also
validate ``source_dir``.
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def test_framework_all_init_args(sagemaker_session):
checkpoint_s3_uri="s3://bucket/checkpoint",
checkpoint_local_path="file://local/checkpoint",
enable_sagemaker_metrics=True,
enable_network_isolation=True,
)
_TrainingJob.start_new(f, "s3://mydata", None)
sagemaker_session.train.assert_called_once()
Expand Down Expand Up @@ -247,6 +248,7 @@ def test_framework_all_init_args(sagemaker_session):
"checkpoint_s3_uri": "s3://bucket/checkpoint",
"checkpoint_local_path": "file://local/checkpoint",
"enable_sagemaker_metrics": True,
"enable_network_isolation": True,
}


Expand Down