Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/sagemaker_pytorch_serving_container/torchserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from retrying import retry

import sagemaker_pytorch_serving_container
from sagemaker_pytorch_serving_container import ts_environment
from sagemaker_inference import default_handler_service, environment, utils
from sagemaker_inference.environment import code_dir

Expand Down Expand Up @@ -149,14 +150,33 @@ def _create_torchserve_config_file():

def _generate_ts_config_properties():
env = environment.Environment()

user_defined_configuration = {
"default_response_timeout": env.model_server_timeout,
"default_workers_per_model": env.model_server_workers,
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
}

ts_env = ts_environment.TorchServeEnvironment()

if ts_env.is_env_set() and not ENABLE_MULTI_MODEL:
models_string = f'''{{\\
"{DEFAULT_TS_MODEL_NAME}": {{\\
"1.0": {{\\
"defaultVersion": true,\\
"marName": "{DEFAULT_TS_MODEL_NAME}.mar",\\
"minWorkers": {ts_env._min_workers},\\
"maxWorkers": {ts_env._max_workers},\\
"batchSize": {ts_env._batch_size},\\
"maxBatchDelay": {ts_env._max_batch_delay},\\
"responseTimeout": {ts_env._response_timeout}\\
}}\\
}}\\
}}'''
user_defined_configuration["models"] = models_string
logger.warn("Sagemaker TS environment variables have been set and will be used "
"for single model endpoint.")

custom_configuration = str()

for key in user_defined_configuration:
Expand Down
94 changes: 94 additions & 0 deletions src/sagemaker_pytorch_serving_container/ts_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functionality that provides access to system
characteristics, environment variables and configuration settings.
"""

from __future__ import absolute_import

from sagemaker_pytorch_serving_container import ts_parameters

import os
import logging

logger = logging.getLogger()

DEFAULT_TS_BATCH_SIZE = 1
DEFAULT_TS_MAX_BATCH_DELAY = 100
DEFAULT_TS_MIN_WORKERS = 1
DEFAULT_TS_MAX_WORKERS = 1
DEFAULT_TS_RESPONSE_TIMEOUT = 60


class TorchServeEnvironment():
"""Provides access to aspects of the torchserve environment relevant to serving containers,
including system characteristics, environment variables and configuration settings.

The Environment is a read-only snapshot of the container environment.
It is a dictionary-like object, allowing any builtin function that works with dictionary.

Attributes:
batch_size (int): This is the maximum batch size in ms that a model is expected to handle
max_batch_delay (int): This is the maximum batch delay time TorchServe waits to receive
batch_size number of requests. If TorchServe doesn’t receive batch_size number of requests
before this timer time’s out, it sends what ever requests that were received to the model handler
min_workers (int): Minimum number of workers that torchserve is allowed to scale down to
max_workers (int): Minimum number of workers that torchserve is allowed to scale up to
response_timeout (int): Time delay after which inference will timeout in absence of a response
"""
def __init__(self):
self._batch_size = int(os.environ.get(ts_parameters.MODEL_SERVER_BATCH_SIZE, DEFAULT_TS_BATCH_SIZE))
self._max_batch_delay = int(os.environ.get(ts_parameters.MODEL_SERVER_MAX_BATCH_DELAY,
DEFAULT_TS_MAX_BATCH_DELAY))
self._min_workers = int(os.environ.get(ts_parameters.MODEL_SERVER_MIN_WORKERS, DEFAULT_TS_MIN_WORKERS))
self._max_workers = int(os.environ.get(ts_parameters.MODEL_SERVER_MAX_WORKERS, DEFAULT_TS_MAX_WORKERS))
self._response_timeout = int(os.environ.get(ts_parameters.MODEL_SERVER_RESPONSE_TIMEOUT,
DEFAULT_TS_RESPONSE_TIMEOUT))

def is_env_set(self): # type: () -> bool
"""bool: whether or not the environment variables have been set"""
ts_env_list = [ts_parameters.MODEL_SERVER_BATCH_SIZE, ts_parameters.MODEL_SERVER_MAX_BATCH_DELAY,
ts_parameters.MODEL_SERVER_MIN_WORKERS, ts_parameters.MODEL_SERVER_MAX_WORKERS,
ts_parameters.MODEL_SERVER_RESPONSE_TIMEOUT]
if any(env in ts_env_list for env in os.environ):
return True

@property
def batch_size(self): # type: () -> int
"""int: number of requests to batch before running inference on the server"""
return self._batch_size

@property
def max_batch_delay(self): # type: () -> int
"""int: time delay in milliseconds, to wait for incoming requests to be batched,
before running inference on the server
"""
return self._max_batch_delay

@property
def min_workers(self): # type:() -> int
"""int: minimum number of worker for model
"""
return self._min_workers

@property
def max_workers(self): # type() -> int
"""int: maximum number of workers for model
"""
return self._max_workers

@property
def response_timeout(self): # type() -> int
"""int: time delay after which inference will timeout in absense of a response
"""
return self._response_timeout
21 changes: 21 additions & 0 deletions src/sagemaker_pytorch_serving_container/ts_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains string constants that define inference toolkit
parameters."""
from __future__ import absolute_import

MODEL_SERVER_BATCH_SIZE = "SAGEMAKER_TS_BATCH_SIZE" # type: str
MODEL_SERVER_MAX_BATCH_DELAY = "SAGEMAKER_TS_MAX_BATCH_DELAY" # type: str
MODEL_SERVER_MIN_WORKERS = "SAGEMAKER_TS_MIN_WORKERS" # type: str
MODEL_SERVER_MAX_WORKERS = "SAGEMAKER_TS_MAX_WORKERS" # type: str
MODEL_SERVER_RESPONSE_TIMEOUT = "SAGEMAKER_TS_RESPONSE_TIMEOUT" # type: str