diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py index 89c68e7..b7eab8a 100644 --- a/src/sagemaker_inference/environment.py +++ b/src/sagemaker_inference/environment.py @@ -26,6 +26,7 @@ DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes DEFAULT_HTTP_PORT = "8080" DEFAULT_VMARGS = "-XX:-UseContainerSupport" +DEFAULT_MAX_REQUEST_SIZE = None SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str @@ -81,6 +82,9 @@ def __init__(self): self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS) + self._max_request_size_in_mb = os.environ.get( + parameters.MAX_REQUEST_SIZE, DEFAULT_MAX_REQUEST_SIZE + ) @staticmethod def _parse_module_name(program_param): @@ -147,3 +151,11 @@ def safe_port_range(self): # type: () -> str def vmargs(self): # type: () -> str """str: vmargs can be provided for the JVM, to be overriden""" return self._vmargs + + @property + def max_request_size(self): # type: () -> str + """str: max request size set by Sagemaker platform in bytes""" + if self._max_request_size_in_mb is not None: + return int(self._max_request_size_in_mb) * 1024 * 1024 + else: + return None diff --git a/src/sagemaker_inference/model_server.py b/src/sagemaker_inference/model_server.py index 4d3be2e..e315f58 100644 --- a/src/sagemaker_inference/model_server.py +++ b/src/sagemaker_inference/model_server.py @@ -160,6 +160,7 @@ def _generate_mms_config_properties(env, handler_service=None): "inference_address": "http://0.0.0.0:{}".format(env.inference_http_port), "management_address": "http://0.0.0.0:{}".format(env.management_http_port), "vmargs": env.vmargs, + "max_request_size": env.max_request_size, } # If provided, add handler service to user config if handler_service: diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py index d1f438b..ca92a04 100644 --- a/src/sagemaker_inference/parameters.py +++ b/src/sagemaker_inference/parameters.py @@ -25,3 +25,4 @@ BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str +MAX_REQUEST_SIZE = "SAGEMAKER_MAX_PAYLOAD_IN_MB" # type: str diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index afb90de..d1cb4e7 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -29,6 +29,7 @@ parameters.BIND_TO_PORT_ENV: "1738", parameters.SAFE_PORT_RANGE_ENV: "1111-2222", parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport", + parameters.MAX_REQUEST_SIZE: "10", }, clear=True, ) @@ -47,6 +48,7 @@ def test_env(): assert env.management_http_port == "1738" assert env.safe_port_range == "1111-2222" assert "-XX:-UseContainerSupport" in env.vmargs + assert env.max_request_size == 10 * 1024 * 1024 @pytest.mark.parametrize("sagemaker_program", ["program.py", "program"])