From 38e96451d790f98de9b9634a4e5d9f0fc50a04b8 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Tue, 22 Nov 2022 17:38:23 +0800 Subject: [PATCH] [Serving] add ppdet serving example (#641) * serving support ppdet * Update README.md update ppadet/README --- .../paddleclas/serving/README.md | 2 +- .../detection/paddledetection/README.md | 1 + .../serving/models/postprocess/1/model.py | 110 +++++++++++++++++ .../serving/models/postprocess/config.pbtxt | 30 +++++ .../models/postprocess/mask_config.pbtxt | 34 ++++++ .../serving/models/ppdet/1/README.md | 3 + .../models/ppdet/faster_rcnn_config.pbtxt | 80 ++++++++++++ .../models/ppdet/mask_rcnn_config.pbtxt | 88 ++++++++++++++ .../serving/models/ppdet/ppyolo_config.pbtxt | 80 ++++++++++++ .../serving/models/ppdet/ppyoloe_config.pbtxt | 72 +++++++++++ .../serving/models/preprocess/1/model.py | 114 ++++++++++++++++++ .../serving/models/preprocess/config.pbtxt | 35 ++++++ .../serving/models/runtime/1/README.md | 5 + .../runtime/faster_rcnn_runtime_config.pbtxt | 58 +++++++++ .../runtime/mask_rcnn_runtime_config.pbtxt | 63 ++++++++++ .../runtime/ppyolo_runtime_config.pbtxt | 58 +++++++++ .../runtime/ppyoloe_runtime_config.pbtxt | 55 +++++++++ .../serving/paddledet_grpc_client.py | 109 +++++++++++++++++ 18 files changed, 996 insertions(+), 1 deletion(-) create mode 100644 examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py create mode 100644 examples/vision/detection/paddledetection/serving/models/postprocess/config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/postprocess/mask_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/ppdet/1/README.md create mode 100644 examples/vision/detection/paddledetection/serving/models/ppdet/faster_rcnn_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/ppdet/mask_rcnn_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/ppdet/ppyolo_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/ppdet/ppyoloe_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/preprocess/1/model.py create mode 100644 examples/vision/detection/paddledetection/serving/models/preprocess/config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/runtime/1/README.md create mode 100644 examples/vision/detection/paddledetection/serving/models/runtime/faster_rcnn_runtime_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/runtime/mask_rcnn_runtime_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/runtime/ppyolo_runtime_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/models/runtime/ppyoloe_runtime_config.pbtxt create mode 100644 examples/vision/detection/paddledetection/serving/paddledet_grpc_client.py diff --git a/examples/vision/classification/paddleclas/serving/README.md b/examples/vision/classification/paddleclas/serving/README.md index 59e5d2c05..77d8046be 100644 --- a/examples/vision/classification/paddleclas/serving/README.md +++ b/examples/vision/classification/paddleclas/serving/README.md @@ -13,7 +13,7 @@ tar -xvf ResNet50_vd_infer.tgz wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg # 将配置文件放入预处理目录 -mv ResNet50_vd_infer/inference_cls.yaml models/preprocess/1/ +mv ResNet50_vd_infer/inference_cls.yaml models/preprocess/1/inference_cls.yaml # 将模型放入 models/runtime/1目录下, 并重命名为model.pdmodel和model.pdiparams mv ResNet50_vd_infer/inference.pdmodel models/runtime/1/model.pdmodel diff --git a/examples/vision/detection/paddledetection/README.md b/examples/vision/detection/paddledetection/README.md index e7b833a28..9b5b1b1dd 100644 --- a/examples/vision/detection/paddledetection/README.md +++ b/examples/vision/detection/paddledetection/README.md @@ -47,3 +47,4 @@ - [Python部署](python) - [C++部署](cpp) +- [服务化部署](serving) diff --git a/examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py b/examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py new file mode 100644 index 000000000..4872b0dee --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import numpy as np +import time + +import fastdeploy as fd + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("postprocess input names:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("postprocess output names:", self.output_names) + + self.postprocess_ = fd.vision.detection.PaddleDetPostprocessor() + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + for request in requests: + infer_outputs = [] + for name in self.input_names: + infer_output = pb_utils.get_input_tensor_by_name(request, name) + if infer_output: + infer_output = infer_output.as_numpy() + infer_outputs.append(infer_output) + + results = self.postprocess_.run(infer_outputs) + r_str = fd.vision.utils.fd_result_to_json(results) + + r_np = np.array(r_str, dtype=np.object) + out_tensor = pb_utils.Tensor(self.output_names[0], r_np) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor, ]) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/vision/detection/paddledetection/serving/models/postprocess/config.pbtxt b/examples/vision/detection/paddledetection/serving/models/postprocess/config.pbtxt new file mode 100644 index 000000000..bb09e32c6 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/postprocess/config.pbtxt @@ -0,0 +1,30 @@ +name: "postprocess" +backend: "python" + +input [ + { + name: "post_input1" + data_type: TYPE_FP32 + dims: [ -1, 6 ] + }, + { + name: "post_input2" + data_type: TYPE_INT32 + dims: [ -1 ] + } +] + +output [ + { + name: "post_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/serving/models/postprocess/mask_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/postprocess/mask_config.pbtxt new file mode 100644 index 000000000..8985cc78a --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/postprocess/mask_config.pbtxt @@ -0,0 +1,34 @@ +backend: "python" + +input [ + { + name: "post_input1" + data_type: TYPE_FP32 + dims: [ -1, 6 ] + }, + { + name: "post_input2" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "post_input3" + data_type: TYPE_INT32 + dims: [ -1, -1, -1 ] + } +] + +output [ + { + name: "post_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/serving/models/ppdet/1/README.md b/examples/vision/detection/paddledetection/serving/models/ppdet/1/README.md new file mode 100644 index 000000000..877efdf8d --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/ppdet/1/README.md @@ -0,0 +1,3 @@ +# PaddleDetection Pipeline + +The pipeline directory does not have model files, but a version number directory needs to be maintained. diff --git a/examples/vision/detection/paddledetection/serving/models/ppdet/faster_rcnn_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/ppdet/faster_rcnn_config.pbtxt new file mode 100644 index 000000000..91d132b9a --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/ppdet/faster_rcnn_config.pbtxt @@ -0,0 +1,80 @@ +platform: "ensemble" + +input [ + { + name: "INPUT" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1, 3 ] + } +] +output [ + { + name: "DET_RESULT" + data_type: TYPE_STRING + dims: [ -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocess" + model_version: 1 + input_map { + key: "preprocess_input" + value: "INPUT" + } + output_map { + key: "preprocess_output1" + value: "RUNTIME_INPUT1" + } + output_map { + key: "preprocess_output2" + value: "RUNTIME_INPUT2" + } + output_map { + key: "preprocess_output3" + value: "RUNTIME_INPUT3" + } + }, + { + model_name: "runtime" + model_version: 1 + input_map { + key: "image" + value: "RUNTIME_INPUT1" + } + input_map { + key: "scale_factor" + value: "RUNTIME_INPUT2" + } + input_map { + key: "im_shape" + value: "RUNTIME_INPUT3" + } + output_map { + key: "concat_12.tmp_0" + value: "RUNTIME_OUTPUT1" + } + output_map { + key: "concat_8.tmp_0" + value: "RUNTIME_OUTPUT2" + } + }, + { + model_name: "postprocess" + model_version: 1 + input_map { + key: "post_input1" + value: "RUNTIME_OUTPUT1" + } + input_map { + key: "post_input2" + value: "RUNTIME_OUTPUT2" + } + output_map { + key: "post_output" + value: "DET_RESULT" + } + } + ] +} \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/serving/models/ppdet/mask_rcnn_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/ppdet/mask_rcnn_config.pbtxt new file mode 100644 index 000000000..b0ee4e092 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/ppdet/mask_rcnn_config.pbtxt @@ -0,0 +1,88 @@ +platform: "ensemble" + +input [ + { + name: "INPUT" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1, 3 ] + } +] +output [ + { + name: "DET_RESULT" + data_type: TYPE_STRING + dims: [ -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocess" + model_version: 1 + input_map { + key: "preprocess_input" + value: "INPUT" + } + output_map { + key: "preprocess_output1" + value: "RUNTIME_INPUT1" + } + output_map { + key: "preprocess_output2" + value: "RUNTIME_INPUT2" + } + output_map { + key: "preprocess_output3" + value: "RUNTIME_INPUT3" + } + }, + { + model_name: "runtime" + model_version: 1 + input_map { + key: "image" + value: "RUNTIME_INPUT1" + } + input_map { + key: "scale_factor" + value: "RUNTIME_INPUT2" + } + input_map { + key: "im_shape" + value: "RUNTIME_INPUT3" + } + output_map { + key: "concat_9.tmp_0" + value: "RUNTIME_OUTPUT1" + } + output_map { + key: "concat_5.tmp_0" + value: "RUNTIME_OUTPUT2" + }, + output_map { + key: "tmp_109" + value: "RUNTIME_OUTPUT3" + } + }, + { + model_name: "postprocess" + model_version: 1 + input_map { + key: "post_input1" + value: "RUNTIME_OUTPUT1" + } + input_map { + key: "post_input2" + value: "RUNTIME_OUTPUT2" + } + input_map { + key: "post_input3" + value: "RUNTIME_OUTPUT3" + } + output_map { + key: "post_output" + value: "DET_RESULT" + } + } + ] +} \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/serving/models/ppdet/ppyolo_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/ppdet/ppyolo_config.pbtxt new file mode 100644 index 000000000..f7c1fe612 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/ppdet/ppyolo_config.pbtxt @@ -0,0 +1,80 @@ +platform: "ensemble" + +input [ + { + name: "INPUT" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1, 3 ] + } +] +output [ + { + name: "DET_RESULT" + data_type: TYPE_STRING + dims: [ -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocess" + model_version: 1 + input_map { + key: "preprocess_input" + value: "INPUT" + } + output_map { + key: "preprocess_output1" + value: "RUNTIME_INPUT1" + } + output_map { + key: "preprocess_output2" + value: "RUNTIME_INPUT2" + } + output_map { + key: "preprocess_output3" + value: "RUNTIME_INPUT3" + } + }, + { + model_name: "runtime" + model_version: 1 + input_map { + key: "image" + value: "RUNTIME_INPUT1" + } + input_map { + key: "scale_factor" + value: "RUNTIME_INPUT2" + } + input_map { + key: "im_shape" + value: "RUNTIME_INPUT3" + } + output_map { + key: "matrix_nms_0.tmp_0" + value: "RUNTIME_OUTPUT1" + } + output_map { + key: "matrix_nms_0.tmp_2" + value: "RUNTIME_OUTPUT2" + } + }, + { + model_name: "postprocess" + model_version: 1 + input_map { + key: "post_input1" + value: "RUNTIME_OUTPUT1" + } + input_map { + key: "post_input2" + value: "RUNTIME_OUTPUT2" + } + output_map { + key: "post_output" + value: "DET_RESULT" + } + } + ] +} \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/serving/models/ppdet/ppyoloe_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/ppdet/ppyoloe_config.pbtxt new file mode 100644 index 000000000..3cb479b46 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/ppdet/ppyoloe_config.pbtxt @@ -0,0 +1,72 @@ +platform: "ensemble" + +input [ + { + name: "INPUT" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1, 3 ] + } +] +output [ + { + name: "DET_RESULT" + data_type: TYPE_STRING + dims: [ -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocess" + model_version: 1 + input_map { + key: "preprocess_input" + value: "INPUT" + } + output_map { + key: "preprocess_output1" + value: "RUNTIME_INPUT1" + } + output_map { + key: "preprocess_output2" + value: "RUNTIME_INPUT2" + } + }, + { + model_name: "runtime" + model_version: 1 + input_map { + key: "image" + value: "RUNTIME_INPUT1" + } + input_map { + key: "scale_factor" + value: "RUNTIME_INPUT2" + } + output_map { + key: "multiclass_nms3_0.tmp_0" + value: "RUNTIME_OUTPUT1" + } + output_map { + key: "multiclass_nms3_0.tmp_2" + value: "RUNTIME_OUTPUT2" + } + }, + { + model_name: "postprocess" + model_version: 1 + input_map { + key: "post_input1" + value: "RUNTIME_OUTPUT1" + } + input_map { + key: "post_input2" + value: "RUNTIME_OUTPUT2" + } + output_map { + key: "post_output" + value: "DET_RESULT" + } + } + ] +} \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/serving/models/preprocess/1/model.py b/examples/vision/detection/paddledetection/serving/models/preprocess/1/model.py new file mode 100644 index 000000000..2ea72054d --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/preprocess/1/model.py @@ -0,0 +1,114 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import numpy as np +import os + +import fastdeploy as fd + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("preprocess input names:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + # dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + # self.output_dtype.append(dtype) + self.output_dtype.append(output_config["data_type"]) + print("preprocess output names:", self.output_names) + + # init PaddleClasPreprocess class + yaml_path = os.path.abspath(os.path.dirname( + __file__)) + "/infer_cfg.yml" + self.preprocess_ = fd.vision.detection.PaddleDetPreprocessor(yaml_path) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + for request in requests: + data = pb_utils.get_input_tensor_by_name(request, + self.input_names[0]) + data = data.as_numpy() + outputs = self.preprocess_.run(data) + + output_tensors = [] + for idx, name in enumerate(self.output_names): + dlpack_tensor = outputs[idx].to_dlpack() + output_tensor = pb_utils.Tensor.from_dlpack(name, + dlpack_tensor) + output_tensors.append(output_tensor) + + inference_response = pb_utils.InferenceResponse( + output_tensors=output_tensors) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/vision/detection/paddledetection/serving/models/preprocess/config.pbtxt b/examples/vision/detection/paddledetection/serving/models/preprocess/config.pbtxt new file mode 100644 index 000000000..39a42113b --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/preprocess/config.pbtxt @@ -0,0 +1,35 @@ +name: "preprocess" +backend: "python" + +input [ + { + name: "preprocess_input" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1, 3 ] + } +] + +output [ + { + name: "preprocess_output1" + data_type: TYPE_FP32 + dims: [ -1, 3, -1, -1 ] + }, + { + name: "preprocess_output2" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + }, + { + name: "preprocess_output3" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/serving/models/runtime/1/README.md b/examples/vision/detection/paddledetection/serving/models/runtime/1/README.md new file mode 100644 index 000000000..1e5d914b4 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/runtime/1/README.md @@ -0,0 +1,5 @@ +# Runtime Directory + +This directory holds the model files. +Paddle models must be model.pdmodel and model.pdiparams files. +ONNX models must be model.onnx files. diff --git a/examples/vision/detection/paddledetection/serving/models/runtime/faster_rcnn_runtime_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/runtime/faster_rcnn_runtime_config.pbtxt new file mode 100644 index 000000000..9f4b9833e --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/runtime/faster_rcnn_runtime_config.pbtxt @@ -0,0 +1,58 @@ +backend: "fastdeploy" + +# Input configuration of the model +input [ + { + # input name + name: "image" + # input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING + data_type: TYPE_FP32 + # input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w] + dims: [ -1, 3, -1, -1 ] + }, + { + name: "scale_factor" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + }, + { + name: "im_shape" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + } +] + +# The output of the model is configured in the same format as the input +output [ + { + name: "concat_12.tmp_0" + data_type: TYPE_FP32 + dims: [ -1, 6 ] + }, + { + name: "concat_8.tmp_0" + data_type: TYPE_INT32 + dims: [ -1 ] + } +] + +# Number of instances of the model +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use GPU, CPU inference option is:KIND_CPU + kind: KIND_GPU + # The instance is deployed on the 0th GPU card + gpus: [0] + } +] + +optimization { + execution_accelerators { + gpu_execution_accelerator : [ { + # use Paddle engine + name: "paddle", + } + ] +}} diff --git a/examples/vision/detection/paddledetection/serving/models/runtime/mask_rcnn_runtime_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/runtime/mask_rcnn_runtime_config.pbtxt new file mode 100644 index 000000000..13fdd5b41 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/runtime/mask_rcnn_runtime_config.pbtxt @@ -0,0 +1,63 @@ +backend: "fastdeploy" + +# Input configuration of the model +input [ + { + # input name + name: "image" + # input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING + data_type: TYPE_FP32 + # input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w] + dims: [ -1, 3, -1, -1 ] + }, + { + name: "scale_factor" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + }, + { + name: "im_shape" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + } +] + +# The output of the model is configured in the same format as the input +output [ + { + name: "concat_9.tmp_0" + data_type: TYPE_FP32 + dims: [ -1, 6 ] + }, + { + name: "concat_5.tmp_0" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "tmp_109" + data_type: TYPE_INT32 + dims: [ -1, -1, -1 ] + } +] + +# Number of instances of the model +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use GPU, CPU inference option is:KIND_CPU + kind: KIND_GPU + # The instance is deployed on the 0th GPU card + gpus: [0] + } +] + +optimization { + execution_accelerators { + gpu_execution_accelerator : [ { + # use Paddle engine + name: "paddle", + } + ] +}} diff --git a/examples/vision/detection/paddledetection/serving/models/runtime/ppyolo_runtime_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/runtime/ppyolo_runtime_config.pbtxt new file mode 100644 index 000000000..0f7b63308 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/runtime/ppyolo_runtime_config.pbtxt @@ -0,0 +1,58 @@ +backend: "fastdeploy" + +# Input configuration of the model +input [ + { + # input name + name: "image" + # input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING + data_type: TYPE_FP32 + # input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w] + dims: [ -1, 3, -1, -1 ] + }, + { + name: "scale_factor" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + }, + { + name: "im_shape" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + } +] + +# The output of the model is configured in the same format as the input +output [ + { + name: "matrix_nms_0.tmp_0" + data_type: TYPE_FP32 + dims: [ -1, 6 ] + }, + { + name: "matrix_nms_0.tmp_2" + data_type: TYPE_INT32 + dims: [ -1 ] + } +] + +# Number of instances of the model +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use GPU, CPU inference option is:KIND_CPU + kind: KIND_GPU + # The instance is deployed on the 0th GPU card + gpus: [0] + } +] + +optimization { + execution_accelerators { + gpu_execution_accelerator : [ { + # use Paddle engine + name: "paddle", + } + ] +}} diff --git a/examples/vision/detection/paddledetection/serving/models/runtime/ppyoloe_runtime_config.pbtxt b/examples/vision/detection/paddledetection/serving/models/runtime/ppyoloe_runtime_config.pbtxt new file mode 100644 index 000000000..39b2c6045 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/models/runtime/ppyoloe_runtime_config.pbtxt @@ -0,0 +1,55 @@ +# optional, If name is specified it must match the name of the model repository directory containing the model. +name: "ppyoloe_runtime" +backend: "fastdeploy" + +# Input configuration of the model +input [ + { + # input name + name: "image" + # input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING + data_type: TYPE_FP32 + # input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w] + dims: [ -1, 3, -1, -1 ] + }, + { + name: "scale_factor" + data_type: TYPE_FP32 + dims: [ -1, 2 ] + } +] + +# The output of the model is configured in the same format as the input +output [ + { + name: "multiclass_nms3_0.tmp_0" + data_type: TYPE_FP32 + dims: [ -1, 6 ] + }, + { + name: "multiclass_nms3_0.tmp_2" + data_type: TYPE_INT32 + dims: [ -1 ] + } +] + +# Number of instances of the model +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use GPU, CPU inference option is:KIND_CPU + kind: KIND_GPU + # The instance is deployed on the 0th GPU card + gpus: [0] + } +] + +optimization { + execution_accelerators { + gpu_execution_accelerator : [ { + # use Paddle engine + name: "paddle", + } + ] +}} diff --git a/examples/vision/detection/paddledetection/serving/paddledet_grpc_client.py b/examples/vision/detection/paddledetection/serving/paddledet_grpc_client.py new file mode 100644 index 000000000..842239496 --- /dev/null +++ b/examples/vision/detection/paddledetection/serving/paddledet_grpc_client.py @@ -0,0 +1,109 @@ +import logging +import numpy as np +import time +from typing import Optional +import cv2 +import json + +from tritonclient import utils as client_utils +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2 + +LOGGER = logging.getLogger("run_inference_on_triton") + + +class SyncGRPCTritonRunner: + DEFAULT_MAX_RESP_WAIT_S = 120 + + def __init__( + self, + server_url: str, + model_name: str, + model_version: str, + *, + verbose=False, + resp_wait_s: Optional[float]=None, ): + self._server_url = server_url + self._model_name = model_name + self._model_version = model_version + self._verbose = verbose + self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s + + self._client = InferenceServerClient( + self._server_url, verbose=self._verbose) + error = self._verify_triton_state(self._client) + if error: + raise RuntimeError( + f"Could not communicate to Triton Server: {error}") + + LOGGER.debug( + f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " + f"are up and ready!") + + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + LOGGER.info(f"Model config {model_config}") + LOGGER.info(f"Model metadata {model_metadata}") + + for tm in model_metadata.inputs: + print("tm:", tm) + self._inputs = {tm.name: tm for tm in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = {tm.name: tm for tm in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def Run(self, inputs): + """ + Args: + inputs: list, Each value corresponds to an input name of self._input_names + Returns: + results: dict, {name : numpy.array} + """ + infer_inputs = [] + for idx, data in enumerate(inputs): + infer_input = InferInput(self._input_names[idx], data.shape, + "UINT8") + infer_input.set_data_from_numpy(data) + infer_inputs.append(infer_input) + + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=infer_inputs, + outputs=self._outputs_req, + client_timeout=self._response_wait_t, ) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + def _verify_triton_state(self, triton_client): + if not triton_client.is_server_live(): + return f"Triton server {self._server_url} is not live" + elif not triton_client.is_server_ready(): + return f"Triton server {self._server_url} is not ready" + elif not triton_client.is_model_ready(self._model_name, + self._model_version): + return f"Model {self._model_name}:{self._model_version} is not ready" + return None + + +if __name__ == "__main__": + model_name = "ppdet" + model_version = "1" + url = "localhost:8001" + runner = SyncGRPCTritonRunner(url, model_name, model_version) + im = cv2.imread("000000014439.jpg") + im = np.array([im, ]) + # batch input + # im = np.array([im, im, im]) + for i in range(1): + result = runner.Run([im, ]) + for name, values in result.items(): + print("output_name:", name) + # values is batch + for value in values: + value = json.loads(value) + print(value['boxes'])