Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/sagemaker_inference/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes
DEFAULT_HTTP_PORT = "8080"
DEFAULT_VMARGS = "-XX:-UseContainerSupport"
DEFAULT_MAX_REQUEST_SIZE = None

SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str

Expand Down Expand Up @@ -81,6 +82,9 @@ def __init__(self):
self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT)
self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV)
self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS)
self._max_request_size_in_mb = os.environ.get(
parameters.MAX_REQUEST_SIZE, DEFAULT_MAX_REQUEST_SIZE
)

@staticmethod
def _parse_module_name(program_param):
Expand Down Expand Up @@ -147,3 +151,11 @@ def safe_port_range(self): # type: () -> str
def vmargs(self): # type: () -> str
"""str: vmargs can be provided for the JVM, to be overriden"""
return self._vmargs

@property
def max_request_size(self): # type: () -> str
"""str: max request size set by Sagemaker platform in bytes"""
if self._max_request_size_in_mb is not None:
return int(self._max_request_size_in_mb) * 1024 * 1024
else:
return None
1 change: 1 addition & 0 deletions src/sagemaker_inference/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _generate_mms_config_properties(env, handler_service=None):
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
"vmargs": env.vmargs,
"max_request_size": env.max_request_size,
}
# If provided, add handler service to user config
if handler_service:
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker_inference/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str
SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str
MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str
MAX_REQUEST_SIZE = "SAGEMAKER_MAX_PAYLOAD_IN_MB" # type: str
2 changes: 2 additions & 0 deletions test/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
parameters.BIND_TO_PORT_ENV: "1738",
parameters.SAFE_PORT_RANGE_ENV: "1111-2222",
parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport",
parameters.MAX_REQUEST_SIZE: "10",
},
clear=True,
)
Expand All @@ -47,6 +48,7 @@ def test_env():
assert env.management_http_port == "1738"
assert env.safe_port_range == "1111-2222"
assert "-XX:-UseContainerSupport" in env.vmargs
assert env.max_request_size == 10 * 1024 * 1024


@pytest.mark.parametrize("sagemaker_program", ["program.py", "program"])
Expand Down