Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
sagemaker_session=None,
enable_network_isolation=False,
model_kms_key=None,
image_config=None,
):
"""Initialize an SageMaker ``Model``.

Expand Down Expand Up @@ -90,6 +91,10 @@ def __init__(
or from the model container.
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
image_config (dict[str, str]): Specifies whether the image of
model container is pulled from ECR, or private registry in your
VPC. By default it is set to pull model container image from
ECR. (default: None).
"""
self.model_data = model_data
self.image_uri = image_uri
Expand All @@ -106,6 +111,7 @@ def __init__(
self._is_edge_packaged_model = False
self._enable_network_isolation = enable_network_isolation
self.model_kms_key = model_kms_key
self.image_config = image_config

def register(
self,
Expand Down Expand Up @@ -279,7 +285,9 @@ def prepare_container_def(
Returns:
dict: A container definition object usable with the CreateModel API.
"""
return sagemaker.container_def(self.image_uri, self.model_data, self.env)
return sagemaker.container_def(
self.image_uri, self.model_data, self.env, image_config=self.image_config
)

def enable_network_isolation(self):
"""Whether to enable network isolation when creating this Model
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4018,7 +4018,7 @@ def update_args(args: Dict[str, Any], **kwargs):
args.update({key: value})


def container_def(image_uri, model_data_url=None, env=None, container_mode=None):
def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config=None):
"""Create a definition for executing a container as part of a SageMaker model.

Args:
Expand All @@ -4030,6 +4030,9 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None)
* MultiModel: Indicates that model container can support hosting multiple models
* SingleModel: Indicates that model container can support hosting a single model
This is the default model container mode when container_mode = None
image_config (dict[str, str]): Specifies whether the image of model container is pulled
from ECR, or private registry in your VPC. By default it is set to pull model
container image from ECR. (default: None).

Returns:
dict[str, str]: A complete container definition object usable with the CreateModel API if
Expand All @@ -4042,6 +4045,8 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None)
c_def["ModelDataUrl"] = model_data_url
if container_mode:
c_def["Mode"] = container_mode
if image_config:
c_def["ImageConfig"] = image_config
return c_def


Expand Down
14 changes: 14 additions & 0 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def test_prepare_container_def_with_model_data_and_env():
assert expected == container_def


def test_prepare_container_def_with_image_config():
image_config = {"RepositoryAccessMode": "Vpc"}
model = Model(MODEL_IMAGE, image_config=image_config)

expected = {
"Image": MODEL_IMAGE,
"ImageConfig": {"RepositoryAccessMode": "Vpc"},
"Environment": {},
}

container_def = model.prepare_container_def()
assert expected == container_def


def test_model_enable_network_isolation():
model = Model(MODEL_IMAGE, MODEL_DATA)
assert model.enable_network_isolation() is False
Expand Down