Skip to content

Commit

Permalink
fix: enable kms support for repack_model (#1061)
Browse files Browse the repository at this point in the history
* fix: enable kms support for repack_model

Currently repack_model doesn't accept a kms key. This change
added a kms_key argument to the fucntion. In addition repack_model
will always use the output_kms_key inside the Estimator if it's set.
  • Loading branch information
icywang86rui committed Sep 25, 2019
1 parent d368524 commit 76d46d0
Show file tree
Hide file tree
Showing 17 changed files with 69 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
self.center_factor = center_factor
self.eval_metrics = eval_metrics

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing
the latest s3 model data produced by this Estimator.
Expand All @@ -158,12 +158,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the KMeansModel constructor.
"""
return KMeansModel(
self.model_data,
self.role,
self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=5000, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self.max_iterations = max_iterations
self.tol = tol

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.LDAModel` referencing the latest
s3 model data produced by this Estimator.
Expand All @@ -132,12 +132,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the LDAModel constructor.
"""
return LDAModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training( # pylint: disable=signature-differs
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/linear_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(
"value greater than 2."
)

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.LinearLearnerModel` referencing
the latest s3 model data produced by this Estimator.
Expand All @@ -382,12 +382,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the LinearLearnerModel constructor.
"""
return LinearLearnerModel(
self.model_data,
self.role,
self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an
``Endpoint``.
Expand All @@ -186,6 +187,7 @@ def create_model(
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
**kwargs: Additional kwargs passed to the ChainerModel constructor.
Returns:
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def deploy(
)
model = self._compiled_models[family]
else:
kwargs["model_kms_key"] = self.output_kms_key
model = self.create_model(**kwargs)
model.name = model_name
return model.deploy(
Expand Down Expand Up @@ -734,7 +735,9 @@ def transformer(
model_name = self._current_job_name
else:
model_name = self.latest_training_job.name
model = self.create_model(vpc_config_override=vpc_config_override)
model = self.create_model(
vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key
)

# not all create_model() implementations have the same kwargs
model.name = model_name
Expand Down Expand Up @@ -1716,6 +1719,7 @@ def transformer(
model_server_workers=model_server_workers,
entry_point=entry_point,
vpc_config_override=vpc_config_override,
model_kms_key=self.output_kms_key,
)
model._create_sagemaker_model(instance_type, tags=tags)

Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
vpc_config=None,
sagemaker_session=None,
enable_network_isolation=False,
model_kms_key=None,
):
"""Initialize an SageMaker ``Model``.
Expand Down Expand Up @@ -114,6 +115,8 @@ def __init__(
network isolation in the endpoint, isolating the model
container. No inbound or outbound network calls can be made to
or from the model container.
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
"""
self.model_data = model_data
self.image = image
Expand All @@ -127,6 +130,7 @@ def __init__(
self.endpoint_name = None
self._is_compiled_model = False
self._enable_network_isolation = enable_network_isolation
self.model_kms_key = model_kms_key

def prepare_container_def(
self, instance_type, accelerator_type=None
Expand Down Expand Up @@ -799,6 +803,7 @@ def _upload_code(self, key_prefix, repack=False):
model_uri=self.model_data,
repacked_model_uri=repacked_model_data,
sagemaker_session=self.sagemaker_session,
kms_key=self.model_kms_key,
)

self.repacked_model_data = repacked_model_data
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def create_model(
source_dir=None,
dependencies=None,
image_name=None,
**kwargs
):
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
``Endpoint``.
Expand Down Expand Up @@ -171,6 +172,7 @@ def create_model(
Examples:
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.
**kwargs: Additional kwargs passed to the MXNetModel constructor.
Returns:
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an
``Endpoint``.
Expand All @@ -139,6 +140,7 @@ def create_model(
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
**kwargs: Additional kwargs passed to the PyTorchModel constructor.
Returns:
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``RLEstimatorModel`` object that can be deployed
to an Endpoint.
Expand All @@ -189,6 +190,7 @@ def create_model(
folders will be copied to SageMaker in the same folder where the
entry_point is copied. If the ```source_dir``` points to S3,
code will be uploaded and the S3 location will be used instead.
**kwargs: Additional kwargs passed to the FrameworkModel constructor.
Returns:
sagemaker.model.FrameworkModel: Depending on input parameters returns
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
Expand Down Expand Up @@ -537,6 +538,8 @@ def create_model(
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
set to ``None``.
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
**kwargs: Additional kwargs passed to ``sagemaker.tensorflow.serving.Model`` constructor
and ``sagemaker.tensorflow.model.TensorFlowModel`` constructor.
Returns:
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
Expand All @@ -552,6 +555,7 @@ def create_model(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

return self._create_default_model(
Expand All @@ -561,6 +565,7 @@ def create_model(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

def _create_tfs_model(
Expand All @@ -570,6 +575,7 @@ def _create_tfs_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return Model(
Expand All @@ -585,6 +591,7 @@ def _create_tfs_model(
source_dir=source_dir,
dependencies=dependencies,
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

def _create_default_model(
Expand All @@ -595,6 +602,7 @@ def _create_default_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return TensorFlowModel(
Expand All @@ -615,6 +623,7 @@ def _create_default_model(
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=dependencies or self.dependencies,
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

def hyperparameters(self):
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/tensorflow/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
self.model_data,
model_data,
self.sagemaker_session,
kms_key=self.model_kms_key,
)
else:
model_data = self.model_data
Expand Down
12 changes: 9 additions & 3 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def repack_model(
model_uri,
repacked_model_uri,
sagemaker_session,
kms_key=None,
):
"""Unpack model tarball and creates a new model tarball with the provided
code script.
Expand Down Expand Up @@ -400,6 +401,7 @@ def repack_model(
model will be saved
sagemaker_session (sagemaker.session.Session): a sagemaker session to
interact with S3.
kms_key (str): KMS key ARN for encrypting the repacked model file
Returns:
str: path to the new packed model
Expand All @@ -417,10 +419,10 @@ def repack_model(
with tarfile.open(tmp_model_path, mode="w:gz") as t:
t.add(model_dir, arcname=os.path.sep)

_save_model(repacked_model_uri, tmp_model_path, sagemaker_session)
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)


def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
"""
Args:
repacked_model_uri:
Expand All @@ -432,8 +434,12 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
bucket, key = url.netloc, url.path.lstrip("/")
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))

if kms_key:
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
else:
extra_args = None
sagemaker_session.boto_session.resource("s3").Object(bucket, new_key).upload_file(
tmp_model_path
tmp_model_path, ExtraArgs=extra_args
)
else:
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
Expand Down
18 changes: 13 additions & 5 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tests.integ
from tests.integ import timeout
from tests.integ import kms_utils
from tests.integ.retry import retries
from tests.integ.s3_utils import assert_s3_files_exist

Expand Down Expand Up @@ -67,16 +68,14 @@ def test_mnist(sagemaker_session, instance_type):

def test_server_side_encryption(sagemaker_session):
boto_session = sagemaker_session.boto_session
with tests.integ.kms_utils.bucket_with_encryption(boto_session, ROLE) as (
bucket_with_kms,
kms_key,
):
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
output_path = os.path.join(
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
)

estimator = TensorFlow(
entry_point=SCRIPT,
entry_point="training.py",
source_dir=TFS_RESOURCE_PATH,
role=ROLE,
train_instance_count=1,
train_instance_type="ml.c5.xlarge",
Expand All @@ -99,6 +98,15 @@ def test_server_side_encryption(sagemaker_session):
inputs=inputs, job_name=unique_name_from_base("test-server-side-encryption")
)

endpoint_name = unique_name_from_base("test-server-side-encryption")
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
estimator.deploy(
initial_instance_count=1,
instance_type="ml.c5.xlarge",
endpoint_name=endpoint_name,
entry_point=os.path.join(TFS_RESOURCE_PATH, "inference.py"),
)


@pytest.mark.canary_quick
def test_mnist_distributed(sagemaker_session, instance_type):
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ def create_model(
model_server_workers=None,
entry_point=None,
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
**kwargs
):
return DummyFrameworkModel(
self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
entry_point=entry_point,
enable_network_isolation=self.enable_network_isolation(),
role=role,
**kwargs
)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,15 @@ def test_model(sagemaker_session):

@patch("sagemaker.utils.repack_model")
def test_model_mms_version(repack_model, sagemaker_session):
model_kms_key = "kms-key"
model = MXNetModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
framework_version=MXNetModel._LOWEST_MMS_VERSION,
sagemaker_session=sagemaker_session,
name="test-mxnet-model",
model_kms_key=model_kms_key,
)
predictor = model.deploy(1, GPU)

Expand All @@ -433,6 +435,7 @@ def test_model_mms_version(repack_model, sagemaker_session):
model_uri=MODEL_DATA,
repacked_model_uri="s3://mybucket/test-mxnet-model/model.tar.gz",
sagemaker_session=sagemaker_session,
kms_key=model_kms_key,
)

assert model.model_data == MODEL_DATA
Expand Down

0 comments on commit 76d46d0

Please sign in to comment.