From fe352eba7d1bb7de5848dcff922b45d59b6ddd7c Mon Sep 17 00:00:00 2001 From: rohithkrn Date: Fri, 31 Mar 2023 16:04:49 -0700 Subject: [PATCH 1/3] handle max_request_size param set by SM platofrm --- src/sagemaker_inference/environment.py | 10 ++++++++++ src/sagemaker_inference/model_server.py | 1 + src/sagemaker_inference/parameters.py | 1 + test/unit/test_environment.py | 2 ++ 4 files changed, 14 insertions(+) diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py index 89c68e7..94495e3 100644 --- a/src/sagemaker_inference/environment.py +++ b/src/sagemaker_inference/environment.py @@ -26,6 +26,7 @@ DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes DEFAULT_HTTP_PORT = "8080" DEFAULT_VMARGS = "-XX:-UseContainerSupport" +DEFAULT_MAX_REQUEST_SIZE = None SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str @@ -81,6 +82,7 @@ def __init__(self): self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS) + self._max_request_size_in_mb = os.environ.get(parameters.MAX_REQUEST_SIZE, DEFAULT_MAX_REQUEST_SIZE) @staticmethod def _parse_module_name(program_param): @@ -147,3 +149,11 @@ def safe_port_range(self): # type: () -> str def vmargs(self): # type: () -> str """str: vmargs can be provided for the JVM, to be overriden""" return self._vmargs + + @property + def max_request_size(self): # type: () -> str + """str: max request size set by Sagemaker platform in bytes""" + if self._max_request_size_in_mb is not None: + return int(self._max_request_size_in_mb) * 1024 * 1024 + else: + return None diff --git a/src/sagemaker_inference/model_server.py b/src/sagemaker_inference/model_server.py index 4d3be2e..e315f58 100644 --- a/src/sagemaker_inference/model_server.py +++ b/src/sagemaker_inference/model_server.py @@ -160,6 +160,7 @@ def _generate_mms_config_properties(env, handler_service=None): "inference_address": "http://0.0.0.0:{}".format(env.inference_http_port), "management_address": "http://0.0.0.0:{}".format(env.management_http_port), "vmargs": env.vmargs, + "max_request_size": env.max_request_size, } # If provided, add handler service to user config if handler_service: diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py index d1f438b..ca92a04 100644 --- a/src/sagemaker_inference/parameters.py +++ b/src/sagemaker_inference/parameters.py @@ -25,3 +25,4 @@ BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str +MAX_REQUEST_SIZE = "SAGEMAKER_MAX_PAYLOAD_IN_MB" # type: str diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index afb90de..ff6de43 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -29,6 +29,7 @@ parameters.BIND_TO_PORT_ENV: "1738", parameters.SAFE_PORT_RANGE_ENV: "1111-2222", parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport", + parameters.MAX_REQUEST_SIZE: "10", }, clear=True, ) @@ -47,6 +48,7 @@ def test_env(): assert env.management_http_port == "1738" assert env.safe_port_range == "1111-2222" assert "-XX:-UseContainerSupport" in env.vmargs + assert env.max_request_size == 10485760 @pytest.mark.parametrize("sagemaker_program", ["program.py", "program"]) From a4952ce95cb1283817de568365311b51d39be118 Mon Sep 17 00:00:00 2001 From: rohithkrn Date: Fri, 31 Mar 2023 16:41:50 -0700 Subject: [PATCH 2/3] fix lint --- src/sagemaker_inference/environment.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py index 94495e3..b7eab8a 100644 --- a/src/sagemaker_inference/environment.py +++ b/src/sagemaker_inference/environment.py @@ -82,7 +82,9 @@ def __init__(self): self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS) - self._max_request_size_in_mb = os.environ.get(parameters.MAX_REQUEST_SIZE, DEFAULT_MAX_REQUEST_SIZE) + self._max_request_size_in_mb = os.environ.get( + parameters.MAX_REQUEST_SIZE, DEFAULT_MAX_REQUEST_SIZE + ) @staticmethod def _parse_module_name(program_param): From cf28325c67bf2c324181e803eb591af099a48410 Mon Sep 17 00:00:00 2001 From: rohithkrn Date: Mon, 3 Apr 2023 10:35:40 -0700 Subject: [PATCH 3/3] update test with a readable value --- test/unit/test_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index ff6de43..d1cb4e7 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -48,7 +48,7 @@ def test_env(): assert env.management_http_port == "1738" assert env.safe_port_range == "1111-2222" assert "-XX:-UseContainerSupport" in env.vmargs - assert env.max_request_size == 10485760 + assert env.max_request_size == 10 * 1024 * 1024 @pytest.mark.parametrize("sagemaker_program", ["program.py", "program"])