-
Notifications
You must be signed in to change notification settings - Fork 61
/
handler_service.py
359 lines (307 loc) · 15.1 KB
/
handler_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import os
import sys
import time
from abc import ABC
from inspect import signature
from sagemaker_inference import environment, utils
from transformers.pipelines import SUPPORTED_TASKS
from mms.service import PredictionException
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
_is_gpu_available,
get_pipeline,
infer_task_from_model_architecture,
)
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
PYTHON_PATH_ENV = "PYTHONPATH"
MODEL_FN = "model_fn"
INPUT_FN = "input_fn"
PREDICT_FN = "predict_fn"
OUTPUT_FN = "output_fn"
TRANSFORM_FN = "transform_fn"
logger = logging.getLogger(__name__)
class HuggingFaceHandlerService(ABC):
"""Default handler service that is executed by the model server.
The handler service is responsible for defining our InferenceHandler.
- The ``handle`` method is invoked for all incoming inference requests to the model server.
- The ``initialize`` method is invoked at model server start up.
Implementation of: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md
"""
def __init__(self):
self.error = None
self.batch_size = 1
self.model_dir = None
self.model = None
self.device = -1
self.initialized = False
self.attempted_init = False
self.context = None
self.manifest = None
self.environment = environment.Environment()
self.load_extra_arg = []
self.preprocess_extra_arg = []
self.predict_extra_arg = []
self.postprocess_extra_arg = []
self.transform_extra_arg = []
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self.attempted_init = True
self.context = context
properties = context.system_properties
self.model_dir = properties.get("model_dir")
self.batch_size = context.system_properties["batch_size"]
code_dir_path = os.path.join(self.model_dir, "code")
sys.path.insert(0, code_dir_path)
self.validate_and_initialize_user_module()
self.device = self.get_device()
self.model = self.load(*([self.model_dir] + self.load_extra_arg))
self.initialized = True
# # Load methods from file
# if (not self._initialized) and ENABLE_MULTI_MODEL:
# code_dir = os.path.join(context.system_properties.get("model_dir"), "code")
# sys.path.append(code_dir)
# self._initialized = True
# # add model_dir/code to python path
def get_device(self):
"""
The get device function will return the device for the DL Framework.
"""
if _is_gpu_available():
return int(self.context.system_properties.get("gpu_id"))
else:
return -1
def load(self, model_dir, context=None):
"""
The Load handler is responsible for loading the Hugging Face transformer model.
It can be overridden to load the model from storage.
Args:
model_dir (str): The directory where model files are stored.
context (obj): metadata on the incoming request data (default: None).
Returns:
hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
"""
# gets pipeline from task tag
if "HF_TASK" in os.environ:
hf_pipeline = get_pipeline(task=os.environ["HF_TASK"], model_dir=model_dir, device=self.device)
elif "config.json" in os.listdir(model_dir):
task = infer_task_from_model_architecture(f"{model_dir}/config.json")
hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
elif "model_index.json" in os.listdir(model_dir):
task = "text-to-image"
hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
else:
raise ValueError(
f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} or text-to-image as env 'HF_TASK'.",
403,
)
return hf_pipeline
def preprocess(self, input_data, content_type, context=None):
"""
The preprocess handler is responsible for deserializing the input data into
an object for prediction, can handle JSON.
The preprocess handler can be overridden for data or feature transformation.
Args:
input_data: the request payload serialized in the content_type format.
content_type: the request content_type.
context (obj): metadata on the incoming request data (default: None).
Returns:
decoded_input_data (dict): deserialized input_data into a Python dictonary.
"""
# raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
if (
os.environ.get("HF_TASK", None) == "zero-shot-classification"
or os.environ.get("HF_TASK", None) == "table-question-answering"
) and content_type == content_types.CSV:
raise PredictionException(
f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",
400,
)
decoded_input_data = decoder_encoder.decode(input_data, content_type)
return decoded_input_data
def predict(self, data, model, context=None):
"""The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
The predict handler can be overridden to implement the model inference.
Args:
data (dict): deserialized decoded_input_data returned by the input_fn
model : Model returned by the `load` method or if it is a custom module `model_fn`.
context (obj): metadata on the incoming request data (default: None).
Returns:
obj (dict): prediction result.
"""
# pop inputs for pipeline
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = model(inputs, **parameters)
else:
prediction = model(inputs)
return prediction
def postprocess(self, prediction, accept, context=None):
"""
The postprocess handler is responsible for serializing the prediction result to
the desired accept type, can handle JSON.
The postprocess handler can be overridden for inference response transformation.
Args:
prediction (dict): a prediction result from predict.
accept (str): type which the output data needs to be serialized.
context (obj): metadata on the incoming request data (default: None).
Returns: output data serialized
"""
return decoder_encoder.encode(prediction, accept)
def transform_fn(self, model, input_data, content_type, accept, context=None):
"""
Transform function ("transform_fn") can be used to write one function with pre/post-processing steps and predict step in it.
This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn".
Args:
model: Model returned by the model_fn above
input_data: Data received for inference
content_type: The content type of the inference data
accept: The response accept type.
context (obj): metadata on the incoming request data (default: None).
Returns: Response in the "accept" format type.
"""
# run pipeline
start_time = time.time()
processed_data = self.preprocess(*([input_data, content_type] + self.preprocess_extra_arg))
preprocess_time = time.time() - start_time
predictions = self.predict(*([processed_data, model] + self.predict_extra_arg))
predict_time = time.time() - preprocess_time - start_time
response = self.postprocess(*([predictions, accept] + self.postprocess_extra_arg))
postprocess_time = time.time() - predict_time - preprocess_time - start_time
logger.info(
f"Preprocess time - {preprocess_time * 1000} ms\n"
f"Predict time - {predict_time * 1000} ms\n"
f"Postprocess time - {postprocess_time * 1000} ms"
)
return response
def handle(self, data, context):
"""Handles an inference request with input data and makes a prediction.
Args:
data (obj): the request data.
context (obj): metadata on the incoming request data.
Returns:
list[obj]: The return value from the Transformer.transform method,
which is a serialized prediction result wrapped in a list if
inference is successful. Otherwise returns an error message
with the context set appropriately.
"""
try:
if not self.initialized:
if self.attempted_init:
logger.warn(
"Model is not initialized, will try to load model again.\n"
"Please consider increase wait time for model loading.\n"
)
self.initialize(context)
input_data = data[0].get("body")
request_property = context.request_processor[0].get_request_properties()
content_type = utils.retrieve_content_type_header(request_property)
accept = request_property.get("Accept") or request_property.get("accept")
if not accept or accept == content_types.ANY:
accept = content_types.JSON
if content_type in content_types.UTF8_TYPES:
input_data = input_data.decode("utf-8")
predict_start = time.time()
response = self.transform_fn(*([self.model, input_data, content_type, accept] + self.transform_extra_arg))
predict_end = time.time()
context.metrics.add_time("Transform Fn", round((predict_end - predict_start) * 1000, 2))
context.set_response_content_type(0, accept)
return [response]
except Exception as e:
raise PredictionException(str(e), 400)
def validate_and_initialize_user_module(self):
"""Retrieves and validates the inference handlers provided within the user module.
Can override load, preprocess, predict and post process function.
"""
user_module_name = self.environment.module_name
if importlib.util.find_spec(user_module_name) is not None:
logger.info("Inference script implementation found at `{}`.".format(user_module_name))
user_module = importlib.import_module(user_module_name)
load_fn = getattr(user_module, MODEL_FN, None)
preprocess_fn = getattr(user_module, INPUT_FN, None)
predict_fn = getattr(user_module, PREDICT_FN, None)
postprocess_fn = getattr(user_module, OUTPUT_FN, None)
transform_fn = getattr(user_module, TRANSFORM_FN, None)
if transform_fn and (preprocess_fn or predict_fn or postprocess_fn):
raise ValueError(
"Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format(
TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN
)
)
self.log_func_implementation_found_or_not(load_fn, MODEL_FN)
if load_fn is not None:
self.load_extra_arg = self.function_extra_arg(self.load, load_fn)
self.load = load_fn
self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN)
if preprocess_fn is not None:
self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn)
self.preprocess = preprocess_fn
self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN)
if predict_fn is not None:
self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn)
self.predict = predict_fn
self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN)
if postprocess_fn is not None:
self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn)
self.postprocess = postprocess_fn
self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN)
if transform_fn is not None:
self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn)
self.transform_fn = transform_fn
else:
logger.info(
"No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format(
user_module_name
)
)
@staticmethod
def log_func_implementation_found_or_not(func, func_name):
if func is not None:
logger.info("`{}` implementation found. It will be used in place of the default one.".format(func_name))
else:
logger.info(
"No `{}` implementation was found. The default one from the handler service will be used.".format(
func_name
)
)
def function_extra_arg(self, default_func, func):
"""Helper to call the handler function which covers 2 cases:
1. the handle function takes context
2. the handle function does not take context
"""
num_default_func_input = len(signature(default_func).parameters)
num_func_input = len(signature(func).parameters)
if num_default_func_input == num_func_input:
# function takes context
extra_args = [self.context]
elif num_default_func_input == num_func_input + 1:
# function does not take context
extra_args = []
else:
raise TypeError(
"{} definition takes {} or {} arguments but {} were given.".format(
func.__name__, num_default_func_input - 1, num_default_func_input, num_func_input
)
)
return extra_args