diff --git a/src/sagemaker_inference/default_handler_service.py b/src/sagemaker_inference/default_handler_service.py index 83fd446..ca4e9ae 100644 --- a/src/sagemaker_inference/default_handler_service.py +++ b/src/sagemaker_inference/default_handler_service.py @@ -13,8 +13,12 @@ """This module contains functionality for the default handler service.""" from __future__ import absolute_import +import os + from sagemaker_inference.transformer import Transformer +PYTHON_PATH_ENV = "PYTHONPATH" + class DefaultHandlerService(object): """Default handler service that is executed by the model server. @@ -51,4 +55,12 @@ def initialize(self, context): """ properties = context.system_properties model_dir = properties.get("model_dir") + + # add model_dir/code to python path + code_dir_path = "{}:".format(model_dir + "/code") + if PYTHON_PATH_ENV in os.environ: + os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] + else: + os.environ[PYTHON_PATH_ENV] = code_dir_path + self._service.validate_and_initialize(model_dir=model_dir) diff --git a/test/container/mxnet/Dockerfile b/test/container/mxnet/Dockerfile index df9b16a..72db452 100644 --- a/test/container/mxnet/Dockerfile +++ b/test/container/mxnet/Dockerfile @@ -111,7 +111,10 @@ RUN useradd -m model-server \ COPY mxnet/mms_entrypoint.py /usr/local/bin/dockerd_entrypoint.py COPY mxnet/config.properties /home/model-server -COPY mxnet/inference.py /opt/ml/models/code/inference.py + +# TEST: model_dir -> /opt/ml/models//inference.py +COPY mxnet/inference.py /opt/ml/models/resnet_18/code/inference.py +COPY mxnet/inference.py /opt/ml/models/resnet_152/code/inference.py RUN chmod +x /usr/local/bin/dockerd_entrypoint.py diff --git a/test/container/mxnet/sagemaker_inference.tar.gz b/test/container/mxnet/sagemaker_inference.tar.gz index 3614eb5..ccdd66a 100644 Binary files a/test/container/mxnet/sagemaker_inference.tar.gz and b/test/container/mxnet/sagemaker_inference.tar.gz differ diff --git a/test/container/mxnet/sagemaker_mxnet_inference.tar.gz b/test/container/mxnet/sagemaker_mxnet_inference.tar.gz index 0d836f0..8b462fc 100644 Binary files a/test/container/mxnet/sagemaker_mxnet_inference.tar.gz and b/test/container/mxnet/sagemaker_mxnet_inference.tar.gz differ diff --git a/test/unit/test_default_handler_service.py b/test/unit/test_default_handler_service.py index 6c9797b..a6d2ba8 100644 --- a/test/unit/test_default_handler_service.py +++ b/test/unit/test_default_handler_service.py @@ -10,7 +10,7 @@ # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from mock import Mock, patch +from mock import MagicMock, Mock, patch from sagemaker_inference.default_handler_service import DefaultHandlerService from sagemaker_inference.transformer import Transformer @@ -48,7 +48,13 @@ def test_handle(): def test_initialize(): transformer = Mock() + properties = {"model_dir": "/opt/ml/models/model-name"} - DefaultHandlerService(transformer).initialize(CONTEXT) + def getitem(key): + return properties[key] + + context = MagicMock() + context.system_properties.__getitem__.side_effect = getitem + DefaultHandlerService(transformer).initialize(context) assert transformer.validate_and_initialize().called_once()