Skip to content

Commit

Permalink
feature: enable sklearn for network isolation mode (#1051)
Browse files Browse the repository at this point in the history
  • Loading branch information
icywang86rui authored and laurenyu committed Sep 20, 2019
1 parent 7dbb149 commit 308f121
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,10 @@ def _framework_env_vars(self):
"""Placeholder docstring"""
if self.uploaded_code:
script_name = self.uploaded_code.script_name
dir_name = self.uploaded_code.s3_prefix
if self.enable_network_isolation():
dir_name = "/opt/ml/model/code"
else:
dir_name = self.uploaded_code.s3_prefix
else:
script_name = self.entry_point
dir_name = "file://" + self.source_dir
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,13 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
)

deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
model_data_uri = (
self.repacked_model_data if self.enable_network_isolation() else self.model_data
)
return sagemaker.container_def(deploy_image, model_data_uri, deploy_env)
19 changes: 19 additions & 0 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ def test_prepare_container_def(time, sagemaker_session):
}


@patch("shutil.rmtree", MagicMock())
@patch("tarfile.open", MagicMock())
@patch("os.listdir", MagicMock(return_value=["blah.py"]))
@patch("time.strftime", return_value=TIMESTAMP)
def test_prepare_container_def_with_network_isolation(time, sagemaker_session):
model = DummyFrameworkModel(sagemaker_session, enable_network_isolation=True)
assert model.prepare_container_def(INSTANCE_TYPE) == {
"Environment": {
"SAGEMAKER_PROGRAM": ENTRY_POINT,
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
"SAGEMAKER_REGION": REGION,
"SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false",
},
"Image": MODEL_IMAGE,
"ModelDataUrl": MODEL_DATA,
}


@patch("shutil.rmtree", MagicMock())
@patch("tarfile.open", MagicMock())
@patch("os.path.exists", MagicMock(return_value=True))
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sagemaker.sklearn import defaults
from sagemaker.sklearn import SKLearn
from sagemaker.sklearn import SKLearnPredictor, SKLearnModel
from sagemaker.fw_utils import UploadedCode

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
Expand Down Expand Up @@ -168,6 +169,25 @@ def test_create_model(sagemaker_session):
assert model_values["Image"] == default_image_uri


@patch("sagemaker.model.FrameworkModel._upload_code")
def test_create_model_with_network_isolation(upload, sagemaker_session):
source_dir = "s3://mybucket/source"
repacked_model_data = "s3://mybucket/prefix/model.tar.gz"

sklearn_model = SKLearnModel(
model_data=source_dir,
role=ROLE,
sagemaker_session=sagemaker_session,
entry_point=SCRIPT_PATH,
enable_network_isolation=True,
)
sklearn_model.uploaded_code = UploadedCode(s3_prefix=repacked_model_data, script_name="script")
sklearn_model.repacked_model_data = repacked_model_data
model_values = sklearn_model.prepare_container_def(CPU)
assert model_values["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] == "/opt/ml/model/code"
assert model_values["ModelDataUrl"] == repacked_model_data


def test_create_model_from_estimator(sagemaker_session, sklearn_version):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
Expand Down

0 comments on commit 308f121

Please sign in to comment.