diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index 6752dfdf..408758d9 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -17,8 +17,14 @@ from sagemaker_pytorch_serving_container.default_inference_handler import \ DefaultPytorchInferenceHandler +import os +import sys + +ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" + class HandlerService(DefaultHandlerService): + """Handler service that is executed by the model server. Determines specific default inference handlers to use based on the type MXNet model being used. @@ -31,5 +37,16 @@ class HandlerService(DefaultHandlerService): """ def __init__(self): + self._initialized = False + transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler()) super(HandlerService, self).__init__(transformer=transformer) + + def initialize(self, context): + # Adding the 'code' directory path to sys.path to allow importing user modules when multi-model mode is enabled. + if (not self._initialized) and ENABLE_MULTI_MODEL: + code_dir = os.path.join(context.system_properties.get("model_dir"), 'code') + sys.path.append(code_dir) + self._initialized = True + + super().initialize(context)