diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index af4d83f55a6ee..f45108a8c4100 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2275,13 +2275,9 @@ USE_TRT_CONVERTER(flatten_contiguous_range); USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(matmul_v2); USE_TRT_CONVERTER(bmm); -USE_TRT_CONVERTER(rsqrt); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); -USE_TRT_CONVERTER(exp); -USE_TRT_CONVERTER(log); USE_TRT_CONVERTER(sigmoid); -USE_TRT_CONVERTER(tanh); USE_TRT_CONVERTER(fc); USE_TRT_CONVERTER(pool2d); USE_TRT_CONVERTER(softmax); @@ -2337,6 +2333,32 @@ USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(mish); USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(pool3d) +USE_TRT_CONVERTER(square); +// unary op +USE_TRT_CONVERTER(exp); +USE_TRT_CONVERTER(log); +USE_TRT_CONVERTER(sqrt); +USE_TRT_CONVERTER(reciprocal); +USE_TRT_CONVERTER(abs); +USE_TRT_CONVERTER(sin); +USE_TRT_CONVERTER(cos); +USE_TRT_CONVERTER(tan); +USE_TRT_CONVERTER(sinh); +USE_TRT_CONVERTER(cosh); +USE_TRT_CONVERTER(tanh); +USE_TRT_CONVERTER(asin); +USE_TRT_CONVERTER(acos); +USE_TRT_CONVERTER(atan); +USE_TRT_CONVERTER(asinh); +USE_TRT_CONVERTER(acosh); +USE_TRT_CONVERTER(atanh); +USE_TRT_CONVERTER(ceil); +USE_TRT_CONVERTER(floor); +#if IS_TRT_VERSION_GE(8200) +USE_TRT_CONVERTER(round); +USE_TRT_CONVERTER(sign); +#endif +USE_TRT_CONVERTER(rsqrt); USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(preln_skip_layernorm) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 6b9e20b68deab..e23ea7e7b1b2b 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -18,6 +18,7 @@ list( group_norm_op.cc pad_op.cc split_op.cc + square_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/square_op.cc b/paddle/fluid/inference/tensorrt/convert/square_op.cc new file mode 100644 index 0000000000000..5ed372167fa6d --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/square_op.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SquareOpConverter : public OpConverter { + public: + SquareOpConverter() {} + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + VLOG(3) << "convert a fluid sqaure op to tensorrt layer "; + nvinfer1::ITensor* input_tensor = + engine_->GetITensor(op_desc.Input("X")[0]); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input_tensor, + *input_tensor, + nvinfer1::ElementWiseOperation::kPROD); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "square", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(square, SquareOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/unary_op.cc b/paddle/fluid/inference/tensorrt/convert/unary_op.cc index 342b966bdcee4..9279e25a1836c 100644 --- a/paddle/fluid/inference/tensorrt/convert/unary_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/unary_op.cc @@ -85,6 +85,7 @@ const std::unordered_map> {"acos", {nvinfer1::UnaryOperation::kACOS}}, {"atan", {nvinfer1::UnaryOperation::kATAN}}, {"asinh", {nvinfer1::UnaryOperation::kASINH}}, + {"acosh", {nvinfer1::UnaryOperation::kACOSH}}, {"atanh", {nvinfer1::UnaryOperation::kATANH}}, {"ceil", {nvinfer1::UnaryOperation::kCEIL}}, {"floor", {nvinfer1::UnaryOperation::kFLOOR}}, @@ -92,12 +93,13 @@ const std::unordered_map> {nvinfer1::UnaryOperation::kSQRT, nvinfer1::UnaryOperation::kRECIP}}, {"logical_not", {nvinfer1::UnaryOperation::kNOT}}, {"reciprocal", {nvinfer1::UnaryOperation::kRECIP}}, -#if IS_TRT_VERSION_GE(8200) - {"sign", {nvinfer1::UnaryOperation::kSIGN}}, -#endif #if IS_TRT_VERSION_GE(7000) {"erf", {nvinfer1::UnaryOperation::kERF}}, #endif +#if IS_TRT_VERSION_GE(8200) + {"sign", {nvinfer1::UnaryOperation::kSIGN}}, + {"round", {nvinfer1::UnaryOperation::kROUND}}, +#endif }; class ExpOpConverter : public UnaryOpConverter { @@ -154,6 +156,10 @@ class AsinhOpConverter : public UnaryOpConverter { public: AsinhOpConverter() { op_type_ = "asinh"; } }; +class AcoshOpConverter : public UnaryOpConverter { + public: + AcoshOpConverter() { op_type_ = "acosh"; } +}; class AtanhOpConverter : public UnaryOpConverter { public: AtanhOpConverter() { op_type_ = "atanh"; } @@ -194,6 +200,10 @@ class ErfOpConverter : public UnaryOpConverter { public: ErfOpConverter() { op_type_ = "erf"; } }; +class RoundOpConverter : public UnaryOpConverter { + public: + RoundOpConverter() { op_type_ = "round"; } +}; #endif } // namespace tensorrt @@ -213,15 +223,17 @@ REGISTER_TRT_OP_CONVERTER(asin, AsinOpConverter); REGISTER_TRT_OP_CONVERTER(acos, AcosOpConverter); REGISTER_TRT_OP_CONVERTER(atan, AtanOpConverter); REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter); +REGISTER_TRT_OP_CONVERTER(acosh, AcoshOpConverter); REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter); REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter); REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter); REGISTER_TRT_OP_CONVERTER(rsqrt, RsqrtOpConverter); REGISTER_TRT_OP_CONVERTER(logical_not, LogicalNotOpConverter); REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter); -#if IS_TRT_VERSION_GE(8200) -REGISTER_TRT_OP_CONVERTER(sign, SignOpConverter); -#endif #if IS_TRT_VERSION_GE(7000) REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter); #endif +#if IS_TRT_VERSION_GE(8200) +REGISTER_TRT_OP_CONVERTER(sign, SignOpConverter); +REGISTER_TRT_OP_CONVERTER(round, RoundOpConverter); +#endif diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 4367927bb1734..cad8e2a5b4a51 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -65,6 +65,10 @@ struct SimpleOpTypeSetTeller : public Teller { int8_teller_set.insert("sparse_fc"); teller_set.insert("sparse_multihead_matmul"); int8_teller_set.insert("sparse_multihead_matmul"); +#endif +#if IS_TRT_VERSION_GE(8200) + teller_set.insert("round"); + int8_teller_set.insert("round"); #endif } @@ -79,18 +83,18 @@ struct SimpleOpTypeSetTeller : public Teller { desc.HasAttr("skip_quant")) return false; std::unordered_set act_op_list = { - "relu", "relu6", "sigmoid", - "elu", "selu", "softsign", - "softplus", "stanh", "thresholded_relu", - "exp", "log", "sqrt", - "abs", "sin", "cos", - "tan", "tanh", "sinh", - "cosh", "asin", "acos", - "atan", "asinh", "atanh", - "ceil", "floor", "erf", - "reciprocal", "silu", "celu", - "tanh_shrink", "logsigmoid", "sign", - "logical_not"}; + "relu", "relu6", "sigmoid", + "elu", "selu", "softsign", + "softplus", "stanh", "thresholded_relu", + "exp", "log", "sqrt", + "abs", "sin", "cos", + "tan", "tanh", "sinh", + "cosh", "asin", "acos", + "atan", "asinh", "acosh", + "atanh", "ceil", "celu", + "erf", "floor", "round", + "sign", "silu", "logical_not", + "reciprocal", "tanh_shrink", "logsigmoid"}; if (act_op_list.find(op_type) != act_op_list.end()) { auto* block = desc.Block(); if (block == nullptr) { @@ -2446,6 +2450,7 @@ struct SimpleOpTypeSetTeller : public Teller { "acos", "atan", "asinh", + "acosh", "atanh", "ceil", "floor", @@ -2454,6 +2459,7 @@ struct SimpleOpTypeSetTeller : public Teller { "reciprocal", "logical_not", "erf", + "square", "softmax", "sigmoid", "hard_swish", @@ -2589,6 +2595,7 @@ struct SimpleOpTypeSetTeller : public Teller { "acos", "atan", "asinh", + "acosh", "atanh", "ceil", "floor", @@ -2597,6 +2604,7 @@ struct SimpleOpTypeSetTeller : public Teller { "reciprocal", "logical_not", "erf", + "square", "softmax", "sigmoid", "hard_swish", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_activation.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_activation.py index 72c656a757c51..85583c37f6443 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_activation.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_activation.py @@ -25,6 +25,10 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 8200: + if program_config.ops[0].type == "round": + return False return True def sample_program_configs(self): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_square.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_square.py new file mode 100644 index 0000000000000..a3cd487881b09 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_square.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertSquareTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(dims): + if dims == 1: + return np.ones([3]).astype(np.float32) + elif dims == 2: + return np.ones([3, 64]).astype(np.float32) + elif dims == 3: + return np.ones([3, 64, 64]).astype(np.float32) + else: + return np.ones([1, 3, 64, 64]).astype(np.float32) + + for dims in [1, 2, 3, 4]: + for alpha in [1.0, 2.0, 3.0]: + self.dims = dims + + ops_config = [ + { + "op_type": "square", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": {}, + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, dims) + ) + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [1]} + self.dynamic_shape.max_input_shape = {"input_data": [128]} + self.dynamic_shape.opt_input_shape = {"input_data": [64]} + elif self.dims == 2: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} + elif self.dims == 3: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} + self.dynamic_shape.max_input_shape = { + "input_data": [10, 64, 64] + } + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 3, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape and self.dims == 1: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), (1e-3, 1e-3) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), (1e-3, 1e-3) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py index ba364220e3985..0fd95eaa29fc9 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py @@ -25,6 +25,10 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 8200: + if program_config.ops[0].type == "round": + return False return True def sample_program_configs(self): @@ -54,11 +58,13 @@ def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): "acos", "atan", "asinh", + "acosh", "atanh", "ceil", "floor", "rsqrt", "reciprocal", + "round", "sign", ]: self.dims = dims