Skip to content
This repository has been archived by the owner on May 23, 2024. It is now read-only.

Wait tfs before starting gunicorn #192

Merged
merged 4 commits into from
Mar 23, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docker/build_artifacts/sagemaker/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import signal
import subprocess
import tfs_utils
import requests
import time

from contextlib import contextmanager

Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(self):
# Use this to specify memory that is needed to initialize CUDA/cuDNN and other GPU libraries
self._tfs_gpu_margin = float(os.environ.get("SAGEMAKER_TFS_FRACTIONAL_GPU_MEM_MARGIN", 0.2))
self._tfs_instance_count = int(os.environ.get("SAGEMAKER_TFS_INSTANCE_COUNT", 1))
self._tfs_wait_time = int(os.environ.get("SAGEMAKER_TFS_WAIT_TIME", 600))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make the time unit part of the environment variable name and the field name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will add it in next revision.

self._tfs_inter_op_parallelism = os.environ.get("SAGEMAKER_TFS_INTER_OP_PARALLELISM", 0)
self._tfs_intra_op_parallelism = os.environ.get("SAGEMAKER_TFS_INTRA_OP_PARALLELISM", 0)
self._gunicorn_worker_class = os.environ.get("SAGEMAKER_GUNICORN_WORKER_CLASS", 'gevent')
Expand Down Expand Up @@ -324,6 +327,24 @@ def _wait_for_gunicorn(self):
log.info("gunicorn server is ready!")
return

def _wait_for_tfs(self):
# Wait until tfs server is up and running
while True:
try:
tfs_ready_count = 0
for i in range(self._tfs_instance_count):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (feel free to ignore):
Please consider using a bit more descriptive variable name instead of i (perhaps tfs_index or tfs_ordinal etc.)

tfs_url = "http://localhost:{}/v1/models/{}/metadata" \
.format(self._tfs_rest_port[i], self._tfs_default_model_name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks suspicious - shouldn't there be a list/dict of corresponding model names (potentially different) for each TF server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these TF servers are using the same model here. If it is multi-model endpoint, the tensorflow server is not started during container initialization.

log.info("Trying to connect with model server \n {}".format(tfs_url))
response = requests.get(tfs_url)
logging.info(response)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless response already includes server/endpoint metadata please consider logging some additional information to help customers identify which server returned which response.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We logged the server info on line 338 for this purpose.

if response.status_code == 200:
tfs_ready_count += 1
if tfs_ready_count == self._tfs_instance_count:
break
except requests.exceptions.ConnectionError:
time.sleep(30)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider adding some configuration for this (including relevant time unit name in the (env) variable / field names) - otherwise it's just a hard-coded magic number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will add it in next revision.


@contextmanager
def _timeout(self, seconds):
def _raise_timeout_error(signum, frame):
Expand Down Expand Up @@ -406,6 +427,8 @@ def start(self):
else:
self._create_tfs_config()
self._start_tfs()
with self._timeout(seconds=self._tfs_wait_time):
self._wait_for_tfs()

self._create_nginx_config()

Expand Down