/
triton_wrapper.py
289 lines (249 loc) · 10.9 KB
/
triton_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# For python niceties
import warnings
from typing import Dict, List, Optional
import numpy
# For triton specific handling
_triton_import_error = None
try:
import tritonclient.grpc
import tritonclient.http
import tritonclient.utils
except (ImportError, ModuleNotFoundError) as err:
_triton_import_error = err
from .helper import nonserializable_attribute, numpy_call_wrapper
class triton_wrapper(nonserializable_attribute, numpy_call_wrapper):
"""
Wrapper for running triton inference.
The target of this class is such that all triton specific operations are
wrapped and abstracted-away from the users. The users should then only needs
to handle awkward-level operations to mangle the arrays into the expected
input format required by the the model of interest.
"""
batch_size_fallback = 10 # Fall back should batch size not be determined.
http_client_concurrency = 12 # TODO: check whether this value is optimum
def __init__(
self, model_url: str, client_args: Optional[Dict] = None, batch_size=-1
):
"""
Parameters
----------
- model_url: A string in the format of:
triton+<protocol>://<address>/<model>/<version>
- client_args: optional keyword arguments to pass to the underlying
`InferenceServerClient` objects.
- batch_size: How the input arrays should be split up for analysis
processing. Leave negative to have this automatically resolved.
"""
if _triton_import_error is not None:
warnings.warn(
"Users should make sure the tritonclient package is installed before proceeding!\n"
"> pip install tritonclient[grpc,http]\n"
"or\n"
"> conda install tritonclient[grpc,http]",
UserWarning,
)
raise _triton_import_error
nonserializable_attribute.__init__(
self, ["client", "model_metadata", "model_inputs", "model_outputs"]
)
fullprotocol, location = model_url.split("://")
_, self.protocol = fullprotocol.split("+")
self.address, self.model, self.version = location.split("/")
# Additional pseudo-lazy objects that requires additional parsing after
# lazy objects have been initialized or additional parsing.
self._batch_size = batch_size
self._client_args = client_args
"""
Spawning the unserializable triton client, as well as other helper objects
that require the triton client to be present.
"""
@property
def pmod(self):
"""Getting the protocol module based on the url protocol string."""
if self.protocol == "grpc":
return tritonclient.grpc
elif self.protocol == "http":
return tritonclient.http
else:
raise ValueError(
f"{self.protocol} does not encode a valid protocol (grpc or http)"
)
def _create_client(self):
return self.pmod.InferenceServerClient(url=self.address, **self.client_args)
@property
def client_args(self) -> Dict:
"""
Function for adding default arguments to the client constructor kwargs.
"""
if self.protocol == "grpc":
kwargs = dict(verbose=False, ssl=True)
elif self.protocol == "http":
kwargs = dict(verbose=False, concurrency=self.http_client_concurrency)
if self._client_args is not None:
kwargs.update(self._client_args)
return kwargs
def _create_model_metadata(self) -> Dict:
return self.client.get_model_metadata(self.model, self.version, as_json=True)
def _create_model_inputs(self) -> Dict[str, Dict]:
"""
Extracting the model input data formats from the model_metatdata. Here
we slightly change the input formats the objects in a format that is
easier to manipulate and compare with numpy arrays.
"""
return {
x["name"]: {
"shape": tuple(int(i) for i in x["shape"]),
"datatype": x["datatype"],
}
for x in self.model_metadata["inputs"]
}
def _create_model_outputs(self) -> List[int]:
"""Getting a list of names of possible outputs"""
return [x["name"] for x in self.model_metadata["outputs"]]
@property
def batch_size(self) -> int:
"""
Getting the batch size to be used for array splitting. If it is
explicitly set by the users, use that; otherwise, extract from the model
configuration hosted on the server.
"""
if self._batch_size < 0:
model_config = self.client.get_model_config(
self.model, self.version, as_json=True
)["config"]
if "dynamic_batching" in model_config:
self._batch_size = model_config["dynamic_batching"][
"preferred_batch_size"
][0]
elif "max_batch_size" in model_config:
self._batch_size = model_config["max_batch_size"]
else:
warnings.warn(
f"Batch size not set by model! Setting to default value {self.batch_size_fallback}. "
"Contact model maintainer to check if this is expected",
UserWarning,
)
self._batch_size = self.batch_size_fallback
return self._batch_size
"""
Numpy/awkward/dask_awkward inference
"""
def validate_numpy_input(
self, output_list: List[str], input_dict: Dict[str, numpy.array]
) -> None:
"""
tritonclient can return the expected input array dimensions and
available output values.
"""
# Input value checking
for iname, iarr in input_dict.items():
# Checking the name
if iname not in self.model_inputs.keys():
raise ValueError(
f"Input [{iname}] not defined in model! "
f"Inputs defined by model: {[x for x in self.model_inputs.keys()]}"
)
# Checking the shape
ishape = numpy.array(iarr.shape)
mshape = numpy.array(self.model_inputs[iname]["shape"])
if len(ishape) != len(mshape):
raise ValueError(
f"Input [{iname}] got wrong dimension: {len(ishape)} "
f"(Expected {len(mshape)})"
)
if not all(numpy.where(mshape > 0, ishape == mshape, True)):
raise ValueError(
f"Input [{iname}] got array of shape {ishape} "
f"(Expected: {mshape}, -1 means arbitrary)"
)
# Checking data type. Notice that this will only raise a warning! Data
# type defined by triton can be found here:
# https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#datatypes
itype = iarr.dtype
mtype = tritonclient.utils.triton_to_np_dtype(
self.model_inputs[iname]["datatype"]
)
if itype != mtype:
warnings.warn(
f"Input [{iname}] got array of type [{itype}] (Expected [{mtype.__name__}])."
" Automatic conversion will be performed using numpy.array.astype.",
UserWarning,
)
# Checking for missing inputs
for mname in self.model_inputs.keys():
if mname not in input_dict.keys():
raise ValueError(f"Input [{mname}] not given in input dictionary!")
# Checking output
for oname in output_list:
if oname not in self.model_outputs:
raise ValueError(
f"Requested output [{oname}] not defined by model (Defined: {[x for x in self.model_outputs]})"
)
def numpy_call(
self, output_list: List[str], input_dict: Dict[str, numpy.array]
) -> Dict[str, numpy.array]:
"""
Parameters
----------
- output_list: List of string corresponding to the name of the outputs
of interest. These strings will be automatically translated into the
required `tritonclient.InferRequestedOutput` objects.
- input_dict: Dictionary with the model's input-names as the key and the
appropriate numpy array as the dictionary value. This dictionary is
automatically translated into a list of `tritonclient.InferInput`
objects.
Return
------
The return will be the dictionary of numpy arrays that have the
output_list arguments as keys.
"""
# Setting up the inference input containers
def _get_infer_shape(name):
ishape = numpy.array(input_dict[name].shape)
mshape = numpy.array(self.model_inputs[name]["shape"])
mshape = numpy.where(mshape < 0, ishape, mshape)
mshape[0] = self.batch_size
return mshape
infer_inputs = [
self.pmod.InferInput(name, _get_infer_shape(name), prop["datatype"])
for name, prop in self.model_inputs.items()
]
# Setting up the inference output containers
infer_outputs = [
self.pmod.InferRequestedOutput(output) for output in output_list
]
# Setting up container for storing output.
output = None
# Padding the outermost dimension to a multiple of of the batch size
orig_len = list(input_dict.values())[0].shape[0] # saving original length
for start_idx in range(0, orig_len, self.batch_size):
stop_idx = min([start_idx + self.batch_size, orig_len])
for idx, name in enumerate(self.model_inputs.keys()):
mtype = tritonclient.utils.triton_to_np_dtype(
self.model_inputs[name]["datatype"]
)
shape = list(input_dict[name].shape)
shape[0] = self.batch_size # Always pad to fixed length
infer_inputs[idx].set_data_from_numpy(
numpy.resize(
input_dict[name][start_idx:stop_idx], # We need a copy here
tuple(shape),
).astype(mtype)
)
# Making request to server
request = self.client.infer(
self.model,
model_version=self.version,
inputs=infer_inputs,
outputs=infer_outputs,
)
if output is None:
output = {
o: request.as_numpy(o)[start_idx:stop_idx] for o in output_list
}
else:
for o in output_list:
output[o] = numpy.concatenate(
(output[o], request.as_numpy(o)), axis=0
)
return {k: v[:orig_len] for k, v in output.items()}