-
Notifications
You must be signed in to change notification settings - Fork 60
/
mms_model_server.py
153 lines (125 loc) · 5.03 KB
/
mms_model_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import os
import pathlib
import subprocess
from sagemaker_inference import environment, logging
from sagemaker_inference.model_server import (
DEFAULT_MMS_LOG_FILE,
DEFAULT_MMS_MODEL_NAME,
ENABLE_MULTI_MODEL,
MMS_CONFIG_FILE,
REQUIREMENTS_PATH,
_add_sigchild_handler,
_add_sigterm_handler,
_create_model_server_config_file,
_install_requirements,
_retry_retrieve_mms_server_process,
_set_python_path,
)
from sagemaker_huggingface_inference_toolkit import handler_service
from sagemaker_huggingface_inference_toolkit.optimum_utils import is_optimum_neuron_available
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
HF_API_TOKEN,
HF_MODEL_REVISION,
_load_model_from_hub,
)
logger = logging.get_logger()
DEFAULT_HANDLER_SERVICE = handler_service.__name__
DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models")
DEFAULT_MODEL_STORE = "/"
def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
"""Configure and start the model server.
Args:
handler_service (str): python path pointing to a module that defines
a class with the following:
- A ``handle`` method, which is invoked for all incoming inference
requests to the model server.
- A ``initialize`` method, which is invoked at model server start up
for loading the model.
Defaults to ``sagemaker_huggingface_inference_toolkit.handler_service``.
"""
use_hf_hub = "HF_MODEL_ID" in os.environ
model_store = DEFAULT_MODEL_STORE
if ENABLE_MULTI_MODEL:
if not os.getenv("SAGEMAKER_HANDLER"):
os.environ["SAGEMAKER_HANDLER"] = handler_service
_set_python_path()
elif use_hf_hub:
# Use different model store directory
model_store = DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY
storage_dir = _load_model_from_hub(
model_id=os.environ["HF_MODEL_ID"],
model_dir=DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
revision=HF_MODEL_REVISION,
use_auth_token=HF_API_TOKEN,
)
_adapt_to_mms_format(handler_service, storage_dir)
else:
_set_python_path()
env = environment.Environment()
# Set the number of workers to available number if optimum neuron is available and not already set
if is_optimum_neuron_available() and os.environ.get("SAGEMAKER_MODEL_SERVER_WORKERS", None) is None:
from optimum.neuron.utils.cache_utils import get_num_neuron_cores
try:
env._model_server_workers = str(get_num_neuron_cores())
except Exception:
env._model_server_workers = "1"
# Note: multi-model default config already sets default_service_handler
handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service
_create_model_server_config_file(env, handler_service_for_config)
if os.path.exists(REQUIREMENTS_PATH):
_install_requirements()
multi_model_server_cmd = [
"multi-model-server",
"--start",
"--model-store",
model_store,
"--mms-config",
MMS_CONFIG_FILE,
"--log-config",
DEFAULT_MMS_LOG_FILE,
]
if not ENABLE_MULTI_MODEL and not use_hf_hub:
multi_model_server_cmd += ["--models", DEFAULT_MMS_MODEL_NAME + "=" + environment.model_dir]
logger.info(multi_model_server_cmd)
subprocess.Popen(multi_model_server_cmd)
# retry for configured timeout
mms_process = _retry_retrieve_mms_server_process(env.startup_timeout)
_add_sigterm_handler(mms_process)
_add_sigchild_handler()
mms_process.wait()
def _adapt_to_mms_format(handler_service, model_path):
os.makedirs(DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, exist_ok=True)
# gets the model from the path, default is model/
model = pathlib.PurePath(model_path)
# This is archiving or cp /opt/ml/model to /opt/ml (MODEL_STORE) into model (MODEL_NAME)
model_archiver_cmd = [
"model-archiver",
"--model-name",
model.name,
"--handler",
handler_service,
"--model-path",
model_path,
"--export-path",
DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
"--archive-format",
"no-archive",
"--f",
]
logger.info(model_archiver_cmd)
subprocess.check_call(model_archiver_cmd)
_set_python_path()