diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 8386df08a7..7e4f79a87e 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -53,6 +53,7 @@ def __init__( sagemaker_session=None, enable_network_isolation=False, model_kms_key=None, + image_config=None, ): """Initialize an SageMaker ``Model``. @@ -90,6 +91,10 @@ def __init__( or from the model container. model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked + image_config (dict[str, str]): Specifies whether the image of + model container is pulled from ECR, or private registry in your + VPC. By default it is set to pull model container image from + ECR. (default: None). """ self.model_data = model_data self.image_uri = image_uri @@ -106,6 +111,7 @@ def __init__( self._is_edge_packaged_model = False self._enable_network_isolation = enable_network_isolation self.model_kms_key = model_kms_key + self.image_config = image_config def register( self, @@ -279,7 +285,9 @@ def prepare_container_def( Returns: dict: A container definition object usable with the CreateModel API. """ - return sagemaker.container_def(self.image_uri, self.model_data, self.env) + return sagemaker.container_def( + self.image_uri, self.model_data, self.env, image_config=self.image_config + ) def enable_network_isolation(self): """Whether to enable network isolation when creating this Model diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 0a86228bd8..ddda685c99 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4018,7 +4018,7 @@ def update_args(args: Dict[str, Any], **kwargs): args.update({key: value}) -def container_def(image_uri, model_data_url=None, env=None, container_mode=None): +def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config=None): """Create a definition for executing a container as part of a SageMaker model. Args: @@ -4030,6 +4030,9 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None) * MultiModel: Indicates that model container can support hosting multiple models * SingleModel: Indicates that model container can support hosting a single model This is the default model container mode when container_mode = None + image_config (dict[str, str]): Specifies whether the image of model container is pulled + from ECR, or private registry in your VPC. By default it is set to pull model + container image from ECR. (default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if @@ -4042,6 +4045,8 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None) c_def["ModelDataUrl"] = model_data_url if container_mode: c_def["Mode"] = container_mode + if image_config: + c_def["ImageConfig"] = image_config return c_def diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 10d99db558..98b5e1e35f 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -54,6 +54,20 @@ def test_prepare_container_def_with_model_data_and_env(): assert expected == container_def +def test_prepare_container_def_with_image_config(): + image_config = {"RepositoryAccessMode": "Vpc"} + model = Model(MODEL_IMAGE, image_config=image_config) + + expected = { + "Image": MODEL_IMAGE, + "ImageConfig": {"RepositoryAccessMode": "Vpc"}, + "Environment": {}, + } + + container_def = model.prepare_container_def() + assert expected == container_def + + def test_model_enable_network_isolation(): model = Model(MODEL_IMAGE, MODEL_DATA) assert model.enable_network_isolation() is False