Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3075,8 +3075,8 @@ def _prepare_for_triton(self):
export_path.mkdir(parents=True)

if self.model:
self.secret_key = "dummy secret key for onnx backend"

# ONNX path: export model to ONNX format for Triton's native ONNX backend.
# No pickle is created or loaded at runtime, so no HMAC signing is needed.
if self.framework == Framework.PYTORCH:
self._export_pytorch_to_onnx(
export_path=export_path, model=self.model, schema_builder=self.schema_builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ def auto_complete_config(auto_complete_model_config):
def initialize(self, args: dict) -> None:
"""Placeholder docstring"""
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")
with open(str(serve_path), mode="rb") as f:
inference_spec, schema_builder = cloudpickle.load(f)
buffer = f.read()
perform_integrity_check(buffer=buffer, metadata_path=str(metadata_path))

# TODO: HMAC signing for integrity check
with open(str(serve_path), mode="rb") as f:
inference_spec, schema_builder = cloudpickle.load(f)

self.inference_spec = inference_spec
self.schema_builder = schema_builder
Expand Down
14 changes: 12 additions & 2 deletions sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@ def _start_triton_server(
env_vars.update(
{
"TRITON_MODEL_DIR": "/models/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
)

# Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where
# pickle integrity verification is needed. The ONNX path does not
# use pickles, so no secret key is required.
if secret_key and isinstance(secret_key, str) and secret_key.strip():
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

if "cpu" not in image_uri:
self.container = docker_client.containers.run(
image=image_uri,
Expand Down Expand Up @@ -133,7 +138,12 @@ def _upload_triton_artifacts(
env_vars = {
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
"TRITON_MODEL_DIR": "/opt/ml/model/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}

# Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where
# pickle integrity verification is needed.
if secret_key and isinstance(secret_key, str) and secret_key.strip():
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

return s3_upload_path, env_vars
20 changes: 16 additions & 4 deletions sagemaker-serve/tests/unit/test_model_builder_utils_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ class TestPrepareForTriton(unittest.TestCase):
"""Test _prepare_for_triton method."""

@patch('shutil.copy2')
@patch.object(_ModelBuilderUtils, '_hmac_signing')
@patch.object(_ModelBuilderUtils, '_export_pytorch_to_onnx')
def test_prepare_for_triton_pytorch(self, mock_export, mock_copy):
"""Test preparing PyTorch model for Triton."""
def test_prepare_for_triton_pytorch(self, mock_export, mock_hmac, mock_copy):
"""Test preparing PyTorch model for Triton.

ONNX path: no pickle is created or loaded at runtime,
so no HMAC signing is needed.
"""
utils = _ModelBuilderUtils()
utils.framework = Framework.PYTORCH
utils.model = Mock()
Expand All @@ -94,11 +99,17 @@ def test_prepare_for_triton_pytorch(self, mock_export, mock_copy):
utils._prepare_for_triton()

mock_export.assert_called_once()
mock_hmac.assert_not_called()

@patch('shutil.copy2')
@patch.object(_ModelBuilderUtils, '_hmac_signing')
@patch.object(_ModelBuilderUtils, '_export_tf_to_onnx')
def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy):
"""Test preparing TensorFlow model for Triton."""
def test_prepare_for_triton_tensorflow(self, mock_export, mock_hmac, mock_copy):
"""Test preparing TensorFlow model for Triton.

ONNX path: no pickle is created or loaded at runtime,
so no HMAC signing is needed.
"""
utils = _ModelBuilderUtils()
utils.framework = Framework.TENSORFLOW
utils.model = Mock()
Expand All @@ -109,6 +120,7 @@ def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy):
utils._prepare_for_triton()

mock_export.assert_called_once()
mock_hmac.assert_not_called()

@patch('shutil.copy2')
@patch.object(_ModelBuilderUtils, '_generate_config_pbtxt')
Expand Down
Loading